Coverage for src/deepdraw/engine/predictor.py: 100%
61 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import datetime
6import logging
7import os
8import time
10import h5py
11import numpy
12import torch
13import torchvision.transforms.functional as VF
15from tqdm import tqdm
17from ..data.utils import overlayed_image
19logger = logging.getLogger(__name__)
22def _save_hdf5(stem, prob, output_folder):
23 """Saves prediction maps as image in the same format as the test image.
25 Parameters
26 ----------
27 stem : str
28 the name of the file without extension on the original dataset
30 prob : PIL.Image.Image
31 Monochrome Image with prediction maps
33 output_folder : str
34 path where to store predictions
35 """
37 fullpath = os.path.join(output_folder, f"{stem}.hdf5")
38 tqdm.write(f"Saving {fullpath}...")
39 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
40 with h5py.File(fullpath, "w") as f:
41 data = prob.cpu().squeeze(0).numpy()
42 f.create_dataset(
43 "array", data=data, compression="gzip", compression_opts=9
44 )
47def _save_image(stem, extension, data, output_folder):
48 """Saves a PIL image into a file.
50 Parameters
51 ----------
53 stem : str
54 the name of the file without extension on the original dataset
56 extension : str
57 an extension for the file to be saved (e.g. ``.png``)
59 data : PIL.Image.Image
60 RGB image with the original image, preloaded
62 output_folder : str
63 path where to store results
64 """
66 fullpath = os.path.join(output_folder, stem + extension)
67 tqdm.write(f"Saving {fullpath}...")
68 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
69 data.save(fullpath)
72def _save_overlayed_png(stem, image, prob, output_folder):
73 """Overlays prediction predictions vessel tree with original test image.
75 Parameters
76 ----------
78 stem : str
79 the name of the file without extension on the original dataset
81 image : torch.Tensor
82 Tensor with RGB input image
84 prob : torch.Tensor
85 Tensor with 1-D prediction map
87 output_folder : str
88 path where to store results
89 """
91 image = VF.to_pil_image(image)
92 prob = VF.to_pil_image(prob.cpu())
93 _save_image(stem, ".png", overlayed_image(image, prob), output_folder)
96def run(model, data_loader, name, device, output_folder, overlayed_folder):
97 """Runs inference on input data, outputs HDF5 files with predictions.
99 Parameters
100 ---------
101 model : :py:class:`torch.nn.Module`
102 neural network model (e.g. driu, hed, unet)
104 data_loader : py:class:`torch.torch.utils.data.DataLoader`
106 name : str
107 the local name of this dataset (e.g. ``train``, or ``test``), to be
108 used when saving measures files.
110 device : :py:class:`torch.device`
111 device to use
113 output_folder : str
114 folder where to store output prediction maps (HDF5 files) and model
115 summary
117 overlayed_folder : str
118 folder where to store output images (PNG files)
119 """
121 logger.info(f"Output folder: {output_folder}")
122 os.makedirs(output_folder, exist_ok=True)
124 logger.info(f"Device: {device}")
126 model.eval() # set evaluation mode
127 model.to(device) # set/cast parameters to device
128 sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions
130 # Setup timers
131 start_total_time = time.time()
132 times = []
133 len_samples = []
135 output_folder = os.path.join(output_folder, name)
136 overlayed_folder = (
137 os.path.join(overlayed_folder, name)
138 if overlayed_folder is not None
139 else overlayed_folder
140 )
142 for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
143 names = samples[0]
144 images = samples[1].to(
145 device=device, non_blocking=torch.cuda.is_available()
146 )
148 with torch.no_grad():
149 start_time = time.perf_counter()
150 outputs = model(images)
152 # necessary check for HED/Little W-Net architecture that use
153 # several outputs for loss calculation instead of just the last one
154 if isinstance(outputs, (list, tuple)):
155 outputs = outputs[-1]
157 predictions = sigmoid(outputs)
159 batch_time = time.perf_counter() - start_time
160 times.append(batch_time)
161 len_samples.append(len(images))
163 for stem, img, prob in zip(names, images, predictions):
164 _save_hdf5(stem, prob, output_folder)
165 if overlayed_folder is not None:
166 _save_overlayed_png(stem, img, prob, overlayed_folder)
168 # report operational summary
169 total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
170 logger.info(f"Total time: {total_time}")
172 average_batch_time = numpy.mean(times)
173 logger.info(f"Average batch time: {average_batch_time:g}s")
175 average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(
176 sum(len_samples)
177 )
178 logger.info(f"Average image time: {average_image_time:g}s")