Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import logging
5import multiprocessing
6import os
7import sys
9import click
10import torch
12from torch.utils.data import DataLoader
14from bob.extension.scripts.click_helper import (
15 ConfigCommand,
16 ResourceOption,
17 verbosity_option,
18)
20from ..engine.predictor import run
21from ..utils.checkpointer import Checkpointer
22from .binseg import download_to_tempfile, setup_pytorch_device
24logger = logging.getLogger(__name__)
27@click.command(
28 entry_point_group="bob.ip.binseg.config",
29 cls=ConfigCommand,
30 epilog="""Examples:
32\b
33 1. Runs prediction on an existing dataset configuration:
34\b
35 $ bob binseg predict -vv m2unet drive --weight=path/to/model_final.pth --output-folder=path/to/predictions
36\b
37 2. To run prediction on a folder with your own images, you must first
38 specify resizing, cropping, etc, so that the image can be correctly
39 input to the model. Failing to do so will likely result in poor
40 performance. To figure out such specifications, you must consult the
41 dataset configuration used for **training** the provided model. Once
42 you figured this out, do the following:
43\b
44 $ bob binseg config copy csv-dataset-example mydataset.py
45 # modify "mydataset.py" to include the base path and required transforms
46 $ bob binseg predict -vv m2unet mydataset.py --weight=path/to/model_final.pth --output-folder=path/to/predictions
47""",
48)
49@click.option(
50 "--output-folder",
51 "-o",
52 help="Path where to store the predictions (created if does not exist)",
53 required=True,
54 default="results",
55 cls=ResourceOption,
56 type=click.Path(),
57)
58@click.option(
59 "--model",
60 "-m",
61 help="A torch.nn.Module instance implementing the network to be evaluated",
62 required=True,
63 cls=ResourceOption,
64)
65@click.option(
66 "--dataset",
67 "-d",
68 help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
69 "to be used for running prediction, possibly including all pre-processing "
70 "pipelines required or, optionally, a dictionary mapping string keys to "
71 "torch.utils.data.dataset.Dataset instances. All keys that do not start "
72 "with an underscore (_) will be processed.",
73 required=True,
74 cls=ResourceOption,
75)
76@click.option(
77 "--batch-size",
78 "-b",
79 help="Number of samples in every batch (this parameter affects memory requirements for the network)",
80 required=True,
81 show_default=True,
82 default=1,
83 type=click.IntRange(min=1),
84 cls=ResourceOption,
85)
86@click.option(
87 "--device",
88 "-d",
89 help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
90 show_default=True,
91 required=True,
92 default="cpu",
93 cls=ResourceOption,
94)
95@click.option(
96 "--weight",
97 "-w",
98 help="Path or URL to pretrained model file (.pth extension)",
99 required=True,
100 cls=ResourceOption,
101)
102@click.option(
103 "--overlayed",
104 "-O",
105 help="Creates overlayed representations of the output probability maps on "
106 "top of input images (store results as PNG files). If not set, or empty "
107 "then do **NOT** output overlayed images. Otherwise, the parameter "
108 "represents the name of a folder where to store those",
109 show_default=True,
110 default=None,
111 required=False,
112 cls=ResourceOption,
113)
114@click.option(
115 "--multiproc-data-loading",
116 "-P",
117 help="""Use multiprocessing for data loading: if set to -1 (default),
118 disables multiprocessing data loading. Set to 0 to enable as many data
119 loading instances as processing cores as available in the system. Set to
120 >= 1 to enable that many multiprocessing instances for data loading.""",
121 type=click.IntRange(min=-1),
122 show_default=True,
123 required=True,
124 default=-1,
125 cls=ResourceOption,
126)
127@verbosity_option(cls=ResourceOption)
128def predict(
129 output_folder,
130 model,
131 dataset,
132 batch_size,
133 device,
134 weight,
135 overlayed,
136 multiproc_data_loading,
137 **kwargs,
138):
139 """Predicts vessel map (probabilities) on input images"""
141 device = setup_pytorch_device(device)
143 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
145 if weight.startswith("http"):
146 logger.info(f"Temporarily downloading '{weight}'...")
147 f = download_to_tempfile(weight, progress=True)
148 weight_fullpath = os.path.abspath(f.name)
149 else:
150 weight_fullpath = os.path.abspath(weight)
152 checkpointer = Checkpointer(model)
153 checkpointer.load(weight_fullpath)
155 # clean-up the overlayed path
156 if overlayed is not None:
157 overlayed = overlayed.strip()
159 for k, v in dataset.items():
161 if k.startswith("_"):
162 logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
163 continue
165 logger.info(f"Running inference on '{k}' set...")
167 multiproc_kwargs = dict()
168 if multiproc_data_loading < 0:
169 multiproc_kwargs["num_workers"] = 0
170 elif multiproc_data_loading == 0:
171 multiproc_kwargs["num_workers"] = multiprocessing.cpu_count()
172 else:
173 multiproc_kwargs["num_workers"] = multiproc_data_loading
175 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
176 multiproc_kwargs[
177 "multiprocessing_context"
178 ] = multiprocessing.get_context("spawn")
180 data_loader = DataLoader(
181 dataset=v,
182 batch_size=batch_size,
183 shuffle=False,
184 pin_memory=torch.cuda.is_available(),
185 **multiproc_kwargs,
186 )
187 run(model, data_loader, k, device, output_folder, overlayed)