Coverage for src/bob/bio/base/database/csv_database.py: 92%

190 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-14 21:41 +0100

1import csv 

2import functools 

3import logging 

4import os 

5 

6from collections import defaultdict 

7from pathlib import Path 

8from typing import Any, Callable, Iterable, Optional, TextIO 

9 

10import sklearn.pipeline 

11 

12from sklearn.base import BaseEstimator, TransformerMixin 

13 

14from bob.bio.base.pipelines.abstract_classes import Database 

15from bob.pipelines import ( 

16 DelayedSample, 

17 FileListDatabase, 

18 Sample, 

19 SampleSet, 

20 check_parameters_for_validity, 

21) 

22from bob.pipelines.dataset import open_definition_file 

23 

24from ..utils.annotations import read_annotation_file 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29def _sample_sets_to_samples(sample_sets): 

30 return functools.reduce( 

31 lambda x, y: x + y, (s.samples for s in sample_sets), [] 

32 ) 

33 

34 

35def _add_key(samples: list[Sample]) -> list[Sample]: 

36 """Adds a ``key`` attribute to all samples if ``key`` is not present 

37 

38 Will use ``path`` to create a unique ``key``. 

39 Note that this won't create unique keys if you have multiple times the same path in 

40 different samples. This will be problematic, as key are expected to be unique. 

41 """ 

42 

43 out = [] 

44 for sample in samples: 

45 if isinstance(sample, SampleSet): 

46 out.append(SampleSet(samples=_add_key(sample), parent=sample)) 

47 continue 

48 if not hasattr(sample, "key"): 

49 if hasattr(sample, "path"): 

50 sample.key = sample.path 

51 else: 

52 raise ValueError( 

53 f"Sample has no 'key' and no 'path' to infer it. {sample=}" 

54 ) 

55 out.append(sample) 

56 return out 

57 

58 

59def validate_bio_samples(samples): 

60 """Validates Samples or SampleSets for backwards compatibility reasons. 

61 

62 This will add a ``key`` attribute (if not already present) to each sample, copied 

63 from the path. 

64 """ 

65 

66 for sample in samples: 

67 if isinstance(sample, SampleSet): 

68 validate_bio_samples(sample.samples) 

69 if not hasattr(sample, "template_id"): 

70 raise ValueError( 

71 f"SampleSet must have a template_id attribute, got {sample}" 

72 ) 

73 if not hasattr(sample, "subject_id"): 

74 raise ValueError( 

75 f"SampleSet must have a subject_id attribute, got {sample}" 

76 ) 

77 continue 

78 

79 if not hasattr(sample, "key"): 

80 if hasattr(sample, "path"): 

81 sample.key = sample.path 

82 else: 

83 raise ValueError( 

84 f"Sample must have a key or a path attribute, got {sample}" 

85 ) 

86 

87 if not hasattr(sample, "subject_id"): 

88 raise ValueError( 

89 f"Sample must have a subject_id attribute, got {sample}" 

90 ) 

91 

92 

93class CSVDatabase(FileListDatabase, Database): 

