1#!/usr/bin/env python
2# vim: set fileencoding=utf-8 :
3
4# Adpated from:
5# https://github.com/pytorch/pytorch/blob/master/torch/hub.py
6# https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/checkpoint.py
7
8import hashlib
9import os
10import re
11import shutil
12import sys
13import tempfile
14from urllib.request import urlopen
15from urllib.parse import urlparse
16from tqdm import tqdm
17
18modelurls = {
19 "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
20 "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
21 "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
22 "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
23 "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
24 "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
25 "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
26 "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
27 "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
28 "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
29 "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
30 "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
31 "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
32 #"resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
33 "resnet50_SIN_IN": "http://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
34 #"mobilenetv2": "https://dl.dropboxusercontent.com/s/4nie4ygivq04p8y/mobilenet_v2.pth.tar",
35 "mobilenetv2": "http://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/mobilenet_v2.pth.tar",
36}
37"""URLs of pre-trained models (backbones)"""
38
39
40def download_url_to_file(url, dst, hash_prefix, progress):
41 file_size = None
42 u = urlopen(url)
43 meta = u.info()
44 if hasattr(meta, "getheaders"):
45 content_length = meta.getheaders("Content-Length")
46 else:
47 content_length = meta.get_all("Content-Length")
48 if content_length is not None and len(content_length) > 0:
49 file_size = int(content_length[0])
50
51 f = tempfile.NamedTemporaryFile(delete=False)
52 try:
53 if hash_prefix is not None:
54 sha256 = hashlib.sha256()
55 with tqdm(total=file_size, disable=not progress) as pbar:
56 while True:
57 buffer = u.read(8192)
58 if len(buffer) == 0:
59 break
60 f.write(buffer)
61 if hash_prefix is not None:
62 sha256.update(buffer)
63 pbar.update(len(buffer))
64
65 f.close()
66 if hash_prefix is not None:
67 digest = sha256.hexdigest()
68 if digest[: len(hash_prefix)] != hash_prefix:
69 raise RuntimeError(
70 'invalid hash value (expected "{}", got "{}")'.format(
71 hash_prefix, digest
72 )
73 )
74 shutil.move(f.name, dst)
75 finally:
76 f.close()
77 if os.path.exists(f.name):
78 os.remove(f.name)
79
80
81HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
82
83
84def cache_url(url, model_dir=None, progress=True):
85 r"""Loads the Torch serialized object at the given URL.
86 If the object is already present in `model_dir`, it's deserialized and
87 returned. The filename part of the URL should follow the naming convention
88 ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
89 digits of the SHA256 hash of the contents of the file. The hash is used to
90 ensure unique names and to verify the contents of the file.
91 The default value of `model_dir` is ``$TORCH_HOME/models`` where
92 ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
93 overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
94 Args:
95 url (string): URL of the object to download
96 model_dir (string, optional): directory in which to save the object
97 progress (bool, optional): whether or not to display a progress bar to stderr
98
99 """
100 if model_dir is None:
101 torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
102 model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
103 if not os.path.exists(model_dir):
104 os.makedirs(model_dir)
105 parts = urlparse(url)
106 filename = os.path.basename(parts.path)
107
108 cached_file = os.path.join(model_dir, filename)
109 if not os.path.exists(cached_file):
110 sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
111 hash_prefix = HASH_REGEX.search(filename)
112 if hash_prefix is not None:
113 hash_prefix = hash_prefix.group(1)
114 download_url_to_file(url, cached_file, hash_prefix, progress=progress)
115
116 return cached_file