Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import datetime
5import logging
6import os
7import time
9import h5py
10import numpy
11import torch
12import torchvision.transforms.functional as VF
14from tqdm import tqdm
16from ..data.utils import overlayed_image
18logger = logging.getLogger(__name__)
21def _save_hdf5(stem, prob, output_folder):
22 """
23 Saves prediction maps as image in the same format as the test image
26 Parameters
27 ----------
28 stem : str
29 the name of the file without extension on the original dataset
31 prob : PIL.Image.Image
32 Monochrome Image with prediction maps
34 output_folder : str
35 path where to store predictions
37 """
39 fullpath = os.path.join(output_folder, f"{stem}.hdf5")
40 tqdm.write(f"Saving {fullpath}...")
41 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
42 with h5py.File(fullpath, "w") as f:
43 data = prob.cpu().squeeze(0).numpy()
44 f.create_dataset(
45 "array", data=data, compression="gzip", compression_opts=9
46 )
49def _save_image(stem, extension, data, output_folder):
50 """Saves a PIL image into a file
52 Parameters
53 ----------
55 stem : str
56 the name of the file without extension on the original dataset
58 extension : str
59 an extension for the file to be saved (e.g. ``.png``)
61 data : PIL.Image.Image
62 RGB image with the original image, preloaded
64 output_folder : str
65 path where to store results
67 """
69 fullpath = os.path.join(output_folder, stem + extension)
70 tqdm.write(f"Saving {fullpath}...")
71 os.makedirs(os.path.dirname(fullpath), exist_ok=True)
72 data.save(fullpath)
75def _save_overlayed_png(stem, image, prob, output_folder):
76 """Overlays prediction predictions vessel tree with original test image
79 Parameters
80 ----------
82 stem : str
83 the name of the file without extension on the original dataset
85 image : torch.Tensor
86 Tensor with RGB input image
88 prob : torch.Tensor
89 Tensor with 1-D prediction map
91 output_folder : str
92 path where to store results
94 """
96 image = VF.to_pil_image(image)
97 prob = VF.to_pil_image(prob.cpu())
98 _save_image(stem, ".png", overlayed_image(image, prob), output_folder)
101def run(model, data_loader, name, device, output_folder, overlayed_folder):
102 """
103 Runs inference on input data, outputs HDF5 files with predictions
105 Parameters
106 ---------
107 model : :py:class:`torch.nn.Module`
108 neural network model (e.g. driu, hed, unet)
110 data_loader : py:class:`torch.torch.utils.data.DataLoader`
112 name : str
113 the local name of this dataset (e.g. ``train``, or ``test``), to be
114 used when saving measures files.
116 device : :py:class:`torch.device`
117 device to use
119 output_folder : str
120 folder where to store output prediction maps (HDF5 files) and model
121 summary
123 overlayed_folder : str
124 folder where to store output images (PNG files)
126 """
128 logger.info(f"Output folder: {output_folder}")
129 os.makedirs(output_folder, exist_ok=True)
131 logger.info(f"Device: {device}")
133 model.eval() # set evaluation mode
134 model.to(device) # set/cast parameters to device
135 sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions
137 # Setup timers
138 start_total_time = time.time()
139 times = []
140 len_samples = []
142 output_folder = os.path.join(output_folder, name)
143 overlayed_folder = (
144 os.path.join(overlayed_folder, name)
145 if overlayed_folder is not None
146 else overlayed_folder
147 )
149 for samples in tqdm(data_loader, desc="batches", leave=False, disable=None):
151 names = samples[0]
152 images = samples[1].to(
153 device=device, non_blocking=torch.cuda.is_available()
154 )
156 with torch.no_grad():
158 start_time = time.perf_counter()
159 outputs = model(images)
161 # necessary check for HED/Little W-Net architecture that use
162 # several outputs for loss calculation instead of just the last one
163 if isinstance(outputs, (list, tuple)):
164 outputs = outputs[-1]
166 predictions = sigmoid(outputs)
168 batch_time = time.perf_counter() - start_time
169 times.append(batch_time)
170 len_samples.append(len(images))
172 for stem, img, prob in zip(names, images, predictions):
173 _save_hdf5(stem, prob, output_folder)
174 if overlayed_folder is not None:
175 _save_overlayed_png(stem, img, prob, overlayed_folder)
177 # report operational summary
178 total_time = datetime.timedelta(seconds=int(time.time() - start_total_time))
179 logger.info(f"Total time: {total_time}")
181 average_batch_time = numpy.mean(times)
182 logger.info(f"Average batch time: {average_batch_time:g}s")
184 average_image_time = numpy.sum(numpy.array(times) * len_samples) / float(
185 sum(len_samples)
186 )
187 logger.info(f"Average image time: {average_image_time:g}s")