94 """A csv file database. 

95 

96 The database protocol files should provide the following files: 

97 

98 

99 .. code-block:: text 

100 

101 dataset_protocols_path/ 

102 dataset_protocols_path/my_protocol/train/for_background_model.csv 

103 dataset_protocols_path/my_protocol/train/for_znorm.csv 

104 dataset_protocols_path/my_protocol/train/for_tnorm.csv 

105 dataset_protocols_path/my_protocol/dev/for_enrolling.csv 

106 dataset_protocols_path/my_protocol/dev/for_probing.csv 

107 dataset_protocols_path/my_protocol/dev/for_matching.csv 

108 dataset_protocols_path/my_protocol/eval/for_enrolling.csv 

109 dataset_protocols_path/my_protocol/eval/for_probing.csv 

110 dataset_protocols_path/my_protocol/eval/for_matching.csv ... 

111 

112 The ``for_background_model`` file should contain the following columns:: 

113 

114 key,subject_id 

115 subject1_image1.png,1 

116 subject1_image2.png,1 

117 subject2_image1.png,2 

118 subject2_image2.png,2 

119 

120 In all the csv files, you can have a column called ``path`` which will be 

121 used as ``key`` if the ``key`` is not specified. For example:: 

122 

123 path,subject_id 

124 subject1_image1.png,1 

125 subject1_image2.png,1 

126 subject2_image1.png,2 

127 subject2_image2.png,2 

128 

129 or:: 

130 

131 path,subject_id,key 

132 subject1_audio1.wav,1,subject1_audio1_channel1 

133 subject1_audio1.wav,1,subject1_audio1_channel2 

134 subject1_audio2.wav,1,subject1_audio2_channel1 

135 subject1_audio2.wav,1,subject1_audio2_channel2 

136 

137 The ``key`` column will be used to checkpoint each sample into a unique file and 

138 must therefore be unique across the whole dataset. 

139 

140 The ``for_enrolling.csv`` file should contain the following columns:: 

141 

142 key,subject_id,template_id 

143 subject3_image1.png,3,template_1 

144 subject3_image2.png,3,template_1 

145 subject3_image3.png,3,template_2 

146 subject3_image4.png,3,template_2 

147 subject4_image1.png,4,template_3 

148 subject4_image2.png,4,template_3 

149 subject4_image3.png,4,template_4 

150 subject4_image4.png,4,template_4 

151 

152 The ``for_probing.csv`` file should contain the following columns:: 

153 

154 key,subject_id,template_id 

155 subject5_image1.png,5,template_5 

156 subject5_image2.png,5,template_5 

157 subject5_image3.png,5,template_6 

158 subject5_image4.png,5,template_6 

159 subject6_image1.png,6,template_7 

160 subject6_image2.png,6,template_7 

161 subject6_image3.png,6,template_8 

162 subject6_image4.png,6,template_8 

163 

164 Subject identity (``subject_id``) is a unique identifier for one identity (one 

165 person). 

166 Template Identity (``template_id``) is an identifier used to group samples when 

167 they need to be enrolled or scored together. 

168 :class:`~bob.bio.base.pipelines.BioAlgorithm` will process 

169 these template. 

170 

171 

172 By default, each enroll ``template_id`` will be compared against each 

173 probe ``template_id`` to produce one score per pair. If you want to specify exact 

174 comparisons (sparse scoring), you can add the ``for_matching.csv`` with the 

175 following columns:: 

176 

177 enroll_template_id,probe_template_id 

178 template_1,template_5 

179 template_2,template_6 

180 template_3,template_5 

181 template_3,template_7 

182 template_4,template_5 

183 template_4,template_8 

184 

185 ``for_znorm.csv`` and ``for_tnorm.csv`` files are optional and are used for score 

186 normalization. See :class:`~bob.bio.base.pipelines.PipelineScoreNorm`. 

187 ``for_znorm.csv`` has the same format as ``for_probing.csv`` and 

188 ``for_tnorm.csv`` has the same format as ``for_enrolling.csv``. 

189 """ 

190 

191 def __init__( 

192 self, 

193 *, 

194 name: str, 

195 protocol: str, 

196 dataset_protocols_path: Optional[str] = None, 

197 transformer: Optional[sklearn.pipeline.Pipeline] = None, 

198 templates_metadata: Optional[list[str]] = None, 

199 annotation_type: Optional[str] = None, 

200 fixed_positions: Optional[dict[str, tuple[float, float]]] = None, 

201 memory_demanding=False, 

202 **kwargs, 

203 ): 

204 """ 

205 Parameters 

206 ---------- 

207 name 

208 The name of the database. 

209 protocol 

210 Name of the protocol folder in the CSV definition structure. 

211 dataset_protocol_path 

212 Path to the CSV files structure (see :ref:`bob.bio.base.database_interface` 

213 for more info). 

214 transformer 

215 An sklearn pipeline or equivalent transformer that handles some light 

216 preprocessing of the samples (This will always run locally). 

217 templates_metadata 

218 Metadata that originate from the samples and must be present in the 

219 templates (SampleSet) e.g. ``["gender", "age"]``. This should be metadata 

220 that is common to all the samples in a template. 

221 annotation_type 

222 A string describing the annotations passed to the annotation loading 

223 function 

224 fixed_positions 

225 TODO Why is it here? What does it do exactly? 

226 --> move it when the FaceCrop annotator is implemented correctly. 

227 memory_demanding 

228 Flag that indicates that experiments using this should not run on low-mem 

229 workers. 

230 """ 

