Coverage for src/deepdraw/script/predict.py: 94%

48 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5import multiprocessing 

6import os 

7import sys 

8 

9import click 

10import torch 

11 

12from clapper.click import ConfigCommand, ResourceOption, verbosity_option 

13from clapper.logging import setup 

14from torch.utils.data import DataLoader 

15 

16logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") 

17 

18from ..engine.predictor import run 

19from ..utils.checkpointer import Checkpointer 

20from .common import download_to_tempfile, setup_pytorch_device 

21 

22 

23@click.command( 

24 entry_point_group="deepdraw.config", 

25 cls=ConfigCommand, 

26 epilog="""Examples: 

27 

28\b 

29 1. Runs prediction on an existing dataset configuration: 

30 

31 .. code:: sh 

32 

33 $ deepdraw predict -vv m2unet drive --weight=path/to/model_final_epoch.pth --output-folder=path/to/predictions 

34 

35\b 

36 2. To run prediction on a folder with your own images, you must first 

37 specify resizing, cropping, etc, so that the image can be correctly 

38 input to the model. Failing to do so will likely result in poor 

39 performance. To figure out such specifications, you must consult the 

40 dataset configuration used for **training** the provided model. Once 

41 you figured this out, do the following: 

42 

43 .. code:: sh 

44 

45 $ deepdraw config copy csv-dataset-example mydataset.py 

46 # modify "mydataset.py" to include the base path and required transforms 

47 $ deepdraw predict -vv m2unet mydataset.py --weight=path/to/model_final_epoch.pth --output-folder=path/to/predictions 

48""", 

49) 

50@click.option( 

51 "--output-folder", 

52 "-o", 

53 help="Path where to store the predictions (created if does not exist)", 

54 required=True, 

55 default="results", 

56 cls=ResourceOption, 

57 type=click.Path(), 

58) 

59@click.option( 

60 "--model", 

61 "-m", 

62 help="A torch.nn.Module instance implementing the network to be evaluated", 

63 required=True, 

64 cls=ResourceOption, 

65) 

66@click.option( 

67 "--dataset", 

68 "-d", 

69 help="A torch.utils.data.dataset.Dataset instance implementing a dataset " 

70 "to be used for running prediction, possibly including all pre-processing " 

71 "pipelines required or, optionally, a dictionary mapping string keys to " 

72 "torch.utils.data.dataset.Dataset instances. All keys that do not start " 

73 "with an underscore (_) will be processed.", 

74 required=True, 

75 cls=ResourceOption, 

76) 

77@click.option( 

78 "--batch-size", 

79 "-b", 

80 help="Number of samples in every batch (this parameter affects memory requirements for the network)", 

81 required=True, 

82 show_default=True, 

83 default=1, 

84 type=click.IntRange(min=1), 

85 cls=ResourceOption, 

86) 

87@click.option( 

88 "--device", 

89 "-d", 

90 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', 

91 show_default=True, 

92 required=True, 

93 default="cpu", 

94 cls=ResourceOption, 

95) 

96@click.option( 

97 "--weight", 

98 "-w", 

99 help="Path or URL to pretrained model file (.pth extension)", 

100 required=True, 

101 cls=ResourceOption, 

102) 

103@click.option( 

104 "--overlayed", 

105 "-O", 

106 help="Creates overlayed representations of the output probability maps on " 

107 "top of input images (store results as PNG files). If not set, or empty " 

108 "then do **NOT** output overlayed images. Otherwise, the parameter " 

109 "represents the name of a folder where to store those", 

110 show_default=True, 

111 default=None, 

112 required=False, 

113 cls=ResourceOption, 

114) 

115@click.option( 

116 "--parallel", 

117 "-P", 

118 help="""Use multiprocessing for data loading: if set to -1 (default), 

119 disables multiprocessing data loading. Set to 0 to enable as many data 

120 loading instances as processing cores as available in the system. Set to 

121 >= 1 to enable that many multiprocessing instances for data loading.""", 

122 type=click.IntRange(min=-1), 

123 show_default=True, 

124 required=True, 

125 default=-1, 

126 cls=ResourceOption, 

127) 

128@verbosity_option(logger=logger, cls=ResourceOption) 

129@click.pass_context 

130def predict( 

131 ctx, 

132 output_folder, 

133 model, 

134 dataset, 

135 batch_size, 

136 device, 

137 weight, 

138 overlayed, 

139 parallel, 

140 verbose, 

141 **kwargs, 

142): 

143 """Predicts vessel map (probabilities) on input images.""" 

144 

145 device = setup_pytorch_device(device) 

146 

147 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) 

148 

149 if weight.startswith("http"): 

150 logger.info(f"Temporarily downloading '{weight}'...") 

151 f = download_to_tempfile(weight, progress=True) 

152 weight_fullpath = os.path.abspath(f.name) 

153 else: 

154 weight_fullpath = os.path.abspath(weight) 

155 

156 checkpointer = Checkpointer(model) 

157 checkpointer.load(weight_fullpath) 

158 

159 # clean-up the overlayed path 

160 if overlayed is not None: 

161 overlayed = overlayed.strip() 

162 

163 for k, v in dataset.items(): 

164 if k.startswith("_"): 

165 logger.info(f"Skipping dataset '{k}' (not to be evaluated)") 

166 continue 

167 

168 logger.info(f"Running inference on '{k}' set...") 

169 

170 # PyTorch dataloader 

171 multiproc_kwargs = dict() 

172 if parallel < 0: 

173 multiproc_kwargs["num_workers"] = 0 

174 else: 

175 multiproc_kwargs["num_workers"] = ( 

176 parallel or multiprocessing.cpu_count() 

177 ) 

178 

179 if multiproc_kwargs["num_workers"] > 0 and sys.platform.startswith( 

180 "darwin" 

181 ): 

182 multiproc_kwargs[ 

183 "multiprocessing_context" 

184 ] = multiprocessing.get_context("spawn") 

185 

186 data_loader = DataLoader( 

187 dataset=v, 

188 batch_size=batch_size, 

189 shuffle=False, 

190 pin_memory=torch.cuda.is_available(), 

191 **multiproc_kwargs, 

192 ) 

193 

194 run(model, data_loader, k, device, output_folder, overlayed)