1#!/usr/bin/env python
2# coding=utf-8
3
4import os
5import shutil
6import tempfile
7import copy
8
9import click
10import torch
11import numpy as np
12from sklearn import metrics
13from torch.utils.data import DataLoader, ConcatDataset
14from matplotlib.backends.backend_pdf import PdfPages
15
16from bob.extension.scripts.click_helper import (
17 verbosity_option,
18 ConfigCommand,
19 ResourceOption,
20)
21
22from ..engine.predictor import run
23from ..utils.checkpointer import Checkpointer
24
25from .tb import download_to_tempfile
26from ..utils.plot import relevance_analysis_plot
27
28import logging
29logger = logging.getLogger(__name__)
30
31
32@click.command(
33 entry_point_group="bob.med.tb.config",
34 cls=ConfigCommand,
35 epilog="""Examples:
36
37\b
38 1. Runs prediction on an existing dataset configuration:
39\b
40 $ bob tb predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions
41
42""",
43)
44@click.option(
45 "--output-folder",
46 "-o",
47 help="Path where to store the predictions (created if does not exist)",
48 required=True,
49 default="results",
50 cls=ResourceOption,
51 type=click.Path(),
52)
53@click.option(
54 "--model",
55 "-m",
56 help="A torch.nn.Module instance implementing the network to be evaluated",
57 required=True,
58 cls=ResourceOption,
59)
60@click.option(
61 "--dataset",
62 "-d",
63 help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
64 "to be used for running prediction, possibly including all pre-processing "
65 "pipelines required or, optionally, a dictionary mapping string keys to "
66 "torch.utils.data.dataset.Dataset instances. All keys that do not start "
67 "with an underscore (_) will be processed.",
68 required=True,
69 cls=ResourceOption,
70)
71@click.option(
72 "--batch-size",
73 "-b",
74 help="Number of samples in every batch (this parameter affects memory requirements for the network)",
75 required=True,
76 show_default=True,
77 default=1,
78 type=click.IntRange(min=1),
79 cls=ResourceOption,
80)
81@click.option(
82 "--device",
83 "-d",
84 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
85 show_default=True,
86 required=True,
87 default="cpu",
88 cls=ResourceOption,
89)
90@click.option(
91 "--weight",
92 "-w",
93 help="Path or URL to pretrained model file (.pth extension)",
94 required=True,
95 cls=ResourceOption,
96)
97@click.option(
98 "--relevance-analysis",
99 "-r",
100 help="If set, generate relevance analysis pdfs to indicate the relative"
101 "importance of each feature",
102 is_flag=True,
103 cls=ResourceOption,
104)
105@click.option(
106 "--grad-cams",
107 "-g",
108 help="If set, generate grad cams for each prediction (must use batch of 1)",
109 is_flag=True,
110 cls=ResourceOption,
111)
112@verbosity_option(cls=ResourceOption)
113def predict(output_folder, model, dataset, batch_size, device, weight,
114 relevance_analysis, grad_cams, **kwargs):
115 """Predicts Tuberculosis presence (probabilities) on input images"""
116
117 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
118
119 if weight.startswith("http"):
120 logger.info(f"Temporarily downloading '{weight}'...")
121 f = download_to_tempfile(weight, progress=True)
122 weight_fullpath = os.path.abspath(f.name)
123 else:
124 weight_fullpath = os.path.abspath(weight)
125
126 checkpointer = Checkpointer(model)
127 checkpointer.load(weight_fullpath, strict=False)
128
129 # Logistic regressor weights
130 if model.name == "logistic_regression":
131 logger.info(f"Logistic regression identified: saving model weights")
132 for param in model.parameters():
133 model_weights = np.array(param.data.reshape(-1))
134 break
135 filepath = os.path.join(output_folder, "LogReg_Weights.pdf")
136 logger.info(f"Creating and saving weights plot at {filepath}...")
137 os.makedirs(os.path.dirname(filepath), exist_ok=True)
138 pdf = PdfPages(filepath)
139 pdf.savefig(relevance_analysis_plot(
140 model_weights,
141 title="LogReg model weights"))
142 pdf.close()
143
144 for k,v in dataset.items():
145
146 if k.startswith("_"):
147 logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
148 continue
149
150 logger.info(f"Running inference on '{k}' set...")
151
152 data_loader = DataLoader(
153 dataset=v,
154 batch_size=batch_size,
155 shuffle=False,
156 pin_memory=torch.cuda.is_available(),
157 )
158 predictions = run(model, data_loader, k, device, output_folder, grad_cams)
159
160 # Relevance analysis using permutation feature importance
161 if(relevance_analysis):
162 if isinstance(v, ConcatDataset) or not isinstance(v._samples[0].data["data"], list):
163 logger.info(f"Relevance analysis only possible with radiological signs as input. Cancelling...")
164 continue
165
166 nb_features = len(v._samples[0].data["data"])
167
168 if nb_features == 1:
169 logger.info(f"Relevance analysis not possible with one feature")
170 else:
171 logger.info(f"Starting relevance analysis for subset '{k}'...")
172
173 all_mse = []
174 for f in range(nb_features):
175
176 v_original = copy.deepcopy(v)
177
178 # Randomly permute feature values from all samples
179 v.random_permute(f)
180
181 data_loader = DataLoader(
182 dataset=v,
183 batch_size=batch_size,
184 shuffle=False,
185 pin_memory=torch.cuda.is_available(),
186 )
187
188 predictions_with_mean = run(model,
189 data_loader,
190 k,
191 device,
192 output_folder + "_temp")
193
194 # Compute MSE between original and new predictions
195 all_mse.append(metrics.mean_squared_error(
196 np.array(predictions)[:,1],
197 np.array(predictions_with_mean)[:,1]
198 ))
199
200 # Back to original values
201 v = v_original
202
203 # Remove temporary folder
204 shutil.rmtree(output_folder + "_temp", ignore_errors=True)
205
206 filepath = os.path.join(output_folder, k + "_RA.pdf")
207 logger.info(f"Creating and saving plot at {filepath}...")
208 os.makedirs(os.path.dirname(filepath), exist_ok=True)
209 pdf = PdfPages(filepath)
210 pdf.savefig(relevance_analysis_plot(
211 all_mse,
212 title=k.capitalize() + " set relevance analysis"))
213 pdf.close()