Coverage for src/bob/learn/em/utils.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-06-16 14:34 +0200

1import logging 

2 

3import dask 

4import dask.array as da 

5import numpy as np 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10def check_and_persist_dask_input(data, persist=True): 

11 # check if input is a dask array. If so, persist and rebalance data 

12 input_is_dask = False 

13 if isinstance(data, da.Array): 

14 if persist: 

15 data: da.Array = data.persist() 

16 input_is_dask = True 

17 # if there is a dask distributed client, rebalance data 

18 try: 

19 client = dask.distributed.Client.current() 

20 client.rebalance() 

21 except ValueError: 

22 pass 

23 

24 else: 

25 data = np.asarray(data) 

26 return input_is_dask, data 

27 

28 

29def array_to_delayed_list(data, input_is_dask): 

30 # If input is a dask array, convert to delayed chunks 

31 if input_is_dask: 

32 data = data.to_delayed().ravel().tolist() 

33 logger.debug(f"Got {len(data)} chunks.") 

34 return data