Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/scripts/predict.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

77 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5import shutil 

6import tempfile 

7import copy 

8 

9import click 

10import torch 

11import numpy as np 

12from sklearn import metrics 

13from torch.utils.data import DataLoader, ConcatDataset 

14from matplotlib.backends.backend_pdf import PdfPages 

15 

16from bob.extension.scripts.click_helper import ( 

17 verbosity_option, 

18 ConfigCommand, 

19 ResourceOption, 

20) 

21 

22from ..engine.predictor import run 

23from ..utils.checkpointer import Checkpointer 

24 

25from .tb import download_to_tempfile 

26from ..utils.plot import relevance_analysis_plot 

27 

28import logging 

29logger = logging.getLogger(__name__) 

30 

31 

32@click.command( 

33 entry_point_group="bob.med.tb.config", 

34 cls=ConfigCommand, 

35 epilog="""Examples: 

36 

37\b 

38 1. Runs prediction on an existing dataset configuration: 

39\b 

40 $ bob tb predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions 

41 

42""", 

43) 

44@click.option( 

45 "--output-folder", 

46 "-o", 

47 help="Path where to store the predictions (created if does not exist)", 

48 required=True, 

49 default="results", 

50 cls=ResourceOption, 

51 type=click.Path(), 

52) 

53@click.option( 

54 "--model", 

55 "-m", 

56 help="A torch.nn.Module instance implementing the network to be evaluated", 

57 required=True, 

58 cls=ResourceOption, 

59) 

60@click.option( 

61 "--dataset", 

62 "-d", 

63 help="A torch.utils.data.dataset.Dataset instance implementing a dataset " 

64 "to be used for running prediction, possibly including all pre-processing " 

65 "pipelines required or, optionally, a dictionary mapping string keys to " 

66 "torch.utils.data.dataset.Dataset instances. All keys that do not start " 

67 "with an underscore (_) will be processed.", 

68 required=True, 

69 cls=ResourceOption, 

70) 

71@click.option( 

72 "--batch-size", 

73 "-b", 

74 help="Number of samples in every batch (this parameter affects memory requirements for the network)", 

75 required=True, 

76 show_default=True, 

77 default=1, 

78 type=click.IntRange(min=1), 

79 cls=ResourceOption, 

80) 

81@click.option( 

82 "--device", 

83 "-d", 

84 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

85 show_default=True, 

86 required=True, 

87 default="cpu", 

88 cls=ResourceOption, 

89) 

90@click.option( 

91 "--weight", 

92 "-w", 

93 help="Path or URL to pretrained model file (.pth extension)", 

94 required=True, 

95 cls=ResourceOption, 

96) 

97@click.option( 

98 "--relevance-analysis", 

99 "-r", 

100 help="If set, generate relevance analysis pdfs to indicate the relative" 

101 "importance of each feature", 

102 is_flag=True, 

103 cls=ResourceOption, 

104) 

105@click.option( 

106 "--grad-cams", 

107 "-g", 

108 help="If set, generate grad cams for each prediction (must use batch of 1)", 

109 is_flag=True, 

110 cls=ResourceOption, 

111) 

112@verbosity_option(cls=ResourceOption) 

113def predict(output_folder, model, dataset, batch_size, device, weight, 

114 relevance_analysis, grad_cams, **kwargs): 

115 """Predicts Tuberculosis presence (probabilities) on input images""" 

116 

117 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) 

118 

119 if weight.startswith("http"): 

120 logger.info(f"Temporarily downloading '{weight}'...") 

121 f = download_to_tempfile(weight, progress=True) 

122 weight_fullpath = os.path.abspath(f.name) 

123 else: 

124 weight_fullpath = os.path.abspath(weight) 

125 

126 checkpointer = Checkpointer(model) 

127 checkpointer.load(weight_fullpath, strict=False) 

128 

129 # Logistic regressor weights 

130 if model.name == "logistic_regression": 

131 logger.info(f"Logistic regression identified: saving model weights") 

132 for param in model.parameters(): 

133 model_weights = np.array(param.data.reshape(-1)) 

134 break 

135 filepath = os.path.join(output_folder, "LogReg_Weights.pdf") 

136 logger.info(f"Creating and saving weights plot at {filepath}...") 

137 os.makedirs(os.path.dirname(filepath), exist_ok=True) 

138 pdf = PdfPages(filepath) 

139 pdf.savefig(relevance_analysis_plot( 

140 model_weights, 

141 title="LogReg model weights")) 

142 pdf.close() 

143 

144 for k,v in dataset.items(): 

145 

146 if k.startswith("_"): 

147 logger.info(f"Skipping dataset '{k}' (not to be evaluated)") 

148 continue 

149 

150 logger.info(f"Running inference on '{k}' set...") 

151 

152 data_loader = DataLoader( 

153 dataset=v, 

154 batch_size=batch_size, 

155 shuffle=False, 

156 pin_memory=torch.cuda.is_available(), 

157 ) 

158 predictions = run(model, data_loader, k, device, output_folder, grad_cams) 

159 

160 # Relevance analysis using permutation feature importance 

161 if(relevance_analysis): 

162 if isinstance(v, ConcatDataset) or not isinstance(v._samples[0].data["data"], list): 

163 logger.info(f"Relevance analysis only possible with radiological signs as input. Cancelling...") 

164 continue 

165 

166 nb_features = len(v._samples[0].data["data"]) 

167 

168 if nb_features == 1: 

169 logger.info(f"Relevance analysis not possible with one feature") 

170 else: 

171 logger.info(f"Starting relevance analysis for subset '{k}'...") 

172 

173 all_mse = [] 

174 for f in range(nb_features): 

175 

176 v_original = copy.deepcopy(v) 

177 

178 # Randomly permute feature values from all samples 

179 v.random_permute(f) 

180 

181 data_loader = DataLoader( 

182 dataset=v, 

183 batch_size=batch_size, 

184 shuffle=False, 

185 pin_memory=torch.cuda.is_available(), 

186 ) 

187 

188 predictions_with_mean = run(model, 

189 data_loader, 

190 k, 

191 device, 

192 output_folder + "_temp") 

193 

194 # Compute MSE between original and new predictions 

195 all_mse.append(metrics.mean_squared_error( 

196 np.array(predictions)[:,1], 

197 np.array(predictions_with_mean)[:,1] 

198 )) 

199 

200 # Back to original values 

201 v = v_original 

202 

203 # Remove temporary folder 

204 shutil.rmtree(output_folder + "_temp", ignore_errors=True) 

205 

206 filepath = os.path.join(output_folder, k + "_RA.pdf") 

207 logger.info(f"Creating and saving plot at {filepath}...") 

208 os.makedirs(os.path.dirname(filepath), exist_ok=True) 

209 pdf = PdfPages(filepath) 

210 pdf.savefig(relevance_analysis_plot( 

211 all_mse, 

212 title=k.capitalize() + " set relevance analysis")) 

213 pdf.close()