Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/script/common.py: 63%

90 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4"""The main entry for bob ip binseg (click-based) scripts.""" 

5 

6import logging 

7import os 

8import random 

9import re 

10import sys 

11import tempfile 

12import time 

13import urllib.request 

14 

15import click 

16import numpy 

17import pkg_resources 

18import torch 

19 

20from click_plugins import with_plugins 

21from tqdm import tqdm 

22 

23from bob.extension.scripts.click_helper import AliasedGroup 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28def setup_pytorch_device(name): 

29 """Sets-up the pytorch device to use 

30 

31 

32 Parameters 

33 ---------- 

34 

35 name : str 

36 The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you 

37 set a specific cuda device such as ``cuda:1``, then we'll make sure it 

38 is currently set. 

39 

40 

41 Returns 

42 ------- 

43 

44 device : :py:class:`torch.device` 

45 The pytorch device to use, pre-configured (and checked) 

46 

47 """ 

48 

49 if name.startswith("cuda:"): 

50 # In case one has multiple devices, we must first set the one 

51 # we would like to use so pytorch can find it. 

52 logger.info(f"User set device to '{name}' - trying to force device...") 

53 os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1] 

54 if not torch.cuda.is_available(): 

55 raise RuntimeError( 

56 f"CUDA is not currently available, but " 

57 f"you set device to '{name}'" 

58 ) 

59 # Let pytorch auto-select from environment variable 

60 return torch.device("cuda") 

61 

62 elif name.startswith("cuda"): # use default device 

63 logger.info(f"User set device to '{name}' - using default CUDA device") 

64 assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None 

65 

66 # cuda or cpu 

67 return torch.device(name) 

68 

69 

70def set_seeds(value, all_gpus): 

71 """Sets up all relevant random seeds (numpy, python, cuda) 

72 

73 If running with multiple GPUs **at the same time**, set ``all_gpus`` to 

74 ``True`` to force all GPU seeds to be initialized. 

75 

76 Reference: `PyTorch page for reproducibility 

77 <https://pytorch.org/docs/stable/notes/randomness.html>`_. 

78 

79 

80 Parameters 

81 ---------- 

82 

83 value : int 

84 The random seed value to use 

85 

86 all_gpus : :py:class:`bool`, Optional 

87 If set, then reset the seed on all GPUs available at once. This is 

88 normally **not** what you want if running on a single GPU 

89 

90 """ 

91 

92 random.seed(value) 

93 numpy.random.seed(value) 

94 torch.manual_seed(value) 

95 torch.cuda.manual_seed(value) # noop if cuda not available 

96 

97 # set seeds for all gpus 

98 if all_gpus: 

99 torch.cuda.manual_seed_all(value) # noop if cuda not available 

100 

101 

102def set_reproducible_cuda(): 

103 """Turns-off all CUDA optimizations that would affect reproducibility 

104 

105 For full reproducibility, also ensure not to use multiple (parallel) data 

106 lowers. That is setup ``num_workers=0``. 

107 

108 Reference: `PyTorch page for reproducibility 

109 <https://pytorch.org/docs/stable/notes/randomness.html>`_. 

110 

111 

112 """ 

113 

114 # ensure to use only optimization algos for cuda that are known to have 

115 # a deterministic effect (not random) 

116 torch.backends.cudnn.deterministic = True 

117 

118 # turns off any optimization tricks 

119 torch.backends.cudnn.benchmark = False 

120 

121 

122def escape_name(v): 

123 """Escapes a name so it contains filesystem friendly characters only 

124 

125 This function escapes every character that's not a letter, ``_``, ``-``, 

126 ``.`` or space with an ``-``. 

127 

128 

129 Parameters 

130 ========== 

131 

132 v : str 

133 String to be escaped 

134 

135 

136 Returns 

137 ======= 

138 

139 s : str 

140 Escaped string 

141 

142 """ 

143 return re.sub(r"[^\w\-_\. ]", "-", v) 

144 

145 

146def save_sh_command(destfile): 

147 """Records command-line to reproduce this experiment 

148 

149 This function can record the current command-line used to call the script 

150 being run. It creates an executable ``bash`` script setting up the current 

151 working directory and activating a conda environment, if needed. It 

152 records further information on the date and time the script was run and the 

153 version of the package. 

154 

155 

156 Parameters 

157 ---------- 

158 

159 destfile : str 

160 Path leading to the file where the commands to reproduce the current 

161 run will be recorded. This file cannot be overwritten by this 

162 function. If needed, you should check and remove an existing file 

163 **before** calling this function. 

164 

165 """ 

166 

167 if os.path.exists(destfile): 

168 logger.info(f"Not overwriting existing file '{destfile}'") 

169 return 

170 

171 logger.info(f"Writing command-line for reproduction at '{destfile}'...") 

172 os.makedirs(os.path.dirname(destfile), exist_ok=True) 

173 

174 with open(destfile, "wt") as f: 

175 f.write("#!/usr/bin/env sh\n") 

176 f.write(f"# date: {time.asctime()}\n") 

177 version = pkg_resources.require("bob.ip.binseg")[0].version 

178 f.write(f"# version: {version} (bob.ip.binseg)\n") 

179 f.write(f"# platform: {sys.platform}\n") 

180 f.write("\n") 

181 args = [] 

182 for k in sys.argv: 

183 if " " in k: 

184 args.append(f'"{k}"') 

185 else: 

186 args.append(k) 

187 if os.environ.get("CONDA_DEFAULT_ENV") is not None: 

188 f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") 

189 f.write(f"#cd {os.path.realpath(os.curdir)}\n") 

190 f.write(" ".join(args) + "\n") 

191 os.chmod(destfile, 0o755) 

192 

193 

194def download_to_tempfile(url, progress=False): 

195 """Downloads a file to a temporary named file and returns it 

196 

197 Parameters 

198 ---------- 

199 

200 url : str 

201 The URL pointing to the file to download 

202 

203 progress : :py:class:`bool`, Optional 

204 If a progress bar should be displayed for downloading the URL. 

205 

206 

207 Returns 

208 ------- 

209 

210 f : tempfile.NamedTemporaryFile 

211 A named temporary file that contains the downloaded URL 

212 

213 """ 

214 

215 file_size = 0 

216 response = urllib.request.urlopen(url) 

217 meta = response.info() 

218 if hasattr(meta, "getheaders"): 

219 content_length = meta.getheaders("Content-Length") 

220 else: 

221 content_length = meta.get_all("Content-Length") 

222 

223 if content_length is not None and len(content_length) > 0: 

224 file_size = int(content_length[0]) 

225 

226 progress &= bool(file_size) 

227 

228 f = tempfile.NamedTemporaryFile() 

229 

230 with tqdm(total=file_size, disable=not progress) as pbar: 

231 while True: 

232 buffer = response.read(8192) 

233 if len(buffer) == 0: 

234 break 

235 f.write(buffer) 

236 pbar.update(len(buffer)) 

237 

238 f.flush() 

239 f.seek(0) 

240 return f 

241 

242 

243@with_plugins(pkg_resources.iter_entry_points("bob.ip.binseg.cli")) 

244@click.group(cls=AliasedGroup) 

245def binseg(): 

246 """Binary 2D Image Segmentation Benchmark commands.""" 

247 

248 

249@with_plugins(pkg_resources.iter_entry_points("bob.ip.detect.cli")) 

250@click.group(cls=AliasedGroup) 

251def detect(): 

252 """Object Detection Benchmark commands."""