Coverage for src/deepdraw/script/common.py: 60%

87 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5"""The main entry for deepdraw (click-based) scripts.""" 

6 

7import logging 

8import os 

9import random 

10import re 

11import sys 

12import tempfile 

13import time 

14import urllib.request 

15 

16import click 

17import numpy 

18import pkg_resources 

19import torch 

20 

21from clapper.click import AliasedGroup 

22from click_plugins import with_plugins 

23from tqdm import tqdm 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28def setup_pytorch_device(name): 

29 """Sets-up the pytorch device to use. 

30 

31 Parameters 

32 ---------- 

33 

34 name : str 

35 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you 

36 set a specific cuda device such as ``cuda:1``, then we'll make sure it 

37 is currently set. 

38 

39 

40 Returns 

41 ------- 

42 

43 device : :py:class:`torch.device` 

44 The pytorch device to use, pre-configured (and checked) 

45 """ 

46 

47 if name.startswith("cuda:"): 

48 # In case one has multiple devices, we must first set the one 

49 # we would like to use so pytorch can find it. 

50 logger.info(f"User set device to '{name}' - trying to force device...") 

51 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1] 

52 if not torch.cuda.is_available(): 

53 raise RuntimeError( 

54 f"CUDA is not currently available, but " 

55 f"you set device to '{name}'" 

56 ) 

57 # Let pytorch auto-select from environment variable 

58 return torch.device("cuda") 

59 

60 elif name.startswith("cuda"): # use default device 

61 logger.info(f"User set device to '{name}' - using default CUDA device") 

62 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None 

63 

64 # cuda or cpu 

65 return torch.device(name) 

66 

67 

68def set_seeds(value, all_gpus): 

69 """Sets up all relevant random seeds (numpy, python, cuda) 

70 

71 If running with multiple GPUs **at the same time**, set ``all_gpus`` to 

72 ``True`` to force all GPU seeds to be initialized. 

73 

74 Reference: `PyTorch page for reproducibility 

75 <https://pytorch.org/docs/stable/notes/randomness.html>`_. 

76 

77 

78 Parameters 

79 ---------- 

80 

81 value : int 

82 The random seed value to use 

83 

84 all_gpus : :py:class:`bool`, Optional 

85 If set, then reset the seed on all GPUs available at once. This is 

86 normally **not** what you want if running on a single GPU 

87 """ 

88 

89 random.seed(value) 

90 numpy.random.seed(value) 

91 torch.manual_seed(value) 

92 torch.cuda.manual_seed(value) # noop if cuda not available 

93 

94 # set seeds for all gpus 

95 if all_gpus: 

96 torch.cuda.manual_seed_all(value) # noop if cuda not available 

97 

98 

99def set_reproducible_cuda(): 

100 """Turns-off all CUDA optimizations that would affect reproducibility. 

101 

102 For full reproducibility, also ensure not to use multiple (parallel) data 

103 lowers. That is setup ``num_workers=0``. 

104 

105 Reference: `PyTorch page for reproducibility 

106 <https://pytorch.org/docs/stable/notes/randomness.html>`_. 

107 """ 

108 

109 # ensure to use only optimization algos for cuda that are known to have 

110 # a deterministic effect (not random) 

111 torch.backends.cudnn.deterministic = True 

112 

113 # turns off any optimization tricks 

114 torch.backends.cudnn.benchmark = False 

115 

116 

117def escape_name(v): 

118 """Escapes a name so it contains filesystem friendly characters only. 

119 

120 This function escapes every character that's not a letter, ``_``, ``-``, 

121 ``.`` or space with an ``-``. 

122 

123 

124 Parameters 

125 ========== 

126 

127 v : str 

128 String to be escaped 

129 

130 

131 Returns 

132 ======= 

133 

134 s : str 

135 Escaped string 

136 """ 

137 return re.sub(r"[^\w\-_\. ]", "-", v) 

138 

139 

140def save_sh_command(destfile): 

141 """Records command-line to reproduce this experiment. 

142 

143 This function can record the current command-line used to call the script 

144 being run. It creates an executable ``bash`` script setting up the current 

145 working directory and activating a conda environment, if needed. It 

146 records further information on the date and time the script was run and the 

147 version of the package. 

148 

149 

150 Parameters 

151 ---------- 

152 

153 destfile : str 

154 Path leading to the file where the commands to reproduce the current 

155 run will be recorded. This file cannot be overwritten by this 

156 function. If needed, you should check and remove an existing file 

157 **before** calling this function. 

158 """ 

159 

160 if os.path.exists(destfile): 

161 logger.info(f"Not overwriting existing file '{destfile}'") 

162 return 

163 

164 logger.info(f"Writing command-line for reproduction at '{destfile}'...") 

165 os.makedirs(os.path.dirname(destfile), exist_ok=True) 

166 

167 with open(destfile, "w") as f: 

168 f.write("#!/usr/bin/env sh\n") 

169 f.write(f"# date: {time.asctime()}\n") 

170 version = pkg_resources.require("deepdraw")[0].version 

171 f.write(f"# version: {version} (deepdraw)\n") 

172 f.write(f"# platform: {sys.platform}\n") 

173 f.write("\n") 

174 args = [] 

175 for k in sys.argv: 

176 if " " in k: 

177 args.append(f'"{k}"') 

178 else: 

179 args.append(k) 

180 if os.environ.get("CONDA_DEFAULT_ENV") is not None: 

181 f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") 

182 f.write(f"#cd {os.path.realpath(os.curdir)}\n") 

183 f.write(" ".join(args) + "\n") 

184 os.chmod(destfile, 0o755) 

185 

186 

187def download_to_tempfile(url, progress=False): 

188 """Downloads a file to a temporary named file and returns it. 

189 

190 Parameters 

191 ---------- 

192 

193 url : str 

194 The URL pointing to the file to download 

195 

196 progress : :py:class:`bool`, Optional 

197 If a progress bar should be displayed for downloading the URL. 

198 

199 

200 Returns 

201 ------- 

202 

203 f : tempfile.NamedTemporaryFile 

204 A named temporary file that contains the downloaded URL 

205 """ 

206 

207 file_size = 0 

208 response = urllib.request.urlopen(url) 

209 meta = response.info() 

210 if hasattr(meta, "getheaders"): 

211 content_length = meta.getheaders("Content-Length") 

212 else: 

213 content_length = meta.get_all("Content-Length") 

214 

215 if content_length is not None and len(content_length) > 0: 

216 file_size = int(content_length[0]) 

217 

218 progress &= bool(file_size) 

219 

220 f = tempfile.NamedTemporaryFile() 

221 

222 with tqdm(total=file_size, disable=not progress) as pbar: 

223 while True: 

224 buffer = response.read(8192) 

225 if len(buffer) == 0: 

226 break 

227 f.write(buffer) 

228 pbar.update(len(buffer)) 

229 

230 f.flush() 

231 f.seek(0) 

232 return f 

233 

234 

235@with_plugins(pkg_resources.iter_entry_points("deepdraw.cli")) 

236@click.group(cls=AliasedGroup) 

237def deepdraw(): 

238 """Binary 2D Image Segmentation Benchmark commands."""