Coverage for /scratch/builds/bob/bob.ip.binseg/miniconda/conda-bld/bob.ip.binseg_1673966692152/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_p/lib/python3.10/site-packages/bob/ip/common/script/predict.py: 88%

40 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 15:03 +0000

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import logging 

5import multiprocessing 

6import os 

7import sys 

8 

9import torch 

10 

11from torch.utils.data import DataLoader 

12 

13from ..utils.checkpointer import Checkpointer 

14from .common import download_to_tempfile, setup_pytorch_device 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19def base_predict( 

20 output_folder, 

21 model, 

22 dataset, 

23 batch_size, 

24 device, 

25 weight, 

26 overlayed, 

27 parallel, 

28 detection, 

29 **kwargs, 

30): 

31 """Create base predict function for segmentation / detection tasks.""" 

32 device = setup_pytorch_device(device) 

33 

34 dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) 

35 

36 if weight.startswith("http"): 

37 logger.info(f"Temporarily downloading '{weight}'...") 

38 f = download_to_tempfile(weight, progress=True) 

39 weight_fullpath = os.path.abspath(f.name) 

40 else: 

41 weight_fullpath = os.path.abspath(weight) 

42 

43 checkpointer = Checkpointer(model) 

44 checkpointer.load(weight_fullpath) 

45 

46 # clean-up the overlayed path 

47 if overlayed is not None: 

48 overlayed = overlayed.strip() 

49 

50 for k, v in dataset.items(): 

51 

52 if k.startswith("_"): 

53 logger.info(f"Skipping dataset '{k}' (not to be evaluated)") 

54 continue 

55 

56 logger.info(f"Running inference on '{k}' set...") 

57 

58 # PyTorch dataloader 

59 multiproc_kwargs = dict() 

60 if parallel < 0: 

61 multiproc_kwargs["num_workers"] = 0 

62 else: 

63 multiproc_kwargs["num_workers"] = ( 

64 parallel or multiprocessing.cpu_count() 

65 ) 

66 

67 if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": 

68 multiproc_kwargs[ 

69 "multiprocessing_context" 

70 ] = multiprocessing.get_context("spawn") 

71 

72 if detection: 

73 from ...detect.engine.predictor import run 

74 

75 def _collate_fn(batch): 

76 return tuple(zip(*batch)) 

77 

78 data_loader = DataLoader( 

79 dataset=v, 

80 batch_size=batch_size, 

81 shuffle=False, 

82 pin_memory=torch.cuda.is_available(), 

83 collate_fn=_collate_fn, 

84 **multiproc_kwargs, 

85 ) 

86 else: 

87 from ...binseg.engine.predictor import run 

88 

89 data_loader = DataLoader( 

90 dataset=v, 

91 batch_size=batch_size, 

92 shuffle=False, 

93 pin_memory=torch.cuda.is_available(), 

94 **multiproc_kwargs, 

95 ) 

96 

97 run(model, data_loader, k, device, output_folder, overlayed)