@@ -491,6 +491,7 @@ def test_subgraph(self, n_vertices=100):
491
491
self .assertEqual (graph .plotting , self ._G .plotting )
492
492
493
493
def test_nngraph (self , n_vertices = 30 ):
494
+ """Test all the combinations of metric, kind, backend."""
494
495
features = np .random .RandomState (42 ).normal (size = (n_vertices , 3 ))
495
496
metrics = ['euclidean' , 'manhattan' , 'max_dist' , 'minkowski' ]
496
497
backends = ['scipy-kdtree' , 'scipy-ckdtree' , 'scipy-pdist' , 'nmslib' ,
@@ -499,46 +500,30 @@ def test_nngraph(self, n_vertices=30):
499
500
500
501
for backend in backends :
501
502
for metric in metrics :
502
- if ((backend == 'flann' and metric == 'max_dist' ) or
503
- (backend == 'nmslib' and metric == 'minkowski' )):
504
- self .assertRaises (ValueError , graphs .NNGraph , features ,
505
- kind = 'knn' , backend = backend ,
506
- metric = metric )
507
- self .assertRaises (ValueError , graphs .NNGraph , features ,
508
- kind = 'radius' , backend = backend ,
509
- metric = metric )
510
- else :
511
- if backend == 'nmslib' :
512
- self .assertRaises (ValueError , graphs .NNGraph , features ,
513
- kind = 'radius' , backend = backend ,
514
- metric = metric , order = order )
503
+ for kind in ['knn' , 'radius' ]:
504
+ params = dict (features = features , metric = metric ,
505
+ order = order , kind = kind , backend = backend )
506
+ # Unsupported combinations.
507
+ if backend == 'flann' and metric == 'max_dist' :
508
+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
509
+ elif backend == 'nmslib' and metric == 'minkowski' :
510
+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
511
+ elif backend == 'nmslib' and kind == 'radius' :
512
+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
515
513
else :
516
- graphs .NNGraph (features , kind = 'radius' ,
517
- backend = backend ,
518
- metric = metric , order = order )
519
- graphs .NNGraph (features , kind = 'knn' ,
520
- backend = backend ,
521
- metric = metric , order = order )
522
- graphs .NNGraph (features , kind = 'knn' ,
523
- backend = backend ,
524
- metric = metric , order = order ,
525
- center = False )
526
- graphs .NNGraph (features , kind = 'knn' ,
527
- backend = backend ,
528
- metric = metric , order = order ,
529
- rescale = False )
530
- graphs .NNGraph (features , kind = 'knn' ,
531
- backend = backend ,
532
- metric = metric , order = order ,
533
- rescale = False , center = False )
514
+ graphs .NNGraph (** params , center = False )
515
+ graphs .NNGraph (** params , rescale = False )
516
+ graphs .NNGraph (** params , center = False , rescale = False )
517
+
518
+ # Invalid parameters.
519
+ self .assertRaises (ValueError , graphs .NNGraph , features ,
520
+ metric = 'invalid' )
534
521
self .assertRaises (ValueError , graphs .NNGraph , features ,
535
- kind = 'invalid' , backend = backend ,
536
- metric = metric )
522
+ kind = 'invalid' )
537
523
self .assertRaises (ValueError , graphs .NNGraph , features ,
538
- kind = 'knn' , backend = 'invalid' ,
539
- metric = metric )
524
+ backend = 'invalid' )
540
525
self .assertRaises (ValueError , graphs .NNGraph , features ,
541
- kind = 'knn' , k = n_vertices + 1 )
526
+ kind = 'knn' , k = n_vertices + 1 )
542
527
543
528
def test_nngraph_consistency (self ):
544
529
features = np .arange (90 ).reshape (30 , 3 )
0 commit comments