Coverage for src/deepdraw/script/predict.py: 94%
48 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5import multiprocessing
6import os
7import sys
9import click
10import torch
12from clapper.click import ConfigCommand, ResourceOption, verbosity_option
13from clapper.logging import setup
14from torch.utils.data import DataLoader
16logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
18from ..engine.predictor import run
19from ..utils.checkpointer import Checkpointer
20from .common import download_to_tempfile, setup_pytorch_device
23@click.command(
24 entry_point_group="deepdraw.config",
25 cls=ConfigCommand,
26 epilog="""Examples:
28\b
29 1. Runs prediction on an existing dataset configuration:
31 .. code:: sh
33 $ deepdraw predict -vv m2unet drive --weight=path/to/model_final_epoch.pth --output-folder=path/to/predictions
35\b
36 2. To run prediction on a folder with your own images, you must first
37 specify resizing, cropping, etc, so that the image can be correctly
38 input to the model. Failing to do so will likely result in poor
39 performance. To figure out such specifications, you must consult the
40 dataset configuration used for **training** the provided model. Once
41 you figured this out, do the following:
43 .. code:: sh
45 $ deepdraw config copy csv-dataset-example mydataset.py
46 # modify "mydataset.py" to include the base path and required transforms
47 $ deepdraw predict -vv m2unet mydataset.py --weight=path/to/model_final_epoch.pth --output-folder=path/to/predictions
48""",
49)
50@click.option(
51 "--output-folder",
52 "-o",
53 help="Path where to store the predictions (created if does not exist)",
54 required=True,
55 default="results",
56 cls=ResourceOption,
57 type=click.Path(),
58)
59@click.option(
60 "--model",
61 "-m",
62 help="A torch.nn.Module instance implementing the network to be evaluated",
63 required=True,
64 cls=ResourceOption,
65)
66@click.option(
67 "--dataset",
68 "-d",
69 help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
70 "to be used for running prediction, possibly including all pre-processing "
71 "pipelines required or, optionally, a dictionary mapping string keys to "
72 "torch.utils.data.dataset.Dataset instances. All keys that do not start "
73 "with an underscore (_) will be processed.",
74 required=True,
75 cls=ResourceOption,
76)
77@click.option(
78 "--batch-size",
79 "-b",
80 help="Number of samples in every batch (this parameter affects memory requirements for the network)",
81 required=True,
82 show_default=True,
83 default=1,
84 type=click.IntRange(min=1),
85 cls=ResourceOption,
86)
87@click.option(
88 "--device",
89 "-d",
90 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
91 show_default=True,
92 required=True,
93 default="cpu",
94 cls=ResourceOption,
95)
96@click.option(
97 "--weight",
98 "-w",
99 help="Path or URL to pretrained model file (.pth extension)",
100 required=True,
101 cls=ResourceOption,
102)
103@click.option(
104 "--overlayed",
105 "-O",
106 help="Creates overlayed representations of the output probability maps on "
107 "top of input images (store results as PNG files). If not set, or empty "
108 "then do **NOT** output overlayed images. Otherwise, the parameter "
109 "represents the name of a folder where to store those",
110 show_default=True,
111 default=None,
112 required=False,
113 cls=ResourceOption,
114)
115@click.option(
116 "--parallel",
117 "-P",
118 help="""Use multiprocessing for data loading: if set to -1 (default),
119 disables multiprocessing data loading. Set to 0 to enable as many data
120 loading instances as processing cores as available in the system. Set to
121 >= 1 to enable that many multiprocessing instances for data loading.""",
122 type=click.IntRange(min=-1),
123 show_default=True,
124 required=True,
125 default=-1,
126 cls=ResourceOption,
127)
128@verbosity_option(logger=logger, cls=ResourceOption)
129@click.pass_context
130def predict(
131 ctx,
132 output_folder,
133 model,
134 dataset,
135 batch_size,
136 device,
137 weight,
138 overlayed,
139 parallel,
140 verbose,
141 **kwargs,
142):
143 """Predicts vessel map (probabilities) on input images."""
145 device = setup_pytorch_device(device)
147 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
149 if weight.startswith("http"):
150 logger.info(f"Temporarily downloading '{weight}'...")
151 f = download_to_tempfile(weight, progress=True)
152 weight_fullpath = os.path.abspath(f.name)
153 else:
154 weight_fullpath = os.path.abspath(weight)
156 checkpointer = Checkpointer(model)
157 checkpointer.load(weight_fullpath)
159 # clean-up the overlayed path
160 if overlayed is not None:
161 overlayed = overlayed.strip()
163 for k, v in dataset.items():
164 if k.startswith("_"):
165 logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
166 continue
168 logger.info(f"Running inference on '{k}' set...")
170 # PyTorch dataloader
171 multiproc_kwargs = dict()
172 if parallel < 0:
173 multiproc_kwargs["num_workers"] = 0
174 else:
175 multiproc_kwargs["num_workers"] = (
176 parallel or multiprocessing.cpu_count()
177 )
179 if multiproc_kwargs["num_workers"] > 0 and sys.platform.startswith(
180 "darwin"
181 ):
182 multiproc_kwargs[
183 "multiprocessing_context"
184 ] = multiprocessing.get_context("spawn")
186 data_loader = DataLoader(
187 dataset=v,
188 batch_size=batch_size,
189 shuffle=False,
190 pin_memory=torch.cuda.is_available(),
191 **multiproc_kwargs,
192 )
194 run(model, data_loader, k, device, output_folder, overlayed)