Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/data/utils.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

50 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4 

5"""Common utilities""" 

6 

7import contextlib 

8 

9import torch 

10import torch.utils.data 

11import PIL 

12import numpy as np 

13from torchvision.transforms import Compose, ToTensor 

14 

15 

16class SampleListDataset(torch.utils.data.Dataset): 

17 """PyTorch dataset wrapper around Sample lists 

18 

19 A transform object can be passed that will be applied to the image, ground 

20 truth and mask (if present). 

21 

22 It supports indexing such that dataset[i] can be used to get ith sample. 

23 

24 Parameters 

25 ---------- 

26 

27 samples : list 

28 A list of :py:class:`bob.med.tb.data.sample.Sample` objects 

29 

30 transforms : :py:class:`list`, Optional 

31 a list of transformations to be applied to **both** image and 

32 ground-truth data. Notice a last transform 

33 (:py:class:`torchvision.transforms.transforms.ToTensor`) is always  

34 applied - you do not need to add that. 

35 

36 """ 

37 

38 def __init__(self, samples, transforms=[]): 

39 

40 self._samples = samples 

41 self.transforms = transforms 

42 

43 @property 

44 def transforms(self): 

45 return self._transforms.transforms[:-1] 

46 

47 @transforms.setter 

48 def transforms(self, l): 

49 if any([isinstance(t, ToTensor) for t in l]): 

50 self._transforms = Compose(l) 

51 else: 

52 self._transforms = Compose(l + [ToTensor()]) 

53 

54 def copy(self, transforms=None): 

55 """Returns a deep copy of itself, optionally resetting transforms 

56 

57 Parameters 

58 ---------- 

59 

60 transforms : :py:class:`list`, Optional 

61 An optional list of transforms to set in the copy. If not 

62 specified, use ``self.transforms``. 

63 """ 

64 

65 return SampleListDataset(self._samples, transforms or self.transforms) 

66 

67 def random_permute(self, feature): 

68 """Randomly permute feature values from all samples 

69 

70 Useful for permutation feature importance computation 

71 

72 Parameters 

73 ---------- 

74 

75 feature : int 

76 The position of the feature 

77 """ 

78 feature_values = np.zeros(len(self)) 

79 

80 for k, s in enumerate(self._samples): 

81 features = s.data['data'] 

82 if isinstance(features, list): 

83 feature_values[k] = features[feature] 

84 

85 np.random.shuffle(feature_values) 

86 

87 for k, s in enumerate(self._samples): 

88 features = s.data["data"] 

89 features[feature] = feature_values[k] 

90 

91 def __len__(self): 

92 """ 

93 

94 Returns 

95 ------- 

96 

97 size : int 

98 size of the dataset 

99 

100 """ 

101 return len(self._samples) 

102 

103 def __getitem__(self, key): 

104 """ 

105 

106 Parameters 

107 ---------- 

108 

109 key : int, slice 

110 

111 Returns 

112 ------- 

113 

114 sample : list 

115 The sample data: ``[key, image, label]`` 

116 

117 """ 

118 

119 if isinstance(key, slice): 

120 return [self[k] for k in range(*key.indices(len(self)))] 

121 else: # we try it as an int 

122 item = data = self._samples[key] 

123 if not isinstance(data, dict): 

124 key = item.key 

125 data = item.data # triggers data loading 

126 

127 retval = data["data"] 

128 

129 if self._transforms and isinstance(retval, PIL.Image.Image): 

130 retval = self._transforms(retval) 

131 elif isinstance(retval, list): 

132 retval = torch.FloatTensor(retval) 

133 

134 if "label" in data: 

135 if isinstance(data["label"], list): 

136 return [key, retval, torch.FloatTensor(data["label"])] 

137 else: 

138 return [key, retval, data["label"]] 

139 

140 return [item.key, retval]