Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/engine/predictor.py: 60%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

109 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import time 

6import datetime 

7import csv 

8import shutil 

9 

10import PIL 

11import numpy 

12from tqdm import tqdm 

13import matplotlib.pyplot as plt 

14from matplotlib.patches import Rectangle 

15from matplotlib.gridspec import GridSpec 

16 

17import torch 

18import torchvision.transforms.functional as VF 

19from torchvision import transforms 

20 

21from ..utils.grad_cams import GradCAM 

22 

23import logging 

24 

25logger = logging.getLogger(__name__) 

26 

27colors = [[(47, 79, 79), "Cardiomegaly"], 

28 [(255, 0, 0), "Emphysema"], 

29 [(0, 128, 0), "Pleural effusion"], 

30 [(0, 0, 128), "Hernia"], 

31 [(255, 84, 0), "Infiltration"], 

32 [(222, 184, 135), "Mass"], 

33 [(0, 255, 0), "Nodule"], 

34 [(0, 191, 255), "Atelectasis"], 

35 [(0, 0, 255), "Pneumothorax"], 

36 [(255, 0, 255), "Pleural thickening"], 

37 [(255, 255, 0), "Pneumonia"], 

38 [(126, 0, 255), "Fibrosis"], 

39 [(255, 20, 147), "Edema"], 

40 [(0, 255, 180), "Consolidation"]] 

41 

42def run(model, data_loader, name, device, output_folder, grad_cams=False): 

43 """ 

44 Runs inference on input data, outputs HDF5 files with predictions 

45 

46 Parameters 

47 --------- 

48 model : :py:class:`torch.nn.Module` 

49 neural network model (e.g. pasa) 

50 

51 data_loader : py:class:`torch.torch.utils.data.DataLoader` 

52 

53 name : str 

54 the local name of this dataset (e.g. ``train``, or ``test``), to be 

55 used when saving measures files. 

56 

57 device : str 

58 device to use ``cpu`` or ``cuda:0`` 

59 

60 output_folder : str 

61 folder where to store output prediction and model 

62 summary 

63 

64 grad_cams : bool 

65 if we export grad cams for every prediction (must be used along 

66 a batch size of 1 with the DensenetRS model) 

67 

68 Returns 

69 ------- 

70 

71 all_predictions : list 

72 All the predictions associated with filename and groundtruth 

73 

74 """ 

75 

76 output_folder = os.path.join(output_folder, name) 

77 

78 logger.info(f"Output folder: {output_folder}") 

79 os.makedirs(output_folder, exist_ok=True) 

80 

81 logger.info(f"Device: {device}") 

82 

83 logfile_name = os.path.join(output_folder, "predictions.csv") 

84 logfile_fields = ( 

85 "filename", 

86 "likelihood", 

87 "ground_truth" 

88 ) 

89 

90 if os.path.exists(logfile_name): 

91 backup = logfile_name + "~" 

92 if os.path.exists(backup): 

93 os.unlink(backup) 

94 shutil.move(logfile_name, backup) 

95 

96 if grad_cams: 

97 grad_folder = os.path.join(output_folder, "cams") 

98 logger.info(f"Grad cams folder: {grad_folder}") 

99 os.makedirs(grad_folder, exist_ok=True) 

100 

101 with open(logfile_name, "a+", newline="") as logfile: 

102 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) 

103 

104 logwriter.writeheader() 

105 

106 model.eval() # set evaluation mode 

107 model.to(device) # set/cast parameters to device 

108 

109 # Setup timers 

110 start_total_time = time.time() 

111 times = [] 

112 len_samples = [] 

113 

114 all_predictions = [] 

115 

116 for samples in tqdm( 

117 data_loader, desc="batches", leave=False, disable=None, 

118 ): 

119 

120 names = samples[0] 

121 images = samples[1].to( 

122 device=device, non_blocking=torch.cuda.is_available() 

123 ) 

124 

125 # Gradcams generation 

126 allowed_models = ["DensenetRS"] 

127 if grad_cams and model.name in allowed_models: 

128 gcam = GradCAM(model=model) 

129 probs, ids = gcam.forward(images) 

130 

131 # To store signs overlays 

132 cams_img = dict() 

133 

134 # Top k number of radiological signs for which we generate cams 

135 topk = 1 

136 

137 for i in range(topk): 

138 

139 # Keep only "positive" signs 

140 if probs[:, [i]] > 0.5: 

141 

