Coverage for src/deepdraw/engine/predictor.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import datetime 

6import logging 

7import os 

8import time 

9 

10import h5py 

11import numpy 

12import torch 

13import torchvision.transforms.functional as VF 

14 

15from tqdm import tqdm 

16 

17from ..data.utils import overlayed_image 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def _save_hdf5(stem, prob, output_folder): 

23 """Saves prediction maps as image in the same format as the test image. 

24 

25 Parameters 

26 ---------- 

27 stem : str 

28 the name of the file without extension on the original dataset 

29 

30 prob : PIL.Image.Image 

31 Monochrome Image with prediction maps 

32 

33 output_folder : str 

34 path where to store predictions 

35 """ 

36 

37 fullpath = os.path.join(output_folder, f"{stem}.hdf5") 

38 tqdm.write(f"Saving {fullpath}...") 

39 os.makedirs(os.path.dirname(fullpath), exist_ok=True) 

40 with h5py.File(fullpath, "w") as f: 

41 data = prob.cpu().squeeze(0).numpy() 

42 f.create_dataset( 

43 "array", data=data, compression="gzip", compression_opts=9 

44 ) 

45 

46 

47def _save_image(stem, extension, data, output_folder): 

48 """Saves a PIL image into a file. 

49 

50 Parameters 

51 ---------- 

52 

53 stem : str 

54 the name of the file without extension on the original dataset 

55 

56 extension : str 

57 an extension for the file to be saved (e.g. ``.png``) 

58 

59 data : PIL.Image.Image 

60 RGB image with the original image, preloaded 

61 

62 output_folder : str 

63 path where to store results 

64 """ 

65 

66 fullpath = os.path.join(output_folder, stem + extension) 

67 tqdm.write(f"Saving {fullpath}...") 

68 os.makedirs(os.path.dirname(fullpath), exist_ok=True) 

69 data.save(fullpath) 

70 

71 

72def _save_overlayed_png(stem, image, prob, output_folder): 

73 """Overlays prediction predictions vessel tree with original test image. 

74 

75 Parameters 

76 ---------- 

77 

78 stem : str 

79 the name of the file without extension on the original dataset 

80 

81 image : torch.Tensor 

82 Tensor with RGB input image 

83 

84 prob : torch.Tensor 

85 Tensor with 1-D prediction map 

86 

87 output_folder : str 

88 path where to store results 

89 """ 

90 

91 image = VF.to_pil_image(image) 

92 prob = VF.to_pil_image(prob.cpu()) 

93 _save_image(stem, ".png", overlayed_image(image, prob), output_folder) 

94 

95 

96def run(model, data_loader, name, device, output_folder, overlayed_folder): 

97 """Runs inference on input data, outputs HDF5 files with predictions. 

98 

99 Parameters 

100 --------- 

101 model : :py:class:`torch.nn.Module` 

102 neural network model (e.g. driu, hed, unet) 

103 

104 data_loader : py:class:`torch.torch.utils.data.DataLoader` 

105 

106 name : str 

107 the local name of this dataset (e.g. ``train``, or ``test``), to be 

108 used when saving measures files. 

109 

110 device : :py:class:`torch.device` 

111 device to use 

112 

113 output_folder : str 

114 folder where to store output prediction maps (HDF5 files) and model 

115 summary 

116 

117 overlayed_folder : str 

118 folder where to store output images (PNG files) 

119 """ 

120 

121 logger.info(f"Output folder: {output_folder}") 

122 os.makedirs(output_folder, exist_ok=True) 

123 

124 logger.info(f"Device: {device}") 

125 

126 model.eval() # set evaluation mode 

127 model.to(device) # set/cast parameters to device 

128 sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions 

129 

130 # Setup timers 

131 start_total_time = time.time() 

132 times = [] 

133 len_samples = [] 

134 

135 output_folder = os.path.join(output_folder, name) 

136 overlayed_folder = ( 

137 os.path.join(overlayed_folder, name) 

138 if overlayed_folder is not None 

139 else overlayed_folder 

140 ) 

141 

142 for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): 

143 names = samples[0] 

144 images = samples[1].to( 

145 device=device, non_blocking=torch.cuda.is_available() 

146 ) 

147 

148 with torch.no_grad(): 

149 start_time = time.perf_counter() 

150 outputs = model(images) 

151 

152 # necessary check for HED/Little W-Net architecture that use 

153 # several outputs for loss calculation instead of just the last one 

154 if isinstance(outputs, (list, tuple)): 

155 outputs = outputs[-1] 

156 

157 predictions = sigmoid(outputs) 

158 

159 batch_time = time.perf_counter() - start_time 

160 times.append(batch_time) 

161 len_samples.append(len(images)) 

162 

163 for stem, img, prob in zip(names, images, predictions): 

164 _save_hdf5(stem, prob, output_folder) 

165 if overlayed_folder is not None: 

166 _save_overlayed_png(stem, img, prob, overlayed_folder) 

167 

168 # report operational summary 

169 total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) 

170 logger.info(f"Total time: {total_time}") 

171 

172 average_batch_time = numpy.mean(times) 

173 logger.info(f"Average batch time: {average_batch_time:g}s") 

174 

175 average_image_time = numpy.sum(numpy.array(times) * len_samples) / float( 

176 sum(len_samples) 

177 ) 

178 logger.info(f"Average image time: {average_image_time:g}s")