Source code for beat.cmdline.databases

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################


import glob
import logging
import os
import random

import click
import simplejson
import zmq

from beat.core import dock
from beat.core import inputs
from beat.core import utils
from beat.core.data import RemoteDataSource
from beat.core.database import Database
from beat.core.database import Storage
from beat.core.hash import hashDataset
from beat.core.hash import toPath
from beat.core.utils import NumpyJSONEncoder

from . import commands
from . import common
from .click_helper import AliasedGroup
from .click_helper import AssetCommand
from .click_helper import AssetInfo
from .decorators import raise_on_error

logger = logging.getLogger(__name__)


CMD_DB_INDEX = "index"
CMD_VIEW_OUTPUTS = "databases_provider"


# ----------------------------------------------------------


[docs]def load_database_sets(configuration, database_name): # Process the name of the database parts = database_name.split("/") if len(parts) == 2: db_name = os.path.join(*parts[:2]) protocol_filter = None set_filter = None elif len(parts) == 3: db_name = os.path.join(*parts[:2]) protocol_filter = parts[2] set_filter = None elif len(parts) == 4: db_name = os.path.join(*parts[:2]) protocol_filter = parts[2] set_filter = parts[3] else: logger.error( "Database specification should have the format " "`<database>/<version>/[<protocol>/[<set>]]', the value " "you passed (%s) is not valid", database_name, ) return (None, None) # Load the dataformat dataformat_cache = {} database = Database(configuration.path, db_name, dataformat_cache) if not database.valid: logger.error("Failed to load the database `%s':", db_name) for e in database.errors: logger.error(" * %s", e) return (None, None, None) # Filter the protocols protocols = database.protocol_names if protocol_filter is not None: if protocol_filter not in protocols: logger.error( "The database `%s' does not have the protocol `%s' - " "choose one of `%s'", db_name, protocol_filter, ", ".join(protocols), ) return (None, None, None) protocols = [protocol_filter] # Filter the sets loaded_sets = [] for protocol_name in protocols: sets = database.set_names(protocol_name) if set_filter is not None: if set_filter not in sets: logger.error( "The database/protocol `%s/%s' does not have the " "set `%s' - choose one of `%s'", db_name, protocol_name, set_filter, ", ".join(sets), ) return (None, None, None) sets = [z for z in sets if z == set_filter] loaded_sets.extend( [ (protocol_name, set_name, database.set(protocol_name, set_name)) for set_name in sets ] ) return (db_name, database, loaded_sets)
# ----------------------------------------------------------
[docs]def start_db_container( configuration, cmd, host, db_name, protocol_name, set_name, database, db_set, excluded_outputs=None, uid=None, db_root=None, ): input_list = inputs.InputList() input_group = inputs.InputGroup(set_name, restricted_access=False) input_list.add(input_group) db_configuration = {"inputs": {}, "channel": set_name} if uid is None: uid = os.getuid() db_configuration["datasets_uid"] = uid if db_root is not None: db_configuration["datasets_root_path"] = db_root for output_name, dataformat_name in db_set["outputs"].items(): if excluded_outputs is not None and output_name in excluded_outputs: continue dataset_hash = hashDataset(db_name, protocol_name, set_name) db_configuration["inputs"][output_name] = dict( database=db_name, protocol=protocol_name, set=set_name, output=output_name, channel=set_name, hash=dataset_hash, path=toPath(dataset_hash, ".db"), ) db_tempdir = utils.temporary_directory() with open(os.path.join(db_tempdir, "configuration.json"), "wt") as f: simplejson.dump(db_configuration, f, indent=4) tmp_prefix = os.path.join(db_tempdir, "prefix") if not os.path.exists(tmp_prefix): os.makedirs(tmp_prefix) database.export(tmp_prefix) if db_root is None: json_path = os.path.join(tmp_prefix, "databases", db_name + ".json") with open(json_path, "r") as f: db_data = simplejson.load(f) database_path = db_data["root_folder"] db_data["root_folder"] = os.path.join("/databases", db_name) with open(json_path, "w") as f: simplejson.dump(db_data, f, indent=4) environment = database.environment if environment: environment_name = utils.build_env_name(environment) try: db_envkey = host.dbenv2docker(environment_name) except KeyError: raise RuntimeError( "Environment {} not found for the database '{}' " "- available environments are {}".format( environment_name, db_name, ", ".join(host.db_environments.keys()) ) ) else: try: db_envkey = host.db2docker([db_name]) except Exception: raise RuntimeError( "No environment found for the database `%s' " "- available environments are %s" % (db_name, ", ".join(host.db_environments.keys())) ) logger.info("Indexing using {}".format(db_envkey)) # Creation of the container # Note: we only support one databases image loaded at the same time CONTAINER_PREFIX = "/beat/prefix" CONTAINER_CACHE = "/beat/cache" database_port = random.randint(51000, 60000) # nosec just getting a free port if cmd == CMD_VIEW_OUTPUTS: db_cmd = [ cmd, "0.0.0.0:{}".format(database_port), CONTAINER_PREFIX, CONTAINER_CACHE, ] else: db_cmd = [ cmd, CONTAINER_PREFIX, CONTAINER_CACHE, db_name, protocol_name, set_name, ] databases_container = host.create_container(db_envkey, db_cmd) databases_container.uid = uid if cmd == CMD_VIEW_OUTPUTS: databases_container.add_port(database_port, database_port, host_address=host.ip) databases_container.add_volume(db_tempdir, "/beat/prefix") databases_container.add_volume(configuration.cache, "/beat/cache") else: databases_container.add_volume(tmp_prefix, "/beat/prefix") databases_container.add_volume( configuration.cache, "/beat/cache", read_only=False ) # Specify the volumes to mount inside the container if "datasets_root_path" not in db_configuration: databases_container.add_volume( database_path, os.path.join("/databases", db_name) ) else: databases_container.add_volume( db_configuration["datasets_root_path"], db_configuration["datasets_root_path"], ) # Start the container host.start(databases_container) if cmd == CMD_VIEW_OUTPUTS: # Communicate with container zmq_context = zmq.Context() db_socket = zmq_context.socket(zmq.PAIR) db_address = "tcp://{}:{}".format(host.ip, database_port) db_socket.connect(db_address) for output_name, dataformat_name in db_set["outputs"].items(): if excluded_outputs is not None and output_name in excluded_outputs: continue data_source = RemoteDataSource() data_source.setup( db_socket, output_name, dataformat_name, configuration.path ) input_ = inputs.Input( output_name, database.dataformats[dataformat_name], data_source ) input_group.add(input_) return (databases_container, db_socket, zmq_context, input_list) return databases_container
# ----------------------------------------------------------
[docs]def pull_impl(webapi, prefix, names, force, indentation, format_cache): """Copies databases (and required dataformats) from the server. Parameters: webapi (object): An instance of our WebAPI class, prepared to access the BEAT server of interest prefix (str): A string representing the root of the path in which the user objects are stored names (:py:class:`list`): A list of strings, each representing the unique relative path of the objects to retrieve or a list of usernames from which to retrieve objects. If the list is empty, then we pull all available objects of a given type. If no user is set, then pull all public objects of a given type. force (bool): If set to ``True``, then overwrites local changes with the remotely retrieved copies. indentation (int): The indentation level, useful if this function is called recursively while downloading different object types. This is normally set to ``0`` (zero). format_cache (dict): A dictionary containing all dataformats already downloaded. Returns: int: Indicating the exit status of the command, to be reported back to the calling process. This value should be zero if everything works OK, otherwise, different than zero (POSIX compliance). """ from .dataformats import pull_impl as dataformats_pull from .protocoltemplates import pull_impl as protocoltemplates_pull status, names = common.pull( webapi, prefix, "database", names, ["declaration", "code", "description"], force, indentation, ) # A database object cannot properly loaded if its protocol templates are # missing, therefore, we must use lower level access. protocol_templates = set() for name in names: db = Storage(prefix, name) declaration, _, _ = db.load() declaration = simplejson.loads(declaration) version = declaration.get("schema_version", 1) if version > 1: for protocol in declaration["protocols"]: protocol_templates.add(protocol["template"]) pt_status = protocoltemplates_pull( webapi, prefix, protocol_templates, force, indentation + 2, format_cache ) # see what dataformats one needs to pull dataformats = [] for name in names: obj = Database(prefix, name) dataformats.extend(obj.dataformats.keys()) # downloads any formats to which we depend on df_status = dataformats_pull( webapi, prefix, dataformats, force, indentation + 2, format_cache ) return status + df_status + pt_status
# ----------------------------------------------------------
[docs]def index_outputs(configuration, names, uid=None, db_root=None, docker=False): names = common.make_up_local_list(configuration.path, "database", names) retcode = 0 if docker: host = dock.Host(raise_on_errors=False) for database_name in names: logger.info("Indexing database %s...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: if not docker: try: view = database.view(protocol_name, set_name) except SyntaxError as error: logger.error("Failed to load the database `%s':", database_name) logger.error(" * Syntax error: %s", error) view = None if view is None: retcode += 1 continue dataset_hash = hashDataset(db_name, protocol_name, set_name) try: view.index( os.path.join(configuration.cache, toPath(dataset_hash, ".db")) ) except RuntimeError as error: logger.error("Failed to load the database `%s':", database_name) logger.error(" * Runtime error %s", error) retcode += 1 continue else: databases_container = start_db_container( configuration, CMD_DB_INDEX, host, db_name, protocol_name, set_name, database, db_set, uid=uid, db_root=db_root, ) status = host.wait(databases_container) logs = host.logs(databases_container) host.rm(databases_container) if status != 0: logger.error("Error occurred: %s", logs) retcode += 1 return retcode
# ----------------------------------------------------------
[docs]def list_index_files(configuration, names): names = common.make_up_local_list(configuration.path, "database", names) retcode = 0 for database_name in names: logger.info("Listing database %s indexes...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: dataset_hash = hashDataset(db_name, protocol_name, set_name) index_filename = toPath(dataset_hash) basename = os.path.splitext(index_filename)[0] for g in glob.glob(basename + ".*"): logger.info(g) return retcode
# ----------------------------------------------------------
[docs]def delete_index_files(configuration, names): names = common.make_up_local_list(configuration.path, "database", names) retcode = 0 for database_name in names: logger.info("Deleting database %s indexes...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: for output_name in db_set["outputs"].keys(): dataset_hash = hashDataset(db_name, protocol_name, set_name) index_filename = toPath(dataset_hash) basename = os.path.join( configuration.cache, os.path.splitext(index_filename)[0] ) for g in glob.glob(basename + ".*"): logger.info("removing `%s'...", g) os.unlink(g) common.recursive_rmdir_if_empty( os.path.dirname(basename), configuration.cache ) return retcode
# ----------------------------------------------------------
[docs]def view_outputs( configuration, dataset_name, excluded_outputs=None, uid=None, db_root=None, docker=False, ): def data_to_json(data, indent): value = common.stringify(data.as_dict()) value = ( simplejson.dumps(value, indent=4, cls=NumpyJSONEncoder) .replace('"BEAT_LIST_DELIMITER[', "[") .replace(']BEAT_LIST_DELIMITER"', "]") .replace('"...",', "...") .replace('"BEAT_LIST_SIZE(', "(") .replace(')BEAT_LIST_SIZE"', ")") ) return ("\n" + " " * indent).join(value.split("\n")) # Load the infos about the database set (db_name, database, sets) = load_database_sets(configuration, dataset_name) if (database is None) or (len(sets) != 1): return 1 (protocol_name, set_name, db_set) = sets[0] if excluded_outputs is not None: excluded_outputs = map(lambda x: x.strip(), excluded_outputs.split(",")) # Setup the view so the outputs can be used if not docker: view = database.view(protocol_name, set_name) if view is None: return 1 dataset_hash = hashDataset(db_name, protocol_name, set_name) view.setup( os.path.join(configuration.cache, toPath(dataset_hash, ".db")), pack=False ) input_group = inputs.InputGroup(set_name, restricted_access=False) for output_name, dataformat_name in db_set["outputs"].items(): if excluded_outputs is not None and output_name in excluded_outputs: continue input = inputs.Input( output_name, database.dataformats[dataformat_name], view.data_sources[output_name], ) input_group.add(input) else: host = dock.Host(raise_on_errors=False) (databases_container, db_socket, zmq_context, input_list) = start_db_container( configuration, CMD_VIEW_OUTPUTS, host, db_name, protocol_name, set_name, database, db_set, excluded_outputs=excluded_outputs, uid=uid, db_root=db_root, ) input_group = input_list.group(set_name) retvalue = 0 # Display the data try: previous_start = -1 while input_group.hasMoreData(): input_group.next() start = input_group.data_index end = input_group.data_index_end if start != previous_start: print(80 * "-") print("FROM %d TO %d" % (start, end)) whole_inputs = [ input_ for input_ in input_group if input_.data_index == start and input_.data_index_end == end ] for input in whole_inputs: label = " - " + str(input.name) + ": " print(label + data_to_json(input.data, len(label))) previous_start = start selected_inputs = [ input_ for input_ in input_group if input_.data_index == input_group.first_data_index and (input_.data_index != start or input_.data_index_end != end) ] grouped_inputs = {} for input_ in selected_inputs: key = (input_.data_index, input_.data_index_end) if key not in grouped_inputs: grouped_inputs[key] = [] grouped_inputs[key].append(input) sorted_keys = sorted(grouped_inputs.keys()) for key in sorted_keys: print print(" FROM %d TO %d" % key) for input in grouped_inputs[key]: label = " - " + str(input.name) + ": " print(label + data_to_json(input.data, len(label))) except Exception as e: logger.error("Failed to retrieve the next data: %s", e) retvalue = 1 if docker: host.kill(databases_container) status = host.wait(databases_container) logs = host.logs(databases_container) host.rm(databases_container) if status != 0: logger.error("Docker error: %s", logs) return retvalue
# ----------------------------------------------------------
[docs]class DatabaseCommand(AssetCommand): asset_info = AssetInfo( asset_type="database", diff_fields=["declaration", "code", "description"], push_fields=["name", "declaration", "code", "description"], )
@click.group(cls=AliasedGroup) @click.pass_context def databases(ctx): """Database commands""" CMD_LIST = [ "list", "path", "edit", "check", "status", "create", "version", ("rm", "rm_local"), "diff", "push", ] commands.initialise_asset_commands(databases, CMD_LIST, DatabaseCommand) @databases.command() @click.argument("db_names", nargs=-1) @click.option( "--force", help="Performs operation regardless of conflicts", is_flag=True ) @click.pass_context @raise_on_error def pull(ctx, db_names, force): """Downloads the specified databases from the server. $ beat databases pull [<name>]... <name>: Database name formatted as "<database>/<version>" """ configuration = ctx.meta["config"] with common.make_webapi(configuration) as webapi: return pull_impl(webapi, configuration.path, db_names, force, 0, {}) @databases.command() @click.argument("db_names", nargs=-1) @click.option( "--list", help="List index files matching output if they exist", is_flag=True ) @click.option( "--delete", help="Delete index files matching output if they " "exist (also, recursively deletes empty directories)", is_flag=True, ) @click.option("--checksum", help="Checksums index files", is_flag=True, default=True) @click.option("--uid", type=click.INT, default=None) @click.option("--db-root", help="Database root") @click.option("--docker", is_flag=True) @click.pass_context @raise_on_error def index(ctx, db_names, list, delete, checksum, uid, db_root, docker): """Indexes all outputs (of all sets) of a database. To index the contents of a database $ beat databases index simple/1 To index the contents of a protocol on a database $ beat databases index simple/1/double To index the contents of a set in a protocol on a database $ beat databases index simple/1/double/double """ configuration = ctx.meta["config"] code = 1 if list: code = list_index_files(configuration, db_names) elif delete: code = delete_index_files(configuration, db_names) elif checksum: code = index_outputs( configuration, db_names, uid=uid, db_root=db_root, docker=docker ) return code @databases.command() @click.argument("set_name", nargs=1) @click.option("--exclude", help="When viewing, excludes this output", default=None) @click.option("--uid", type=click.INT, default=None) @click.option("--db-root", help="Database root") @click.option("--docker", is_flag=True) @click.pass_context @raise_on_error def view(ctx, set_name, exclude, uid, db_root, docker): """View the data of the specified dataset. To view the contents of a specific set $ beat databases view simple/1/protocol/set """ configuration = ctx.meta["config"] if exclude is not None: return view_outputs( configuration, set_name, exclude, uid=uid, db_root=db_root, docker=docker ) return view_outputs( configuration, set_name, uid=uid, db_root=db_root, docker=docker )