231 if not hasattr(self, "name"): 

232 self.name = name 

233 transformer = sklearn.pipeline.make_pipeline( 

234 sklearn.pipeline.FunctionTransformer(_add_key), transformer 

235 ) 

236 super().__init__( 

237 name=name, # For FileListDatabase 

238 protocol=protocol, 

239 dataset_protocols_path=dataset_protocols_path, 

240 transformer=transformer, 

241 annotation_type=annotation_type, 

242 fixed_positions=fixed_positions, 

243 memory_demanding=memory_demanding, 

244 **kwargs, 

245 ) 

246 if self.list_file("dev", "for_matching") is None: 

247 self.score_all_vs_all = True 

248 else: 

249 self.score_all_vs_all = False 

250 

251 self.templates_metadata = [] 

252 if templates_metadata is not None: 

253 self.templates_metadata = templates_metadata 

254 

255 def list_file(self, group: str, name: str) -> TextIO: 

256 """Returns a definition file containing one sample per row. 

257 

258 Overloads ``bob.pipelines`` list_file as the group is a dir. 

259 """ 

260 

261 try: 

262 list_file = open_definition_file( 

263 search_pattern=Path(group) / (name + ".csv"), 

264 database_name=self.name, 

265 protocol=self.protocol, 

266 database_filename=self.dataset_protocols_path.name, 

267 base_dir=self.dataset_protocols_path.parent, 

268 subdir=".", 

269 ) 

270 return list_file 

271 except FileNotFoundError: 

272 return None 

273 

274 def get_reader(self, group: str, name: str) -> Iterable: 

275 """Returns an :any:`Iterable` containing :class:`Sample` or :class:`SampleSet` objects.""" 

276 key = (self.protocol, group, name) 

277 if key not in self.readers: 

278 list_file = self.list_file(group, name) 

279 self.readers[key] = None 

280 if list_file is not None: 

281 self.readers[key] = self.reader_cls( 

282 list_file=list_file, transformer=self.transformer 

283 ) 

284 

285 reader = self.readers[key] 

286 return reader 

287 

288 # cached methods should be based on protocol as well 

289 @functools.lru_cache(maxsize=None) 

290 def _background_model_samples(self, protocol): 

291 reader = self.get_reader("train", "for_background_model") 

292 if reader is None: 

293 return [] 

294 samples = list(reader) 

295 validate_bio_samples(samples) 

296 return samples 

297 

298 def background_model_samples(self): 

299 return self._background_model_samples(self.protocol) 

300 

301 def _sample_sets(self, group, name): 

302 # we need protocol as input so we can cache the result 

303 reader = self.get_reader(group, name) 

304 if reader is None: 

305 return [] 

306 # create Sample_sets from samples given their unique enroll_template_id/probe_template_id 

307 samples_grouped_by_template_id = defaultdict(list) 

308 for sample in reader: 

309 samples_grouped_by_template_id[sample.template_id].append(sample) 

310 sample_sets = [] 

311 for ( 

312 template_id, 

313 samples_for_template_id, 

314 ) in samples_grouped_by_template_id.items(): 

315 # since all samples inside one sampleset have the same subject_id, 

316 # we add that as well. 

317 samples = list(samples_for_template_id) 

318 subject_id = samples[0].subject_id 

319 metadata = { 

320 m: getattr(samples[0], m) for m in self.templates_metadata 

321 } 

322 sample_sets.append( 

323 SampleSet( 

324 samples, 

325 template_id=template_id, 

326 subject_id=subject_id, 

327 key=f"template_{template_id}", 

328 **metadata, 

329 ) 

330 ) 

331 validate_bio_samples(sample_sets) 

332 return sample_sets 

333 

334 # cached methods should be based on protocol as well 

335 @functools.lru_cache(maxsize=None) 

336 def _references(self, protocol, group): # TODO: protocol 

337 return self._sample_sets(group, "for_enrolling") 

338 

339 def references(self, group="dev"): 

340 return self._references(self.protocol, group) 

341 

342 def _add_all_references(self, sample_sets, group): 

343 references = [s.template_id for s in self.references(group)] 

344 for sample_set in sample_sets: 

345 sample_set.references = references 

346 

