Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
4"""The main entry for bob ip binseg (click-based) scripts."""
6import logging
7import os
8import random
9import re
10import sys
11import tempfile
12import time
13import urllib.request
15import click
16import numpy
17import pkg_resources
18import torch
20from click_plugins import with_plugins
21from tqdm import tqdm
23from bob.extension.scripts.click_helper import AliasedGroup
25logger = logging.getLogger(__name__)
28def setup_pytorch_device(name):
29 """Sets-up the pytorch device to use
32 Parameters
33 ----------
35 name : str
36 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you
37 set a specific cuda device such as ``cuda:1``, then we'll make sure it
38 is currently set.
41 Returns
42 -------
44 device : :py:class:`torch.device`
45 The pytorch device to use, pre-configured (and checked)
47 """
49 if name.startswith("cuda:"):
50 # In case one has multiple devices, we must first set the one
51 # we would like to use so pytorch can find it.
52 logger.info(f"User set device to '{name}' - trying to force device...")
53 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
54 if not torch.cuda.is_available():
55 raise RuntimeError(
56 f"CUDA is not currently available, but "
57 f"you set device to '{name}'"
58 )
59 # Let pytorch auto-select from environment variable
60 return torch.device("cuda")
62 elif name.startswith("cuda"): # use default device
63 logger.info(f"User set device to '{name}' - using default CUDA device")
64 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
66 # cuda or cpu
67 return torch.device(name)
70def set_seeds(value, all_gpus):
71 """Sets up all relevant random seeds (numpy, python, cuda)
73 If running with multiple GPUs **at the same time**, set ``all_gpus`` to
74 ``True`` to force all GPU seeds to be initialized.
76 Reference: `PyTorch page for reproducibility
77 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
80 Parameters
81 ----------
83 value : int
84 The random seed value to use
86 all_gpus : :py:class:`bool`, Optional
87 If set, then reset the seed on all GPUs available at once. This is
88 normally **not** what you want if running on a single GPU
90 """
92 random.seed(value)
93 numpy.random.seed(value)
94 torch.manual_seed(value)
95 torch.cuda.manual_seed(value) # noop if cuda not available
97 # set seeds for all gpus
98 if all_gpus:
99 torch.cuda.manual_seed_all(value) # noop if cuda not available
102def set_reproducible_cuda():
103 """Turns-off all CUDA optimizations that would affect reproducibility
105 For full reproducibility, also ensure not to use multiple (parallel) data
106 lowers. That is setup ``num_workers=0``.
108 Reference: `PyTorch page for reproducibility
109 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
112 """
114 # ensure to use only optimization algos for cuda that are known to have
115 # a deterministic effect (not random)
116 torch.backends.cudnn.deterministic = True
118 # turns off any optimization tricks
119 torch.backends.cudnn.benchmark = False
122def escape_name(v):
123 """Escapes a name so it contains filesystem friendly characters only
125 This function escapes every character that's not a letter, ``_``, ``-``,
126 ``.`` or space with an ``-``.
129 Parameters
130 ==========
132 v : str
133 String to be escaped
136 Returns
137 =======
139 s : str
140 Escaped string
142 """
143 return re.sub(r"[^\w\-_\. ]", "-", v)
146def save_sh_command(destfile):
147 """Records command-line to reproduce this experiment
149 This function can record the current command-line used to call the script
150 being run. It creates an executable ``bash`` script setting up the current
151 working directory and activating a conda environment, if needed. It
152 records further information on the date and time the script was run and the
153 version of the package.
156 Parameters
157 ----------
159 destfile : str
160 Path leading to the file where the commands to reproduce the current
161 run will be recorded. This file cannot be overwritten by this
162 function. If needed, you should check and remove an existing file
163 **before** calling this function.
165 """
167 if os.path.exists(destfile):
168 logger.info(f"Not overwriting existing file '{destfile}'")
169 return
171 logger.info(f"Writing command-line for reproduction at '{destfile}'...")
172 os.makedirs(os.path.dirname(destfile), exist_ok=True)
174 with open(destfile, "wt") as f:
175 f.write("#!/usr/bin/env sh\n")
176 f.write(f"# date: {time.asctime()}\n")
177 version = pkg_resources.require("bob.ip.binseg")[0].version
178 f.write(f"# version: {version} (bob.ip.binseg)\n")
179 f.write(f"# platform: {sys.platform}\n")
180 f.write("\n")
181 args = []
182 for k in sys.argv:
183 if " " in k:
184 args.append(f'"{k}"')
185 else:
186 args.append(k)
187 if os.environ.get("CONDA_DEFAULT_ENV") is not None:
188 f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
189 f.write(f"#cd {os.path.realpath(os.curdir)}\n")
190 f.write(" ".join(args) + "\n")
191 os.chmod(destfile, 0o755)
194def download_to_tempfile(url, progress=False):
195 """Downloads a file to a temporary named file and returns it
197 Parameters
198 ----------
200 url : str
201 The URL pointing to the file to download
203 progress : :py:class:`bool`, Optional
204 If a progress bar should be displayed for downloading the URL.
207 Returns
208 -------
210 f : tempfile.NamedTemporaryFile
211 A named temporary file that contains the downloaded URL
213 """
215 file_size = 0
216 response = urllib.request.urlopen(url)
217 meta = response.info()
218 if hasattr(meta, "getheaders"):
219 content_length = meta.getheaders("Content-Length")
220 else:
221 content_length = meta.get_all("Content-Length")
223 if content_length is not None and len(content_length) > 0:
224 file_size = int(content_length[0])
226 progress &= bool(file_size)
228 f = tempfile.NamedTemporaryFile()
230 with tqdm(total=file_size, disable=not progress) as pbar:
231 while True:
232 buffer = response.read(8192)
233 if len(buffer) == 0:
234 break
235 f.write(buffer)
236 pbar.update(len(buffer))
238 f.flush()
239 f.seek(0)
240 return f
243@with_plugins(pkg_resources.iter_entry_points("bob.ip.binseg.cli"))
244@click.group(cls=AliasedGroup)
245def binseg():
246 """Binary 2D Image Segmentation Benchmark commands."""