Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import datetime 

5import logging 

6import os 

7import time 

8 

9import h5py 

10import numpy 

11import torch 

12import torchvision.transforms.functional as VF 

13 

14from tqdm import tqdm 

15 

16from ..data.utils import overlayed_image 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21def _save_hdf5(stem, prob, output_folder): 

22 """ 

23 Saves prediction maps as image in the same format as the test image 

24 

25 

26 Parameters 

27 ---------- 

28 stem : str 

29 the name of the file without extension on the original dataset 

30 

31 prob : PIL.Image.Image 

32 Monochrome Image with prediction maps 

33 

34 output_folder : str 

35 path where to store predictions 

36 

37 """ 

38 

39 fullpath = os.path.join(output_folder, f"{stem}.hdf5") 

40 tqdm.write(f"Saving {fullpath}...") 

41 os.makedirs(os.path.dirname(fullpath), exist_ok=True) 

42 with h5py.File(fullpath, "w") as f: 

43 data = prob.cpu().squeeze(0).numpy() 

44 f.create_dataset( 

45 "array", data=data, compression="gzip", compression_opts=9 

46 ) 

47 

48 

49def _save_image(stem, extension, data, output_folder): 

50 """Saves a PIL image into a file 

51 

52 Parameters 

53 ---------- 

54 

55 stem : str 

56 the name of the file without extension on the original dataset 

57 

58 extension : str 

59 an extension for the file to be saved (e.g. ``.png``) 

60 

61 data : PIL.Image.Image 

62 RGB image with the original image, preloaded 

63 

64 output_folder : str 

65 path where to store results 

66 

67 """ 

68 

69 fullpath = os.path.join(output_folder, stem + extension) 

70 tqdm.write(f"Saving {fullpath}...") 

71 os.makedirs(os.path.dirname(fullpath), exist_ok=True) 

72 data.save(fullpath) 

73 

74 

75def _save_overlayed_png(stem, image, prob, output_folder): 

76 """Overlays prediction predictions vessel tree with original test image 

77 

78 

79 Parameters 

80 ---------- 

81 

82 stem : str 

83 the name of the file without extension on the original dataset 

84 

85 image : torch.Tensor 

86 Tensor with RGB input image 

87 

88 prob : torch.Tensor 

89 Tensor with 1-D prediction map 

90 

91 output_folder : str 

92 path where to store results 

93 

94 """ 

95 

96 image = VF.to_pil_image(image) 

97 prob = VF.to_pil_image(prob.cpu()) 

98 _save_image(stem, ".png", overlayed_image(image, prob), output_folder) 

99 

100 

101def run(model, data_loader, name, device, output_folder, overlayed_folder): 

102 """ 

103 Runs inference on input data, outputs HDF5 files with predictions 

104 

105 Parameters 

106 --------- 

107 model : :py:class:`torch.nn.Module` 

108 neural network model (e.g. driu, hed, unet) 

109 

110 data_loader : py:class:`torch.torch.utils.data.DataLoader` 

111 

112 name : str 

113 the local name of this dataset (e.g. ``train``, or ``test``), to be 

114 used when saving measures files. 

115 

116 device : :py:class:`torch.device` 

117 device to use 

118 

119 output_folder : str 

120 folder where to store output prediction maps (HDF5 files) and model 

121 summary 

122 

123 overlayed_folder : str 

124 folder where to store output images (PNG files) 

125 

126 """ 

127 

128 logger.info(f"Output folder: {output_folder}") 

129 os.makedirs(output_folder, exist_ok=True) 

130 

131 logger.info(f"Device: {device}") 

132 

133 model.eval() # set evaluation mode 

134 model.to(device) # set/cast parameters to device 

135 sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions 

136 

137 # Setup timers 

138 start_total_time = time.time() 

139 times = [] 

140 len_samples = [] 

141 

142 output_folder = os.path.join(output_folder, name) 

143 overlayed_folder = ( 

144 os.path.join(overlayed_folder, name) 

145 if overlayed_folder is not None 

146 else overlayed_folder 

147 ) 

148 

149 for samples in tqdm(data_loader, desc="batches", leave=False, disable=None): 

150 

151 names = samples[0] 

152 images = samples[1].to( 

153 device=device, non_blocking=torch.cuda.is_available() 

154 ) 

155 

156 with torch.no_grad(): 

157 

158 start_time = time.perf_counter() 

159 outputs = model(images) 

160 

161 # necessary check for HED/Little W-Net architecture that use 

162 # several outputs for loss calculation instead of just the last one 

163 if isinstance(outputs, (list, tuple)): 

164 outputs = outputs[-1] 

165 

166 predictions = sigmoid(outputs) 

167 

168 batch_time = time.perf_counter() - start_time 

169 times.append(batch_time) 

170 len_samples.append(len(images)) 

171 

172 for stem, img, prob in zip(names, images, predictions): 

173 _save_hdf5(stem, prob, output_folder) 

174 if overlayed_folder is not None: 

175 _save_overlayed_png(stem, img, prob, overlayed_folder) 

176 

177 # report operational summary 

178 total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) 

179 logger.info(f"Total time: {total_time}") 

180 

181 average_batch_time = numpy.mean(times) 

182 logger.info(f"Average batch time: {average_batch_time:g}s") 

183 

184 average_image_time = numpy.sum(numpy.array(times) * len_samples) / float( 

185 sum(len_samples) 

186 ) 

187 logger.info(f"Average image time: {average_image_time:g}s")