Coverage for src/deepdraw/configs/datasets/shenzhen/__init__.py: 96%

50 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5 

6def _maker(protocol, n): 

7 from ....data.shenzhen import dataset as raw 

8 from ....data.transforms import Resize, ShrinkIntoSquare 

9 from .. import make_dataset as mk 

10 

11 return mk(raw.subsets(protocol), [ShrinkIntoSquare(), Resize((n, n))]) 

12 

13 

14def _maker_augmented(protocol, n): 

15 from ....data.shenzhen import dataset as raw 

16 from ....data.transforms import ColorJitter as _jitter 

17 from ....data.transforms import Compose as _compose 

18 from ....data.transforms import GaussianBlur as _blur 

19 from ....data.transforms import RandomHorizontalFlip as _hflip 

20 from ....data.transforms import RandomRotation as _rotation 

21 from ....data.transforms import Resize as _resize 

22 from ....data.transforms import ShrinkIntoSquare as _shrinkintosq 

23 from .. import make_subset 

24 

25 def mk_aug_subset(subsets, train_transforms, all_transforms): 

26 retval = {} 

27 

28 for key in subsets.keys(): 

29 retval[key] = make_subset(subsets[key], transforms=all_transforms) 

30 if key == "train": 

31 retval["__train__"] = make_subset( 

32 subsets[key], 

33 transforms=train_transforms, 

34 ) 

35 else: 

36 if key == "validation": 

37 retval["__valid__"] = retval[key] 

38 

39 if ("__train__" in retval) and ("__valid__" not in retval): 

40 retval["__valid__"] = retval["__train__"] 

41 

42 return retval 

43 

44 return mk_aug_subset( 

45 subsets=raw.subsets(protocol), 

46 all_transforms=[_shrinkintosq(), _resize((n, n))], 

47 train_transforms=[ 

48 _compose( 

49 [ 

50 _shrinkintosq(), 

51 _resize((n, n)), 

52 _rotation(degrees=15, p=0.5), 

53 _hflip(p=0.5), 

54 _jitter(p=0.5), 

55 _blur(p=0.5), 

56 ] 

57 ) 

58 ], 

59 ) 

60 

61 

62def _maker_augmented_gt_box(protocol, n): 

63 from ....data.shenzhen import dataset as raw 

64 from ....data.transforms import ColorJitter as _jitter 

65 from ....data.transforms import Compose as _compose 

66 from ....data.transforms import GaussianBlur as _blur 

67 from ....data.transforms import GroundTruthCrop as _gtcrop 

68 from ....data.transforms import RandomHorizontalFlip as _hflip 

69 from ....data.transforms import RandomRotation as _rotation 

70 from ....data.transforms import Resize as _resize 

71 from ....data.transforms import ShrinkIntoSquare as _shrinkintosq 

72 from .. import make_subset 

73 

74 def mk_aug_subset(subsets, train_transforms, all_transforms): 

75 retval = {} 

76 

77 for key in subsets.keys(): 

78 retval[key] = make_subset(subsets[key], transforms=all_transforms) 

79 if key == "train": 

80 retval["__train__"] = make_subset( 

81 subsets[key], 

82 transforms=train_transforms, 

83 ) 

84 else: 

85 if key == "validation": 

86 retval["__valid__"] = retval[key] 

87 

88 if ("__train__" in retval) and ("__valid__" not in retval): 

89 retval["__valid__"] = retval["__train__"] 

90 

91 return retval 

92 

93 return mk_aug_subset( 

94 subsets=raw.subsets(protocol), 

95 all_transforms=[ 

96 _shrinkintosq(), 

97 _gtcrop(extra_area=0.2), 

98 _resize((n, n)), 

99 ], 

100 train_transforms=[ 

101 _compose( 

102 [ 

103 _shrinkintosq(), 

104 _gtcrop(extra_area=0.2), 

105 _resize((n, n)), 

106 _rotation(degrees=15, p=0.5), 

107 _hflip(p=0.5), 

108 _jitter(p=0.5), 

109 _blur(p=0.5), 

110 ] 

111 ) 

112 ], 

113 )