1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import os
5import time
6import datetime
7import csv
8import shutil
9
10import PIL
11import numpy
12from tqdm import tqdm
13import matplotlib.pyplot as plt
14from matplotlib.patches import Rectangle
15from matplotlib.gridspec import GridSpec
16
17import torch
18import torchvision.transforms.functional as VF
19from torchvision import transforms
20
21from ..utils.grad_cams import GradCAM
22
23import logging
24
25logger = logging.getLogger(__name__)
26
27colors = [[(47, 79, 79), "Cardiomegaly"],
28 [(255, 0, 0), "Emphysema"],
29 [(0, 128, 0), "Pleural effusion"],
30 [(0, 0, 128), "Hernia"],
31 [(255, 84, 0), "Infiltration"],
32 [(222, 184, 135), "Mass"],
33 [(0, 255, 0), "Nodule"],
34 [(0, 191, 255), "Atelectasis"],
35 [(0, 0, 255), "Pneumothorax"],
36 [(255, 0, 255), "Pleural thickening"],
37 [(255, 255, 0), "Pneumonia"],
38 [(126, 0, 255), "Fibrosis"],
39 [(255, 20, 147), "Edema"],
40 [(0, 255, 180), "Consolidation"]]
41
42def run(model, data_loader, name, device, output_folder, grad_cams=False):
43 """
44 Runs inference on input data, outputs HDF5 files with predictions
45
46 Parameters
47 ---------
48 model : :py:class:`torch.nn.Module`
49 neural network model (e.g. pasa)
50
51 data_loader : py:class:`torch.torch.utils.data.DataLoader`
52
53 name : str
54 the local name of this dataset (e.g. ``train``, or ``test``), to be
55 used when saving measures files.
56
57 device : str
58 device to use ``cpu`` or ``cuda:0``
59
60 output_folder : str
61 folder where to store output prediction and model
62 summary
63
64 grad_cams : bool
65 if we export grad cams for every prediction (must be used along
66 a batch size of 1 with the DensenetRS model)
67
68 Returns
69 -------
70
71 all_predictions : list
72 All the predictions associated with filename and groundtruth
73
74 """
75
76 output_folder = os.path.join(output_folder, name)
77
78 logger.info(f"Output folder: {output_folder}")
79 os.makedirs(output_folder, exist_ok=True)
80
81 logger.info(f"Device: {device}")
82
83 logfile_name = os.path.join(output_folder, "predictions.csv")
84 logfile_fields = (
85 "filename",
86 "likelihood",
87 "ground_truth"
88 )
89
90 if os.path.exists(logfile_name):
91 backup = logfile_name + "~"
92 if os.path.exists(backup):
93 os.unlink(backup)
94 shutil.move(logfile_name, backup)
95
96 if grad_cams:
97 grad_folder = os.path.join(output_folder, "cams")
98 logger.info(f"Grad cams folder: {grad_folder}")
99 os.makedirs(grad_folder, exist_ok=True)
100
101 with open(logfile_name, "a+", newline="") as logfile:
102 logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
103
104 logwriter.writeheader()
105
106 model.eval() # set evaluation mode
107 model.to(device) # set/cast parameters to device
108
109 # Setup timers
110 start_total_time = time.time()
111 times = []
112 len_samples = []
113
114 all_predictions = []
115
116 for samples in tqdm(
117 data_loader, desc="batches", leave=False, disable=None,
118 ):
119
120 names = samples[0]
121 images = samples[1].to(
122 device=device, non_blocking=torch.cuda.is_available()
123 )
124
125 # Gradcams generation
126 allowed_models = ["DensenetRS"]
127 if grad_cams and model.name in allowed_models:
128 gcam = GradCAM(model=model)
129 probs, ids = gcam.forward(images)
130
131 # To store signs overlays
132 cams_img = dict()
133
134 # Top k number of radiological signs for which we generate cams
135 topk = 1
136
137 for i in range(topk):
138
139 # Keep only "positive" signs
140 if probs[:, [i]] > 0.5:
141
142 # Grad-CAM
143 b = ids[:, [i]]
144 gcam.backward(ids=ids[:, [i]])
145 regions = gcam.generate(target_layer="model_ft.features.denseblock4.denselayer16.conv2")
146
147 for j in range(len(images)):
148
149 current_cam = regions[j, 0].cpu().numpy()
150 current_cam[current_cam < 0.75] = 0.0
151 current_cam[current_cam >= 0.75] = 1.0
152 current_cam = PIL.Image.fromarray(numpy.uint8(current_cam * 255) , 'L')
153 cams_img[b.item()] = [current_cam, round(probs[:, [i]].item(), 2)]
154
155 if len(cams_img) > 0:
156
157 # Convert original image tensor into PIL Image
158 original_image = transforms.ToPILImage(mode='RGB')(images[0])
159
160 for sign_id, label_prob in cams_img.items():
161
162 label = label_prob[0]
163
164 # Create the colored overlay for current sign
165 colored_sign = PIL.ImageOps.colorize(
166 label.convert("L"), (0, 0, 0), colors[sign_id][0]
167 )
168
169 # blend image and label together - first blend to get signs drawn with a
170 # slight "label_color" tone on top, then composite with original image, to
171 # avoid loosing brightness.
172 retval = PIL.Image.blend(original_image, colored_sign, 0.5)
173 composite_mask = PIL.ImageOps.invert(label.convert("L"))
174 original_image = PIL.Image.composite(original_image, retval, composite_mask)
175
176 handles = []
177 labels = []
178 for i, v in enumerate(colors):
179 # If sign present on image
180 if cams_img.get(i) is not None:
181 handles.append(Rectangle((0,0),1,1, color = tuple((v/255 for v in v[0]))))
182 labels.append(v[1] + " (" + str(cams_img[i][1]) + ")")
183
184 gs = GridSpec(6,1)
185 fig = plt.figure(figsize = (10,11))
186 ax1 = fig.add_subplot(gs[:-1,:]) # For the plot
187 ax2 = fig.add_subplot(gs[-1,:]) # For the legend
188
189 ax1.imshow(original_image)
190 ax1.axis('off')
191 ax2.legend(handles,labels, mode='expand', ncol=3, frameon=False)
192 ax2.axis('off')
193
194 original_filename = samples[0][0].split('/')[-1].split('.')[0]
195 cam_filename = os.path.join(grad_folder, original_filename + "_cam.png")
196 fig.savefig(cam_filename)
197
198 with torch.no_grad():
199
200 start_time = time.perf_counter()
201 outputs = model(images)
202 probabilities = torch.sigmoid(outputs)
203
204 # necessary check for HED architecture that uses several outputs
205 # for loss calculation instead of just the last concatfuse block
206 if isinstance(outputs, list):
207 outputs = outputs[-1]
208
209 # predictions = sigmoid(outputs)
210
211 batch_time = time.perf_counter() - start_time
212 times.append(batch_time)
213 len_samples.append(len(images))
214
215 logdata = (
216 ("filename", f"{names[0]}"),
217 ("likelihood", f"{torch.flatten(probabilities).data.cpu().numpy()}"),
218 ("ground_truth", f"{torch.flatten(samples[2]).data.cpu().numpy()}"),
219 )
220
221 logwriter.writerow(dict(k for k in logdata))
222 logfile.flush()
223 tqdm.write(" | ".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
224
225 # Keep prediction for relevance analysis
226 all_predictions.append([
227 names[0],
228 torch.flatten(probabilities).data.cpu().numpy(),
229 torch.flatten(samples[2]).data.cpu().numpy()
230 ])
231
232 # report operational summary
233 total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
234 logger.info(f"Total time: {total_time}")
235
236 average_batch_time = numpy.mean(times)
237 logger.info(f"Average batch time: {average_batch_time:g}s")
238
239 average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(
240 sum(len_samples)
241 )
242 logger.info(f"Average image time: {average_image_time:g}s")
243
244 return all_predictions