347 # cached methods should be based on protocol as well 

348 @functools.lru_cache(maxsize=None) 

349 def _probes(self, protocol, group): # TODO: protocol 

350 sample_sets = self._sample_sets(group, "for_probing") 

351 

352 # if there are no probes 

353 if not sample_sets: 

354 return sample_sets 

355 

356 # populate .references for each sample set 

357 matching_file = self.list_file(group, "for_matching") 

358 if matching_file is None: 

359 self._add_all_references(sample_sets, group) 

360 

361 # read the matching file 

362 else: 

363 # references is dict where key is probe_template_id and value is a 

364 # list of enroll_template_ids 

365 references = defaultdict(list) 

366 reader = csv.DictReader(matching_file) 

367 for row in reader: 

368 references[row["probe_template_id"]].append( 

369 row["enroll_template_id"] 

370 ) 

371 

372 for sample_set in sample_sets: 

373 sample_set.references = references[sample_set.template_id] 

374 

375 return sample_sets 

376 

377 def probes(self, group="dev"): 

378 return self._probes(self.protocol, group) 

379 

380 def samples(self, groups=None): 

381 """Get samples of a certain group 

382 

383 Parameters 

384 ---------- 

385 groups : :obj:`str`, optional 

386 A str or list of str to be used for filtering samples, by default None 

387 

388 Returns 

389 ------- 

390 list 

391 A list containing the samples loaded from csv files. 

392 """ 

393 groups = check_parameters_for_validity( 

394 groups, "groups", self.groups(), self.groups() 

395 ) 

396 all_samples = [] 

397 

398 if "train" in groups: 

399 all_samples.extend(self.background_model_samples()) 

400 groups.remove("train") 

401 

402 for grp in groups: 

403 all_samples.extend(_sample_sets_to_samples(self.references(grp))) 

404 all_samples.extend(_sample_sets_to_samples(self.probes(grp))) 

405 

406 # Add znorm samples. Returning znorm samples for one group of dev or 

407 # eval is enough because they are duplicated. 

408 for grp in groups: 

409 all_samples.extend(_sample_sets_to_samples(self.zprobes(grp))) 

410 break 

411 

412 # Add tnorm samples. 

413 all_samples.extend(_sample_sets_to_samples(self.treferences())) 

414 

415 return all_samples 

416 

417 def all_samples(self, groups=None): 

418 return self.samples(groups) 

419 

420 @functools.lru_cache(maxsize=None) 

421 def _zprobes(self, protocol, group): 

422 sample_sets = self._sample_sets("train", "for_znorm") 

423 if not sample_sets: 

424 return sample_sets 

425 

426 self._add_all_references(sample_sets, group) 

427 

428 return sample_sets 

429 

430 def zprobes(self, group="dev", proportion=1.0): 

431 sample_sets = self._zprobes(self.protocol, group) 

432 if not sample_sets: 

433 return sample_sets 

434 

435 sample_sets = sample_sets[: int(len(sample_sets) * proportion)] 

436 return sample_sets 

437 

438 @functools.lru_cache(maxsize=None) 

439 def _treferences(self, protocol): 

440 sample_sets = self._sample_sets("train", "for_tnorm") 

441 return sample_sets 

442 

443 def treferences(self, proportion=1.0): 

444 sample_sets = self._treferences(self.protocol) 

445 if not sample_sets: 

446 return sample_sets 

447 

448 sample_sets = sample_sets[: int(len(sample_sets) * proportion)] 

449 return sample_sets 

450 

451 

452class FileSampleLoader(BaseEstimator, TransformerMixin): 

453 """Loads file-based samples into :class:`~bob.pipelines.DelayedSample` objects. 

454 

455 Given the :attr:`sample.path` attribute,``dataset_original_directory`` and an 

456 ``extension``, this transformer will load lazily the samples from the file. 

457 The ``data_loader`` is used to load the data. 

458 

459 The resulting :class:`~bob.pipelines.DelayedSample` objects will call 

460 ``data_loader`` when their :attr:`data` is accessed. 

461 

462 This transformer will not access the data files. 

463 

464 Parameters 

465 ---------- 

466 data_loader 

467 A callable to load the sample, given the full path to the file. 

468 dataset_original_directory 

469 Path of where the raw data files are stored. This will be prepended to the 

470 ``path`` attribute of the samples. 

471 extension 

472 File extension of the raw data files. This will be appended to the ``path`` 

473 attribute of the samples. 

474 """ 

