Coverage for src/bob/pad/face/transformer/histogram.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 23:20 +0100

1import numpy as np 

2 

3from sklearn.base import BaseEstimator, TransformerMixin 

4from sklearn.utils import check_array 

5 

6 

7def _get_cropping_size(image_size, patch_size): 

8 # How many pixels missing to cover the whole image 

9 r = image_size % patch_size 

10 # Spit gap into two evenly 

11 before = r // 2 

12 after = image_size - (r - before) 

13 return before, after 

14 

15 

16def _extract_patches(image, patch_size): 

17 # https://stackoverflow.com/a/16858283 

18 h, w = image.shape 

19 nrows, ncols = patch_size 

20 if h % nrows != 0 or w % ncols != 0: 

21 w_left, w_right = _get_cropping_size(w, ncols) 

22 h_top, h_bottom = _get_cropping_size(h, nrows) 

23 # Perform center crop 

24 image = image[h_top:h_bottom, w_left:w_right] 

25 return ( 

26 image.reshape(h // nrows, nrows, -1, ncols) 

27 .swapaxes(1, 2) 

28 .reshape(-1, nrows, ncols) 

29 ) 

30 

31 

32class SpatialHistogram(TransformerMixin, BaseEstimator): 

33 """ 

34 Split images into a grid of patches, compute histogram on each one of them 

35 and concatenate them to obtain the final descriptor. 

36 """ 

37 

38 def __init__(self, grid_size=(4, 4), range=(0, 256), nbins=256): 

39 """ 

40 Constructor 

41 :param grid_size: Tuple `(grid_y, grid_x)` indicating the number of 

42 patches to extract in each directions 

43 :param range: Tuple `(h_min, h_max)` indicating the histogram range. 

44 cf numpy.histogram 

45 :param nbins: Number of bins in the histogram, cf numpy.histogram 

46 """ 

47 self.grid_size = grid_size 

48 self.range = range 

49 self.nbins = nbins 

50 

51 def fit(self, X, y): 

52 return self 

53 

54 def transform(self, X): 

55 X = check_array(X, allow_nd=True) # X.shape == (N, H, W) 

56 histo = [] 

57 for sample in X: 

58 h = self._spatial_histogram(sample) # [grid_x * grid_y * nbins] 

59 histo.append(h) 

60 return np.asarray(histo) 

61 

62 def _spatial_histogram(self, image): 

63 """Compute spatial histogram for a given images""" 

64 patch_size = [s // g for s, g in zip(image.shape, self.grid_size)] 

65 patches = _extract_patches(image=image, patch_size=patch_size) 

66 hist = [] 

67 for patch in patches: 

68 h, _ = np.histogram( 

69 patch, bins=self.nbins, range=self.range, density=True 

70 ) 

71 hist.append(h) 

72 return np.asarray(hist).reshape(-1) 

73 

74 def _more_tags(self): 

75 return {"stateless": True, "requires_fit": False}