Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import multiprocessing 

6import os 

7import sys 

8 

9import click 

10import torch 

11 

12from torch.utils.data import DataLoader 

13 

14from bob.extension.scripts.click_helper import ( 

15 ConfigCommand, 

16 ResourceOption, 

17 verbosity_option, 

18) 

19 

20from ..engine.predictor import run 

21from ..utils.checkpointer import Checkpointer 

22from .binseg import download_to_tempfile, setup_pytorch_device 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27@click.command( 

28 entry_point_group="bob.ip.binseg.config", 

29 cls=ConfigCommand, 

30 epilog="""Examples: 

31 

32\b 

33 1. Runs prediction on an existing dataset configuration: 

34\b 

35 $ bob binseg predict -vv m2unet drive --weight=path/to/model_final.pth --output-folder=path/to/predictions 

36\b 

37 2. To run prediction on a folder with your own images, you must first 

38 specify resizing, cropping, etc, so that the image can be correctly 

39 input to the model. Failing to do so will likely result in poor 

40 performance. To figure out such specifications, you must consult the 

41 dataset configuration used for **training** the provided model. Once 

42 you figured this out, do the following: 

43\b 

44 $ bob binseg config copy csv-dataset-example mydataset.py 

45 # modify "mydataset.py" to include the base path and required transforms 

46 $ bob binseg predict -vv m2unet mydataset.py --weight=path/to/model_final.pth --output-folder=path/to/predictions 

47""", 

48) 

49@click.option( 

50 "--output-folder", 

51 "-o", 

52 help="Path where to store the predictions (created if does not exist)", 

53 required=True, 

54 default="results", 

55 cls=ResourceOption, 

56 type=click.Path(), 

57) 

58@click.option( 

59 "--model", 

60 "-m", 

61 help="A torch.nn.Module instance implementing the network to be evaluated", 

62 required=True, 

63 cls=ResourceOption, 

64) 

65@click.option( 

66 "--dataset", 

67 "-d", 

68 help="A torch.utils.data.dataset.Dataset instance implementing a dataset " 

69 "to be used for running prediction, possibly including all pre-processing " 

70 "pipelines required or, optionally, a dictionary mapping string keys to " 

71 "torch.utils.data.dataset.Dataset instances. All keys that do not start " 

72 "with an underscore (_) will be processed.", 

73 required=True, 

74 cls=ResourceOption, 

75) 

76@click.option( 

77 "--batch-size", 

78 "-b", 

79 help="Number of samples in every batch (this parameter affects memory requirements for the network)", 

80 required=True, 

81 show_default=True, 

82 default=1, 

83 type=click.IntRange(min=1), 

84 cls=ResourceOption, 

85) 

86@click.option( 

87 "--device", 

88 "-d", 

89 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

90 show_default=True, 

91 required=True, 

92 default="cpu", 

93 cls=ResourceOption, 

94) 

95@click.option( 

96 "--weight", 

97 "-w", 

98 help="Path or URL to pretrained model file (.pth extension)", 

99 required=True, 

100 cls=ResourceOption, 

101) 

102@click.option( 

103 "--overlayed", 

104 "-O", 

105 help="Creates overlayed representations of the output probability maps on " 

106 "top of input images (store results as PNG files). If not set, or empty " 

107 "then do **NOT** output overlayed images. Otherwise, the parameter " 

108 "represents the name of a folder where to store those", 

109 show_default=True, 

110 default=None, 

111 required=False, 

112 cls=ResourceOption, 

113) 

114@click.option( 

115 "--multiproc-data-loading", 

116 "-P", 

117 help="""Use multiprocessing for data loading: if set to -1 (default), 

118 disables multiprocessing data loading. Set to 0 to enable as many data 

119 loading instances as processing cores as available in the system. Set to 

120 >= 1 to enable that many multiprocessing instances for data loading.""", 

121 type=click.IntRange(min=-1), 

122 show_default=True, 

123 required=True, 

124 default=-1, 

125 cls=ResourceOption, 

126) 

127@verbosity_option(cls=ResourceOption) 

128def predict( 

129 output_folder, 

130 model, 

131 dataset, 

132 batch_size, 

133 device, 

134 weight, 

135 overlayed, 

136 multiproc_data_loading, 

137 **kwargs, 

138): 

139 """Predicts vessel map (probabilities) on input images""" 

140 

141 device = setup_pytorch_device(device) 

142 

143 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) 

144 

145 if weight.startswith("http"): 

146 logger.info(f"Temporarily downloading '{weight}'...") 

147 f = download_to_tempfile(weight, progress=True) 

148 weight_fullpath = os.path.abspath(f.name) 

149 else: 

150 weight_fullpath = os.path.abspath(weight) 

151 

152 checkpointer = Checkpointer(model) 

153 checkpointer.load(weight_fullpath) 

154 

155 # clean-up the overlayed path 

156 if overlayed is not None: 

157 overlayed = overlayed.strip() 

158 

159 for k, v in dataset.items(): 

160 

161 if k.startswith("_"): 

162 logger.info(f"Skipping dataset '{k}' (not to be evaluated)") 

163 continue 

164 

165 logger.info(f"Running inference on '{k}' set...") 

166 

167 multiproc_kwargs = dict() 

168 if multiproc_data_loading < 0: 

169 multiproc_kwargs["num_workers"] = 0 

170 elif multiproc_data_loading == 0: 

171 multiproc_kwargs["num_workers"] = multiprocessing.cpu_count() 

172 else: 

173 multiproc_kwargs["num_workers"] = multiproc_data_loading 

174 

175 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": 

176 multiproc_kwargs[ 

177 "multiprocessing_context" 

178 ] = multiprocessing.get_context("spawn") 

179 

180 data_loader = DataLoader( 

181 dataset=v, 

182 batch_size=batch_size, 

183 shuffle=False, 

184 pin_memory=torch.cuda.is_available(), 

185 **multiproc_kwargs, 

186 ) 

187 run(model, data_loader, k, device, output_folder, overlayed)