475 

476 def __init__( 

477 self, 

478 data_loader: Callable[[str], Any], 

479 dataset_original_directory: str = "", 

480 extension: str = "", 

481 **kwargs, 

482 ): 

483 super().__init__(**kwargs) 

484 self.data_loader = data_loader 

485 self.dataset_original_directory = dataset_original_directory 

486 self.extension = extension 

487 

488 def transform(self, samples: list[Sample]) -> list[DelayedSample]: 

489 """Prepares the data into :class:`~bob.pipelines.DelayedSample` objects. 

490 

491 Transforms :class:`~bob.pipelines.Sample` objects with a ``path`` attribute to 

492 :class:`~bob.pipelines.DelayedSample` with data ready to be loaded (lazily) by 

493 :attr:`data_loader`. 

494 

495 When needed (access to the :class:`DelayedSample`\\ 's :attr:`data` attribute), 

496 :attr:`data_loader` will be called with the path (extended with 

497 :attr:`original_directory` and :attr:`extension`) as argument. 

498 

499 Parameters 

500 ---------- 

501 samples 

502 :class:`~bob.pipelines.Sample` objects with their ``path`` attribute 

503 containing a path to a file to load. 

504 """ 

505 output = [] 

506 for sample in samples: 

507 path = getattr(sample, "path") 

508 delayed_sample = DelayedSample( 

509 functools.partial( 

510 self.data_loader, 

511 os.path.join( 

512 # we append ./ to path to make sure that the path is 

513 # relative to the dataset_original_directory 

514 self.dataset_original_directory, 

515 f"./{path + self.extension}", 

516 ), 

517 ), 

518 parent=sample, 

519 ) 

520 output.append(delayed_sample) 

521 return output 

522 

523 def _more_tags(self): 

524 return {"requires_fit": False} 

525 

526 

527class AnnotationsLoader(TransformerMixin, BaseEstimator): 

528 """Prepares annotations to be loaded from a path in ``delayed_attributes``. 

529 

530 Metadata loader that loads samples' annotations using 

531 :py:func:`~bob.bio.base.utils.annotations.read_annotation_file`. This assumes that 

532 the annotation files follows the same folder structure and naming as the raw data 

533 files. Although, the base location and the extension can vary from those and is 

534 specified by :attr:`annotation_directory` and :attr:`annotation_extension`. 

535 

536 Parameters 

537 ---------- 

538 annotation_directory 

539 Path where the annotations are stored. 

540 

541 annotation_extension : str 

542 Extension of the annotations. 

543 

544 annotation_type : str 

545 Annotations type passed to 

546 :func:`~bob.bio.base.utils.annotations.read_annotation_file`. 

547 

548 """ 

549 

550 def __init__( 

551 self, 

552 annotation_directory: Optional[str] = None, 

553 annotation_extension: str = ".json", 

554 annotation_type: str = "json", 

555 ): 

556 self.annotation_directory = annotation_directory 

557 self.annotation_extension = annotation_extension 

558 self.annotation_type = annotation_type 

559 

560 def transform(self, X: list[DelayedSample]) -> list[DelayedSample]: 

561 """Edits the samples to lazily load annotations files. 

562 

563 Parameters 

564 ---------- 

565 X 

566 The samples to augment. 

567 """ 

568 if self.annotation_directory is None: 

569 return None 

570 

571 annotated_samples = [] 

572 for x in X: 

573 # we use .key here because .path might not be unique for all 

574 # samples. Also, the ``bob bio annotate-samples`` command dictates 

575 # how annotations are stored. 

576 annotation_file = os.path.join( 

577 self.annotation_directory, x.key + self.annotation_extension 

578 ) 

579 

580 annotated_samples.append( 

581 DelayedSample.from_sample( 

582 x, 

583 delayed_attributes=dict( 

584 annotations=lambda: read_annotation_file( 

585 annotation_file, self.annotation_type 

586 ) 

587 ), 

588 ) 

589 ) 

590 

591 return annotated_samples 

592 

593 def _more_tags(self): 

594 return {"requires_fit": False}