1#!/usr/bin/env python
2# coding=utf-8
3
4import glob
5import logging
6import os
7
8import numpy
9import skimage.color
10import skimage.io
11import skimage.measure
12import skimage.morphology
13
14from bob.extension import rc
15
16logger = logging.getLogger(__name__)
17
18
19def base_mkmask(dataset, globs, threshold, output_folder, **kwargs):
20 """
21 Base function for mkmask.
22
23 """
24
25 def threshold_and_closing(input_path, t, width=5):
26 """Creates a "rough" mask from the input image, returns binary equivalent
27
28 The mask will be created by running a simple threshold operation followed
29 by a morphological closing
30
31
32 Arguments
33 =========
34
35 input_path : str
36 The path leading to the image from where the mask needs to be extracted
37
38 t : int
39 Threshold to apply on the original image
40
41 width : int
42 Width of the disc to use for the closing operation
43
44
45 Returns
46 =======
47
48 mask : numpy.ndarray
49 A 2D array, with the same size as the input image, where ``True``
50 pixels correspond to the "valid" regions of the mask.
51
52 """
53
54 img = skimage.util.img_as_ubyte(
55 skimage.io.imread(input_path, as_gray=True)
56 )
57 mask = img > t
58 return skimage.morphology.binary_opening(
59 mask, skimage.morphology.disk(width)
60 )
61
62 def count_blobs(mask):
63 """Counts "white" blobs in a binary mask, outputs counts
64
65
66 Arguments
67 =========
68
69 mask : numpy.ndarray
70 A 2D array, with the same size as the input image, where ``255``
71 pixels correspond to the "valid" regions of the mask. ``0`` means
72 background.
73
74
75 Returns
76 =======
77
78 count : int
79 The number of connected blobs in the provided mask.
80
81 """
82 return skimage.measure.label(mask, return_num=True)[1]
83
84 def process_glob(base_path, use_glob, output_path, threshold):
85 """Recursively process a set of images
86
87 Arguments
88 =========
89
90 base_path : str
91 The base directory where to look for files matching a certain name
92 patternrc.get("bob.ip.binseg." + dataset + ".datadir"):
93
94 use_glob : list
95 A list of globs to use for matching filenames inside ``base_path``
96
97 output_path : str
98 Where to place the results of procesing
99
100 """
101
102 files = []
103 for g in use_glob:
104 files += glob.glob(os.path.join(base_path, g))
105 for i, path in enumerate(files):
106 basename = os.path.relpath(path, base_path)
107 basename_without_extension = os.path.splitext(basename)[0]
108 logger.info(
109 f"Processing {basename_without_extension} ({i+1}/{len(files)})..."
110 )
111 dest = os.path.join(
112 output_path, basename_without_extension + ".png"
113 )
114 destdir = os.path.dirname(dest)
115 if not os.path.exists(destdir):
116 os.makedirs(destdir)
117 mask = threshold_and_closing(path, threshold)
118 immask = mask.astype(numpy.uint8) * 255
119 nblobs = count_blobs(immask)
120 if nblobs != 1:
121 logger.warning(
122 f" -> WARNING: found {nblobs} blobs in the saved mask "
123 f"(should be one)"
124 )
125 skimage.io.imsave(dest, immask)
126
127 if rc.get("bob.ip.binseg." + dataset + ".datadir"):
128 base_path = rc.get("bob.ip.binseg." + dataset + ".datadir")
129 else:
130 base_path = dataset
131
132 list_globs = []
133 for g in globs:
134 list_globs.append(g)
135 threshold = int(threshold)
136 process_glob(
137 base_path=base_path,
138 use_glob=list_globs,
139 output_path=output_folder,
140 threshold=threshold,
141 )