1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3
4"""The main entry for bob ip binseg (click-based) scripts."""
5
6import logging
7import os
8import random
9import re
10import sys
11import tempfile
12import time
13import urllib.request
14
15import click
16import numpy
17import pkg_resources
18import torch
19
20from click_plugins import with_plugins
21from tqdm import tqdm
22
23from bob.extension.scripts.click_helper import AliasedGroup
24
25logger = logging.getLogger(__name__)
26
27
28def setup_pytorch_device(name):
29 """Sets-up the pytorch device to use
30
31
32 Parameters
33 ----------
34
35 name : str
36 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you
37 set a specific cuda device such as ``cuda:1``, then we'll make sure it
38 is currently set.
39
40
41 Returns
42 -------
43
44 device : :py:class:`torch.device`
45 The pytorch device to use, pre-configured (and checked)
46
47 """
48
49 if name.startswith("cuda:"):
50 # In case one has multiple devices, we must first set the one
51 # we would like to use so pytorch can find it.
52 logger.info(f"User set device to '{name}' - trying to force device...")
53 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
54 if not torch.cuda.is_available():
55 raise RuntimeError(
56 f"CUDA is not currently available, but "
57 f"you set device to '{name}'"
58 )
59 # Let pytorch auto-select from environment variable
60 return torch.device("cuda")
61
62 elif name.startswith("cuda"): # use default device
63 logger.info(f"User set device to '{name}' - using default CUDA device")
64 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
65
66 # cuda or cpu
67 return torch.device(name)
68
69
70def set_seeds(value, all_gpus):
71 """Sets up all relevant random seeds (numpy, python, cuda)
72
73 If running with multiple GPUs **at the same time**, set ``all_gpus`` to
74 ``True`` to force all GPU seeds to be initialized.
75
76 Reference: `PyTorch page for reproducibility
77 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
78
79
80 Parameters
81 ----------
82
83 value : int
84 The random seed value to use
85
86 all_gpus : :py:class:`bool`, Optional
87 If set, then reset the seed on all GPUs available at once. This is
88 normally **not** what you want if running on a single GPU
89
90 """
91
92 random.seed(value)
93 numpy.random.seed(value)
94 torch.manual_seed(value)
95 torch.cuda.manual_seed(value) # noop if cuda not available
96
97 # set seeds for all gpus
98 if all_gpus:
99 torch.cuda.manual_seed_all(value) # noop if cuda not available
100
101
102def set_reproducible_cuda():
103 """Turns-off all CUDA optimizations that would affect reproducibility
104
105 For full reproducibility, also ensure not to use multiple (parallel) data
106 lowers. That is setup ``num_workers=0``.
107
108 Reference: `PyTorch page for reproducibility
109 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
110
111
112 """
113
114 # ensure to use only optimization algos for cuda that are known to have
115 # a deterministic effect (not random)
116 torch.backends.cudnn.deterministic = True
117
118 # turns off any optimization tricks
119 torch.backends.cudnn.benchmark = False
120
121
122def escape_name(v):
123 """Escapes a name so it contains filesystem friendly characters only
124
125 This function escapes every character that's not a letter, ``_``, ``-``,
126 ``.`` or space with an ``-``.
127
128
129 Parameters
130 ==========
131
132 v : str
133 String to be escaped
134
135
136 Returns
137 =======
138
139 s : str
140 Escaped string
141
142 """
143 return re.sub(r"[^\w\-_\. ]", "-", v)
144
145
146def save_sh_command(destfile):
147 """Records command-line to reproduce this experiment
148
149 This function can record the current command-line used to call the script
150 being run. It creates an executable ``bash`` script setting up the current
151 working directory and activating a conda environment, if needed. It
152 records further information on the date and time the script was run and the
153 version of the package.
154
155
156 Parameters
157 ----------
158
159 destfile : str
160 Path leading to the file where the commands to reproduce the current
161 run will be recorded. This file cannot be overwritten by this
162 function. If needed, you should check and remove an existing file
163 **before** calling this function.
164
165 """
166
167 if os.path.exists(destfile):
168 logger.info(f"Not overwriting existing file '{destfile}'")
169 return
170
171 logger.info(f"Writing command-line for reproduction at '{destfile}'...")
172 os.makedirs(os.path.dirname(destfile), exist_ok=True)
173
174 with open(destfile, "wt") as f:
175 f.write("#!/usr/bin/env sh\n")
176 f.write(f"# date: {time.asctime()}\n")
177 version = pkg_resources.require("bob.ip.binseg")[0].version
178 f.write(f"# version: {version} (bob.ip.binseg)\n")
179 f.write(f"# platform: {sys.platform}\n")
180 f.write("\n")
181 args = []
182 for k in sys.argv:
183 if " " in k:
184 args.append(f'"{k}"')
185 else:
186 args.append(k)
187 if os.environ.get("CONDA_DEFAULT_ENV") is not None:
188 f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
189 f.write(f"#cd {os.path.realpath(os.curdir)}\n")
190 f.write(" ".join(args) + "\n")
191 os.chmod(destfile, 0o755)
192
193
194def download_to_tempfile(url, progress=False):
195 """Downloads a file to a temporary named file and returns it
196
197 Parameters
198 ----------
199
200 url : str
201 The URL pointing to the file to download
202
203 progress : :py:class:`bool`, Optional
204 If a progress bar should be displayed for downloading the URL.
205
206
207 Returns
208 -------
209
210 f : tempfile.NamedTemporaryFile
211 A named temporary file that contains the downloaded URL
212
213 """
214
215 file_size = 0
216 response = urllib.request.urlopen(url)
217 meta = response.info()
218 if hasattr(meta, "getheaders"):
219 content_length = meta.getheaders("Content-Length")
220 else:
221 content_length = meta.get_all("Content-Length")
222
223 if content_length is not None and len(content_length) > 0:
224 file_size = int(content_length[0])
225
226 progress &= bool(file_size)
227
228 f = tempfile.NamedTemporaryFile()
229
230 with tqdm(total=file_size, disable=not progress) as pbar:
231 while True:
232 buffer = response.read(8192)
233 if len(buffer) == 0:
234 break
235 f.write(buffer)
236 pbar.update(len(buffer))
237
238 f.flush()
239 f.seek(0)
240 return f
241
242
243@with_plugins(pkg_resources.iter_entry_points("bob.ip.binseg.cli"))
244@click.group(cls=AliasedGroup)
245def binseg():
246 """Binary 2D Image Segmentation Benchmark commands."""
247
248
249@with_plugins(pkg_resources.iter_entry_points("bob.ip.detect.cli"))
250@click.group(cls=AliasedGroup)
251def detect():
252 """Object Detection Benchmark commands."""