Coverage for src/bob/fusion/base/tools/plotting.py: 93%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 22:15 +0100

1#!/usr/bin/env python 

2 

3import numpy as np 

4 

5from numpy.random import default_rng 

6 

7from bob.learn.em import KMeansMachine 

8 

9 

10def grouping(scores, gformat="random", npoints=500, seed=None, **kwargs): 

11 scores = np.asarray(scores) 

12 if scores.size == 0: 

13 return scores 

14 

15 if gformat == "kmeans": 

16 kmeans_machine = KMeansMachine( 

17 n_clusters=npoints, convergence_threshold=0.1, max_iter=500 

18 ) 

19 kmeans_machine.fit(scores) 

20 scores = kmeans_machine.means 

21 

22 elif gformat == "random": 

23 rng = default_rng(seed) 

24 scores = rng.choice(scores, npoints, replace=False) 

25 

26 return scores