Coverage for src/deepdraw/configs/datasets/cxr8/__init__.py: 88%

48 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-11-30 15:00 +0100

1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> 

2# 

3# SPDX-License-Identifier: GPL-3.0-or-later 

4 

5 

6def _maker(protocol, n): 

7 from ....data.cxr8 import dataset as raw 

8 from ....data.transforms import Resize 

9 from .. import make_dataset as mk 

10 

11 return mk(raw.subsets(protocol), [Resize((n, n))]) 

12 

13 

14def _maker_augmented(protocol, n): 

15 from ....data.cxr8 import dataset as raw 

16 from ....data.transforms import ColorJitter as _jitter 

17 from ....data.transforms import Compose as _compose 

18 from ....data.transforms import GaussianBlur as _blur 

19 from ....data.transforms import RandomHorizontalFlip as _hflip 

20 from ....data.transforms import RandomRotation as _rotation 

21 from ....data.transforms import Resize as _resize 

22 from .. import make_subset 

23 

24 def mk_aug_subset(subsets, train_transforms, all_transforms): 

25 retval = {} 

26 

27 for key in subsets.keys(): 

28 retval[key] = make_subset(subsets[key], transforms=all_transforms) 

29 if key == "train": 

30 retval["__train__"] = make_subset( 

31 subsets[key], 

32 transforms=train_transforms, 

33 ) 

34 else: 

35 if key == "validation": 

36 retval["__valid__"] = retval[key] 

37 

38 if ("__train__" in retval) and ("__valid__" not in retval): 

39 retval["__valid__"] = retval["__train__"] 

40 

41 return retval 

42 

43 return mk_aug_subset( 

44 subsets=raw.subsets(protocol), 

45 all_transforms=[_resize((n, n))], 

46 train_transforms=[ 

47 _compose( 

48 [ 

49 _resize((n, n)), 

50 _rotation(degrees=15, p=0.5), 

51 _hflip(p=0.5), 

52 _jitter(p=0.5), 

53 _blur(p=0.5), 

54 ] 

55 ) 

56 ], 

57 ) 

58 

59 

60def _maker_augmented_gt_box(protocol, n): 

61 from ....data.cxr8 import dataset as raw 

62 from ....data.transforms import ColorJitter as _jitter 

63 from ....data.transforms import Compose as _compose 

64 from ....data.transforms import GaussianBlur as _blur 

65 from ....data.transforms import GroundTruthCrop as _gtcrop 

66 from ....data.transforms import RandomHorizontalFlip as _hflip 

67 from ....data.transforms import RandomRotation as _rotation 

68 from ....data.transforms import Resize as _resize 

69 from .. import make_subset 

70 

71 def mk_aug_subset(subsets, train_transforms, all_transforms): 

72 retval = {} 

73 

74 for key in subsets.keys(): 

75 retval[key] = make_subset(subsets[key], transforms=all_transforms) 

76 if key == "train": 

77 retval["__train__"] = make_subset( 

78 subsets[key], 

79 transforms=train_transforms, 

80 ) 

81 else: 

82 if key == "validation": 

83 retval["__valid__"] = retval[key] 

84 

85 if ("__train__" in retval) and ("__valid__" not in retval): 

86 retval["__valid__"] = retval["__train__"] 

87 

88 return retval 

89 

90 return mk_aug_subset( 

91 subsets=raw.subsets(protocol), 

92 all_transforms=[_gtcrop(extra_area=0.2), _resize((n, n))], 

93 train_transforms=[ 

94 _compose( 

95 [ 

96 _gtcrop(extra_area=0.2), 

97 _resize((n, n)), 

98 _rotation(degrees=15, p=0.5), 

99 _hflip(p=0.5), 

100 _jitter(p=0.5), 

101 _blur(p=0.5), 

102 ] 

103 ) 

104 ], 

105 )