#!/usr/bin/env python
# encoding: utf-8
import os
from bob.db.base import utils
from .models import *
from .driver import Interface
INFO = Interface()
SQLITE_FILE = INFO.files()[0]
import bob.db.base
[docs]class Database(bob.db.base.SQLiteDatabase):
""" Class representing the database
See parent class `:py:class:bob.db.base.SQLiteDatabase` for more details ...
Attributes
----------
original_directory: str
Path where the database is stored
original_extension: str
Extension of files in the database
annotation_directory: str
Path where the annotations are stored
annotation_extension: str
Extension of anootation files
"""
def __init__(self,
original_directory=None,
original_extension=None,
annotation_directory=None,
annotation_extension=None,
protocol='mc-rgb'):
""" Init function
Parameters
----------
original_directory: str
Path where the database is stored
original_extension: str
Extension of files in the database
annotation_directory: str
Path where the annotations are stored
annotation_extension: str
Extension of anootation files
"""
super(Database, self).__init__(SQLITE_FILE, File, original_directory, original_extension)
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
self.protocol = protocol
@property
def modalities(self):
return ['rgb', 'nir', 'depth']
[docs] def groups(self, protocol=None):
"""Returns the names of all registered groups
Parameters
----------
protocol: str
ignored, since the group are the same across protocols.
"""
return ProtocolPurpose.group_choices
[docs] def clients(self, protocol=None, groups=None):
"""Returns a set of clients for the specific query by the user.
Parameters
----------
protocol: str
Ignored since the clients are identical for all protocols.
groups: str or tuple of str
The groups to which the clients belong ('world', 'dev', 'eval').
Returns:
lst:
list containing clients which have the given properties.
"""
groups = self.check_parameters_for_validity(groups, "group", self.groups())
retval = []
if "world" in groups:
q = self.query(Client).filter(Client.group == 'world')
retval += list(q)
if 'dev' in groups:
q = self.query(Client).filter(Client.group == 'dev')
retval += list(q)
if 'eval' in groups:
q = self.query(Client).filter(Client.group == 'eval')
retval += list(q)
return retval
[docs] def models(self, protocol=None, groups=None):
"""Returns a set of models for the specific query by the user.
Parameters
----------
protocol
Ignored since the models are identical for all protocols.
groups
The groups to which the subjects attached to the models belong
Returns
-------
lst:
A list containing all the models which have the given properties.
"""
return self.clients(protocol, groups)
[docs] def model_ids(self, protocol=None, groups=None):
"""Returns a set of models ids for the specific query by the user.
Parameters
----------
protocol
Ignored since the models are identical for all protocols.
groups
The groups to which the subjects attached to the models belong
Returns
-------
lst:
A list containing all the models ids which have the given properties.
"""
return [model.id for model in self.models(protocol, groups)]
[docs] def client(self, id):
"""Returns the client object of the specified id.
Parameters
----------
id: int
The client id.
Raises
------
Error:
if the client does not exist.
"""
return self.query(Client).filter(Client.id == id).one()
[docs] def protocol_names(self):
"""Returns all registered protocol names
"""
l = self.protocols()
retval = [str(k.name) for k in l]
return retval
[docs] def protocols(self):
"""Returns all registered protocols
"""
return list(self.query(Protocol))
[docs] def purposes(self):
"""Returns purposes
"""
return ProtocolPurpose.purpose_choices
[docs] def objects(self, protocol=None, purposes=None, model_ids=None, groups=None, modality=None):
"""Returns a set of Files for the specific query by the user.
Parameters
----------
protocol: str
One of the FARGO protocols.
purposes: str or tuple of str
The purposes required to be retrieved ('enroll', 'probe', 'train') or a tuple
with several of them. If 'None' is given (this is the default), it is
considered the same as a tuple with all possible values. This field is
ignored for the data from the "world" group.
model_ids: int or tuple of int
Only retrieves the files for the provided list of model ids.
If 'None' is given, no filter over the model_ids is performed.
groups: str or tuple of str
One of the groups ('dev', 'eval', 'world') or a tuple with several of them.
If 'None' is given (this is the default), it is considered the same as a
tuple with all possible values.
modality: str or tuple
One of the three modalities 'rgb', 'nir' and 'depth'
Returns
-------
lst:
A list of files which have the given properties.
"""
from sqlalchemy import and_
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
modality = self.check_parameters_for_validity(modality, "modality", self.modalities)
import collections
if(model_ids is None):
model_ids = ()
elif(not isinstance(model_ids, collections.Iterable)):
model_ids = (model_ids,)
# Now query the database
retval = []
if 'world' in groups:
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol)
q = q.filter(Client.group == 'world').filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.group == 'world'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id)
q = q.filter(File.modality.in_(modality))
retval += list(q)
if ('dev' in groups or 'eval' in groups):
if('enroll' in purposes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.group.in_(groups), ProtocolPurpose.purpose == 'enroll'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id)
retval += list(q)
# dense probing -> don't filter by model_ids
if('probe' in purposes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.group.in_(groups), ProtocolPurpose.purpose == 'probe'))
q = q.order_by(File.client_id)
retval += list(q)
# remove duplicates and sort the list
rv = list(set(retval))
rv.sort()
return rv