Coverage for src/bob/learn/em/utils.py: 100%
23 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:34 +0200
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:34 +0200
1import logging
3import dask
4import dask.array as da
5import numpy as np
7logger = logging.getLogger(__name__)
10def check_and_persist_dask_input(data, persist=True):
11 # check if input is a dask array. If so, persist and rebalance data
12 input_is_dask = False
13 if isinstance(data, da.Array):
14 if persist:
15 data: da.Array = data.persist()
16 input_is_dask = True
17 # if there is a dask distributed client, rebalance data
18 try:
19 client = dask.distributed.Client.current()
20 client.rebalance()
21 except ValueError:
22 pass
24 else:
25 data = np.asarray(data)
26 return input_is_dask, data
29def array_to_delayed_list(data, input_is_dask):
30 # If input is a dask array, convert to delayed chunks
31 if input_is_dask:
32 data = data.to_delayed().ravel().tolist()
33 logger.debug(f"Got {len(data)} chunks.")
34 return data