Coverage for src/deepdraw/configs/datasets/shenzhen/__init__.py: 96%
50 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
« prev ^ index » next coverage.py v7.3.1, created at 2023-11-30 15:00 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
6def _maker(protocol, n):
7 from ....data.shenzhen import dataset as raw
8 from ....data.transforms import Resize, ShrinkIntoSquare
9 from .. import make_dataset as mk
11 return mk(raw.subsets(protocol), [ShrinkIntoSquare(), Resize((n, n))])
14def _maker_augmented(protocol, n):
15 from ....data.shenzhen import dataset as raw
16 from ....data.transforms import ColorJitter as _jitter
17 from ....data.transforms import Compose as _compose
18 from ....data.transforms import GaussianBlur as _blur
19 from ....data.transforms import RandomHorizontalFlip as _hflip
20 from ....data.transforms import RandomRotation as _rotation
21 from ....data.transforms import Resize as _resize
22 from ....data.transforms import ShrinkIntoSquare as _shrinkintosq
23 from .. import make_subset
25 def mk_aug_subset(subsets, train_transforms, all_transforms):
26 retval = {}
28 for key in subsets.keys():
29 retval[key] = make_subset(subsets[key], transforms=all_transforms)
30 if key == "train":
31 retval["__train__"] = make_subset(
32 subsets[key],
33 transforms=train_transforms,
34 )
35 else:
36 if key == "validation":
37 retval["__valid__"] = retval[key]
39 if ("__train__" in retval) and ("__valid__" not in retval):
40 retval["__valid__"] = retval["__train__"]
42 return retval
44 return mk_aug_subset(
45 subsets=raw.subsets(protocol),
46 all_transforms=[_shrinkintosq(), _resize((n, n))],
47 train_transforms=[
48 _compose(
49 [
50 _shrinkintosq(),
51 _resize((n, n)),
52 _rotation(degrees=15, p=0.5),
53 _hflip(p=0.5),
54 _jitter(p=0.5),
55 _blur(p=0.5),
56 ]
57 )
58 ],
59 )
62def _maker_augmented_gt_box(protocol, n):
63 from ....data.shenzhen import dataset as raw
64 from ....data.transforms import ColorJitter as _jitter
65 from ....data.transforms import Compose as _compose
66 from ....data.transforms import GaussianBlur as _blur
67 from ....data.transforms import GroundTruthCrop as _gtcrop
68 from ....data.transforms import RandomHorizontalFlip as _hflip
69 from ....data.transforms import RandomRotation as _rotation
70 from ....data.transforms import Resize as _resize
71 from ....data.transforms import ShrinkIntoSquare as _shrinkintosq
72 from .. import make_subset
74 def mk_aug_subset(subsets, train_transforms, all_transforms):
75 retval = {}
77 for key in subsets.keys():
78 retval[key] = make_subset(subsets[key], transforms=all_transforms)
79 if key == "train":
80 retval["__train__"] = make_subset(
81 subsets[key],
82 transforms=train_transforms,
83 )
84 else:
85 if key == "validation":
86 retval["__valid__"] = retval[key]
88 if ("__train__" in retval) and ("__valid__" not in retval):
89 retval["__valid__"] = retval["__train__"]
91 return retval
93 return mk_aug_subset(
94 subsets=raw.subsets(protocol),
95 all_transforms=[
96 _shrinkintosq(),
97 _gtcrop(extra_area=0.2),
98 _resize((n, n)),
99 ],
100 train_transforms=[
101 _compose(
102 [
103 _shrinkintosq(),
104 _gtcrop(extra_area=0.2),
105 _resize((n, n)),
106 _rotation(degrees=15, p=0.5),
107 _hflip(p=0.5),
108 _jitter(p=0.5),
109 _blur(p=0.5),
110 ]
111 )
112 ],
113 )