1#!/usr/bin/env python
2# coding=utf-8
3
4import logging
5import multiprocessing
6import os
7import sys
8
9import torch
10
11from torch.utils.data import DataLoader
12
13from ..utils.checkpointer import Checkpointer
14from .common import download_to_tempfile, setup_pytorch_device
15
16logger = logging.getLogger(__name__)
17
18
19def base_predict(
20 output_folder,
21 model,
22 dataset,
23 batch_size,
24 device,
25 weight,
26 overlayed,
27 parallel,
28 detection,
29 **kwargs,
30):
31 """Create base predict function for segmentation / detection tasks."""
32 device = setup_pytorch_device(device)
33
34 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
35
36 if weight.startswith("http"):
37 logger.info(f"Temporarily downloading '{weight}'...")
38 f = download_to_tempfile(weight, progress=True)
39 weight_fullpath = os.path.abspath(f.name)
40 else:
41 weight_fullpath = os.path.abspath(weight)
42
43 checkpointer = Checkpointer(model)
44 checkpointer.load(weight_fullpath)
45
46 # clean-up the overlayed path
47 if overlayed is not None:
48 overlayed = overlayed.strip()
49
50 for k, v in dataset.items():
51
52 if k.startswith("_"):
53 logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
54 continue
55
56 logger.info(f"Running inference on '{k}' set...")
57
58 # PyTorch dataloader
59 multiproc_kwargs = dict()
60 if parallel < 0:
61 multiproc_kwargs["num_workers"] = 0
62 else:
63 multiproc_kwargs["num_workers"] = (
64 parallel or multiprocessing.cpu_count()
65 )
66
67 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
68 multiproc_kwargs[
69 "multiprocessing_context"
70 ] = multiprocessing.get_context("spawn")
71
72 if detection:
73 from ...detect.engine.predictor import run
74
75 def _collate_fn(batch):
76 return tuple(zip(*batch))
77
78 data_loader = DataLoader(
79 dataset=v,
80 batch_size=batch_size,
81 shuffle=False,
82 pin_memory=torch.cuda.is_available(),
83 collate_fn=_collate_fn,
84 **multiproc_kwargs,
85 )
86 else:
87 from ...binseg.engine.predictor import run
88
89 data_loader = DataLoader(
90 dataset=v,
91 batch_size=batch_size,
92 shuffle=False,
93 pin_memory=torch.cuda.is_available(),
94 **multiproc_kwargs,
95 )
96
97 run(model, data_loader, k, device, output_folder, overlayed)