Coverage for src/deepdraw/script/common.py: 60%
87 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""The main entry for deepdraw (click-based) scripts."""
7import logging
8import os
9import random
10import re
11import sys
12import tempfile
13import time
14import urllib.request
16import click
17import numpy
18import pkg_resources
19import torch
21from clapper.click import AliasedGroup
22from click_plugins import with_plugins
23from tqdm import tqdm
25logger = logging.getLogger(__name__)
28def setup_pytorch_device(name):
29 """Sets-up the pytorch device to use.
31 Parameters
32 ----------
34 name : str
35 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you
36 set a specific cuda device such as ``cuda:1``, then we'll make sure it
37 is currently set.
40 Returns
41 -------
43 device : :py:class:`torch.device`
44 The pytorch device to use, pre-configured (and checked)
45 """
47 if name.startswith("cuda:"):
48 # In case one has multiple devices, we must first set the one
49 # we would like to use so pytorch can find it.
50 logger.info(f"User set device to '{name}' - trying to force device...")
51 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
52 if not torch.cuda.is_available():
53 raise RuntimeError(
54 f"CUDA is not currently available, but "
55 f"you set device to '{name}'"
56 )
57 # Let pytorch auto-select from environment variable
58 return torch.device("cuda")
60 elif name.startswith("cuda"): # use default device
61 logger.info(f"User set device to '{name}' - using default CUDA device")
62 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
64 # cuda or cpu
65 return torch.device(name)
68def set_seeds(value, all_gpus):
69 """Sets up all relevant random seeds (numpy, python, cuda)
71 If running with multiple GPUs **at the same time**, set ``all_gpus`` to
72 ``True`` to force all GPU seeds to be initialized.
74 Reference: `PyTorch page for reproducibility
75 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
78 Parameters
79 ----------
81 value : int
82 The random seed value to use
84 all_gpus : :py:class:`bool`, Optional
85 If set, then reset the seed on all GPUs available at once. This is
86 normally **not** what you want if running on a single GPU
87 """
89 random.seed(value)
90 numpy.random.seed(value)
91 torch.manual_seed(value)
92 torch.cuda.manual_seed(value) # noop if cuda not available
94 # set seeds for all gpus
95 if all_gpus:
96 torch.cuda.manual_seed_all(value) # noop if cuda not available
99def set_reproducible_cuda():
100 """Turns-off all CUDA optimizations that would affect reproducibility.
102 For full reproducibility, also ensure not to use multiple (parallel) data
103 lowers. That is setup ``num_workers=0``.
105 Reference: `PyTorch page for reproducibility
106 <https://pytorch.org/docs/stable/notes/randomness.html>`_.
107 """
109 # ensure to use only optimization algos for cuda that are known to have
110 # a deterministic effect (not random)
111 torch.backends.cudnn.deterministic = True
113 # turns off any optimization tricks
114 torch.backends.cudnn.benchmark = False
117def escape_name(v):
118 """Escapes a name so it contains filesystem friendly characters only.
120 This function escapes every character that's not a letter, ``_``, ``-``,
121 ``.`` or space with an ``-``.
124 Parameters
125 ==========
127 v : str
128 String to be escaped
131 Returns
132 =======
134 s : str
135 Escaped string
136 """
137 return re.sub(r"[^\w\-_\. ]", "-", v)
140def save_sh_command(destfile):
141 """Records command-line to reproduce this experiment.
143 This function can record the current command-line used to call the script
144 being run. It creates an executable ``bash`` script setting up the current
145 working directory and activating a conda environment, if needed. It
146 records further information on the date and time the script was run and the
147 version of the package.
150 Parameters
151 ----------
153 destfile : str
154 Path leading to the file where the commands to reproduce the current
155 run will be recorded. This file cannot be overwritten by this
156 function. If needed, you should check and remove an existing file
157 **before** calling this function.
158 """
160 if os.path.exists(destfile):
161 logger.info(f"Not overwriting existing file '{destfile}'")
162 return
164 logger.info(f"Writing command-line for reproduction at '{destfile}'...")
165 os.makedirs(os.path.dirname(destfile), exist_ok=True)
167 with open(destfile, "w") as f:
168 f.write("#!/usr/bin/env sh\n")
169 f.write(f"# date: {time.asctime()}\n")
170 version = pkg_resources.require("deepdraw")[0].version
171 f.write(f"# version: {version} (deepdraw)\n")
172 f.write(f"# platform: {sys.platform}\n")
173 f.write("\n")
174 args = []
175 for k in sys.argv:
176 if " " in k:
177 args.append(f'"{k}"')
178 else:
179 args.append(k)
180 if os.environ.get("CONDA_DEFAULT_ENV") is not None:
181 f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
182 f.write(f"#cd {os.path.realpath(os.curdir)}\n")
183 f.write(" ".join(args) + "\n")
184 os.chmod(destfile, 0o755)
187def download_to_tempfile(url, progress=False):
188 """Downloads a file to a temporary named file and returns it.
190 Parameters
191 ----------
193 url : str
194 The URL pointing to the file to download
196 progress : :py:class:`bool`, Optional
197 If a progress bar should be displayed for downloading the URL.
200 Returns
201 -------
203 f : tempfile.NamedTemporaryFile
204 A named temporary file that contains the downloaded URL
205 """
207 file_size = 0
208 response = urllib.request.urlopen(url)
209 meta = response.info()
210 if hasattr(meta, "getheaders"):
211 content_length = meta.getheaders("Content-Length")
212 else:
213 content_length = meta.get_all("Content-Length")
215 if content_length is not None and len(content_length) > 0:
216 file_size = int(content_length[0])
218 progress &= bool(file_size)
220 f = tempfile.NamedTemporaryFile()
222 with tqdm(total=file_size, disable=not progress) as pbar:
223 while True:
224 buffer = response.read(8192)
225 if len(buffer) == 0:
226 break
227 f.write(buffer)
228 pbar.update(len(buffer))
230 f.flush()
231 f.seek(0)
232 return f
235@with_plugins(pkg_resources.iter_entry_points("deepdraw.cli"))
236@click.group(cls=AliasedGroup)
237def deepdraw():
238 """Binary 2D Image Segmentation Benchmark commands."""