Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/utils/model_zoo.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

57 statements  

1#!/usr/bin/env python 

2# vim: set fileencoding=utf-8 : 

3 

4# Adpated from: 

5# https://github.com/pytorch/pytorch/blob/master/torch/hub.py 

6# https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/checkpoint.py 

7 

8import hashlib 

9import os 

10import re 

11import shutil 

12import sys 

13import tempfile 

14from urllib.request import urlopen 

15from urllib.parse import urlparse 

16from tqdm import tqdm 

17 

18modelurls = { 

19 "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", 

20 "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", 

21 "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", 

22 "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", 

23 "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", 

24 "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", 

25 "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", 

26 "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", 

27 "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 

28 "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 

29 "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 

30 "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 

31 "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 

32 #"resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar", 

33 "resnet50_SIN_IN": "http://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar", 

34 #"mobilenetv2": "https://dl.dropboxusercontent.com/s/4nie4ygivq04p8y/mobilenet_v2.pth.tar", 

35 "mobilenetv2": "http://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/mobilenet_v2.pth.tar", 

36} 

37"""URLs of pre-trained models (backbones)""" 

38 

39 

40def download_url_to_file(url, dst, hash_prefix, progress): 

41 file_size = None 

42 u = urlopen(url) 

43 meta = u.info() 

44 if hasattr(meta, "getheaders"): 

45 content_length = meta.getheaders("Content-Length") 

46 else: 

47 content_length = meta.get_all("Content-Length") 

48 if content_length is not None and len(content_length) > 0: 

49 file_size = int(content_length[0]) 

50 

51 f = tempfile.NamedTemporaryFile(delete=False) 

52 try: 

53 if hash_prefix is not None: 

54 sha256 = hashlib.sha256() 

55 with tqdm(total=file_size, disable=not progress) as pbar: 

56 while True: 

57 buffer = u.read(8192) 

58 if len(buffer) == 0: 

59 break 

60 f.write(buffer) 

61 if hash_prefix is not None: 

62 sha256.update(buffer) 

63 pbar.update(len(buffer)) 

64 

65 f.close() 

66 if hash_prefix is not None: 

67 digest = sha256.hexdigest() 

68 if digest[: len(hash_prefix)] != hash_prefix: 

69 raise RuntimeError( 

70 'invalid hash value (expected "{}", got "{}")'.format( 

71 hash_prefix, digest 

72 ) 

73 ) 

74 shutil.move(f.name, dst) 

75 finally: 

76 f.close() 

77 if os.path.exists(f.name): 

78 os.remove(f.name) 

79 

80 

81HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") 

82 

83 

84def cache_url(url, model_dir=None, progress=True): 

85 r"""Loads the Torch serialized object at the given URL. 

86 If the object is already present in `model_dir`, it's deserialized and 

87 returned. The filename part of the URL should follow the naming convention 

88 ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more 

89 digits of the SHA256 hash of the contents of the file. The hash is used to 

90 ensure unique names and to verify the contents of the file. 

91 The default value of `model_dir` is ``$TORCH_HOME/models`` where 

92 ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be 

93 overridden with the ``$TORCH_MODEL_ZOO`` environment variable. 

94 Args: 

95 url (string): URL of the object to download 

96 model_dir (string, optional): directory in which to save the object 

97 progress (bool, optional): whether or not to display a progress bar to stderr 

98 

99 """ 

100 if model_dir is None: 

101 torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) 

102 model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) 

103 if not os.path.exists(model_dir): 

104 os.makedirs(model_dir) 

105 parts = urlparse(url) 

106 filename = os.path.basename(parts.path) 

107 

108 cached_file = os.path.join(model_dir, filename) 

109 if not os.path.exists(cached_file): 

110 sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 

111 hash_prefix = HASH_REGEX.search(filename) 

112 if hash_prefix is not None: 

113 hash_prefix = hash_prefix.group(1) 

114 download_url_to_file(url, cached_file, hash_prefix, progress=progress) 

115 

116 return cached_file