Skip to content

Commit 7b35e88

Browse files
mdeffnperraud
authored andcommitted
simplify test_nngraph (PR #21)
1 parent b01fd31 commit 7b35e88

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

pygsp/tests/test_graphs.py

+21-36
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def test_subgraph(self, n_vertices=100):
491491
self.assertEqual(graph.plotting, self._G.plotting)
492492

493493
def test_nngraph(self, n_vertices=30):
494+
"""Test all the combinations of metric, kind, backend."""
494495
features = np.random.RandomState(42).normal(size=(n_vertices, 3))
495496
metrics = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
496497
backends = ['scipy-kdtree', 'scipy-ckdtree', 'scipy-pdist', 'nmslib',
@@ -499,46 +500,30 @@ def test_nngraph(self, n_vertices=30):
499500

500501
for backend in backends:
501502
for metric in metrics:
502-
if ((backend == 'flann' and metric == 'max_dist') or
503-
(backend == 'nmslib' and metric == 'minkowski')):
504-
self.assertRaises(ValueError, graphs.NNGraph, features,
505-
kind='knn', backend=backend,
506-
metric=metric)
507-
self.assertRaises(ValueError, graphs.NNGraph, features,
508-
kind='radius', backend=backend,
509-
metric=metric)
510-
else:
511-
if backend == 'nmslib':
512-
self.assertRaises(ValueError, graphs.NNGraph, features,
513-
kind='radius', backend=backend,
514-
metric=metric, order=order)
503+
for kind in ['knn', 'radius']:
504+
params = dict(features=features, metric=metric,
505+
order=order, kind=kind, backend=backend)
506+
# Unsupported combinations.
507+
if backend == 'flann' and metric == 'max_dist':
508+
self.assertRaises(ValueError, graphs.NNGraph, **params)
509+
elif backend == 'nmslib' and metric == 'minkowski':
510+
self.assertRaises(ValueError, graphs.NNGraph, **params)
511+
elif backend == 'nmslib' and kind == 'radius':
512+
self.assertRaises(ValueError, graphs.NNGraph, **params)
515513
else:
516-
graphs.NNGraph(features, kind='radius',
517-
backend=backend,
518-
metric=metric, order=order)
519-
graphs.NNGraph(features, kind='knn',
520-
backend=backend,
521-
metric=metric, order=order)
522-
graphs.NNGraph(features, kind='knn',
523-
backend=backend,
524-
metric=metric, order=order,
525-
center=False)
526-
graphs.NNGraph(features, kind='knn',
527-
backend=backend,
528-
metric=metric, order=order,
529-
rescale=False)
530-
graphs.NNGraph(features, kind='knn',
531-
backend=backend,
532-
metric=metric, order=order,
533-
rescale=False, center=False)
514+
graphs.NNGraph(**params, center=False)
515+
graphs.NNGraph(**params, rescale=False)
516+
graphs.NNGraph(**params, center=False, rescale=False)
517+
518+
# Invalid parameters.
519+
self.assertRaises(ValueError, graphs.NNGraph, features,
520+
metric='invalid')
534521
self.assertRaises(ValueError, graphs.NNGraph, features,
535-
kind='invalid', backend=backend,
536-
metric=metric)
522+
kind='invalid')
537523
self.assertRaises(ValueError, graphs.NNGraph, features,
538-
kind='knn', backend='invalid',
539-
metric=metric)
524+
backend='invalid')
540525
self.assertRaises(ValueError, graphs.NNGraph, features,
541-
kind='knn', k=n_vertices+1)
526+
kind='knn', k=n_vertices+1)
542527

543528
def test_nngraph_consistency(self):
544529
features = np.arange(90).reshape(30, 3)

0 commit comments

Comments
 (0)