142 # Grad-CAM 

143 b = ids[:, [i]] 

144 gcam.backward(ids=ids[:, [i]]) 

145 regions = gcam.generate(target_layer="model_ft.features.denseblock4.denselayer16.conv2") 

146 

147 for j in range(len(images)): 

148 

149 current_cam = regions[j, 0].cpu().numpy() 

150 current_cam[current_cam < 0.75] = 0.0 

151 current_cam[current_cam >= 0.75] = 1.0 

152 current_cam = PIL.Image.fromarray(numpy.uint8(current_cam * 255) , 'L') 

153 cams_img[b.item()] = [current_cam, round(probs[:, [i]].item(), 2)] 

154 

155 if len(cams_img) > 0: 

156 

157 # Convert original image tensor into PIL Image 

158 original_image = transforms.ToPILImage(mode='RGB')(images[0]) 

159 

160 for sign_id, label_prob in cams_img.items(): 

161 

162 label = label_prob[0] 

163 

164 # Create the colored overlay for current sign 

165 colored_sign = PIL.ImageOps.colorize( 

166 label.convert("L"), (0, 0, 0), colors[sign_id][0] 

167 ) 

168 

169 # blend image and label together - first blend to get signs drawn with a 

170 # slight "label_color" tone on top, then composite with original image, to 

171 # avoid loosing brightness. 

172 retval = PIL.Image.blend(original_image, colored_sign, 0.5) 

173 composite_mask = PIL.ImageOps.invert(label.convert("L")) 

174 original_image = PIL.Image.composite(original_image, retval, composite_mask) 

175 

176 handles = [] 

177 labels = [] 

178 for i, v in enumerate(colors): 

179 # If sign present on image 

180 if cams_img.get(i) is not None: 

181 handles.append(Rectangle((0,0),1,1, color = tuple((v/255 for v in v[0])))) 

182 labels.append(v[1] + " (" + str(cams_img[i][1]) + ")") 

183 

184 gs = GridSpec(6,1) 

185 fig = plt.figure(figsize = (10,11)) 

186 ax1 = fig.add_subplot(gs[:-1,:]) # For the plot 

187 ax2 = fig.add_subplot(gs[-1,:]) # For the legend 

188 

189 ax1.imshow(original_image) 

190 ax1.axis('off') 

191 ax2.legend(handles,labels, mode='expand', ncol=3, frameon=False) 

192 ax2.axis('off') 

193 

194 original_filename = samples[0][0].split('/')[-1].split('.')[0] 

195 cam_filename = os.path.join(grad_folder, original_filename + "_cam.png") 

196 fig.savefig(cam_filename) 

197 

198 with torch.no_grad(): 

199 

200 start_time = time.perf_counter() 

201 outputs = model(images) 

202 probabilities = torch.sigmoid(outputs) 

203 

204 # necessary check for HED architecture that uses several outputs 

205 # for loss calculation instead of just the last concatfuse block 

206 if isinstance(outputs, list): 

207 outputs = outputs[-1] 

208 

209 # predictions = sigmoid(outputs) 

210 

211 batch_time = time.perf_counter() - start_time 

212 times.append(batch_time) 

213 len_samples.append(len(images)) 

214 

215 logdata = ( 

216 ("filename", f"{names[0]}"), 

217 ("likelihood", f"{torch.flatten(probabilities).data.cpu().numpy()}"), 

218 ("ground_truth", f"{torch.flatten(samples[2]).data.cpu().numpy()}"), 

219 ) 

220 

221 logwriter.writerow(dict(k for k in logdata)) 

222 logfile.flush() 

223 tqdm.write(" | ".join([f"{k}: {v}" for (k, v) in logdata[:4]])) 

224 

225 # Keep prediction for relevance analysis 

226 all_predictions.append([ 

227 names[0], 

228 torch.flatten(probabilities).data.cpu().numpy(), 

229 torch.flatten(samples[2]).data.cpu().numpy() 

230 ]) 

231 

232 # report operational summary 

233 total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) 

234 logger.info(f"Total time: {total_time}") 

235 

236 average_batch_time = numpy.mean(times) 

237 logger.info(f"Average batch time: {average_batch_time:g}s") 

238 

239 average_image_time = numpy.sum(numpy.array(times) * len_samples) / float( 

240 sum(len_samples) 

241 ) 

242 logger.info(f"Average image time: {average_image_time:g}s") 

243 

244 return all_predictions