Source code for beat.backend.python.execution.loop

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################


"""
========
executor
========

A class that can setup and execute loop algorithm blocks on the backend
"""

import json
import logging
import os

import zmq

from ..algorithm import Algorithm
from ..dataformat import DataFormat
from ..exceptions import RemoteException
from ..helpers import AccessMode
from ..helpers import create_inputs_from_configuration
from ..helpers import create_outputs_from_configuration
from .helpers import make_data_format

logger = logging.getLogger(__name__)


[docs]class LoopChannel(object): """The LoopChannel class is a direct communication link between a loop using algorithm and the loop itself """ def __init__(self, socket): """Constructor Parameters: socket (:py:class:`zmq.Socket`): Socket for zmq communication """ self.socket = socket
[docs] def setup(self, algorithm, prefix): """Setup the channel internals Parameters: algorithm (:py:class:`.algorithm.Algorithm`) : algorithm for which the communication channel is setup. prefix (str) : Folder were the prefix is located. """ request_format_name = algorithm.loop_map["request"] self.request_data_format = DataFormat(prefix, request_format_name) answer_format_name = algorithm.loop_map["answer"] self.answer_data_format = DataFormat(prefix, answer_format_name)
[docs] def validate(self, hypothesis): """This method will request validation for the hypothesis passed in parameter. Parameters: hypothesis (dict) : Computed hypothesis that must be validated by the loop algorithm. """ data = make_data_format(hypothesis, self.request_data_format) self.socket.send_string("val", zmq.SNDMORE) self.socket.send(data.pack()) answer = self.socket.recv() if answer == b"err": kind = self.socket.recv() message = self.socket.recv() raise RemoteException(kind, message) packed = self.socket.recv() data_format = self.answer_data_format.type() data = data_format.unpack(packed) return (answer == b"True", data)
[docs]class LoopExecutor(object): """Executors runs the code given an execution block information Parameters: socket (zmq.Socket): A pre-connected socket to send and receive messages from. directory (str): The path to a directory containing all the information required to run the user experiment. dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping dataformat names to loaded dataformats. This parameter is optional and, if passed, may greatly speed-up database loading times as dataformats that are already loaded may be re-used. If you use this parameter, you must guarantee that the cache is refreshed as appropriate in case the underlying dataformats change. database_cache (:py:class:`dict`, Optional): A dictionary mapping database names to loaded databases. This parameter is optional and, if passed, may greatly speed-up database loading times as databases that are already loaded may be re-used. If you use this parameter, you must guarantee that the cache is refreshed as appropriate in case the underlying databases change. library_cache (:py:class:`dict`, Optional): A dictionary mapping library names to loaded libraries. This parameter is optional and, if passed, may greatly speed-up library loading times as libraries that are already loaded may be re-used. If you use this parameter, you must guarantee that the cache is refreshed as appropriate in case the underlying libraries change.""" def __init__( self, message_handler, directory, dataformat_cache=None, database_cache=None, library_cache=None, cache_root="/cache", db_socket=None, ): self._runner = None self.algorithm = None self.output_list = None self.db_socket = db_socket self.configuration = os.path.join(directory, "configuration.json") with open(self.configuration, "r") as f: conf_data = f.read() self.data = json.loads(conf_data)["loop"] self.prefix = os.path.join(directory, "prefix") # Temporary caches, if the user has not set them, for performance database_cache = database_cache if database_cache is not None else {} dataformat_cache = dataformat_cache if dataformat_cache is not None else {} library_cache = library_cache if library_cache is not None else {} # Load the algorithm self.algorithm = Algorithm( self.prefix, self.data["algorithm"], dataformat_cache, library_cache ) if not self.algorithm.valid: logger.warning( "Failed to load algorithm:\n%s" % "\n".join(self.algorithm.errors) ) if db_socket: db_access_mode = AccessMode.REMOTE databases = None else: db_access_mode = AccessMode.LOCAL databases = database_cache self.input_list, self.data_loaders = create_inputs_from_configuration( self.data, self.algorithm, self.prefix, cache_root, cache_access=AccessMode.LOCAL, db_access=db_access_mode, socket=self.db_socket, databases=databases, ) self.output_list, _ = create_outputs_from_configuration( self.data, self.algorithm, self.prefix, cache_root, input_list=self.input_list, data_loaders=self.data_loaders, ) self.message_handler = message_handler self.message_handler.setup(self.algorithm, self.prefix) self.message_handler.set_executor(self) @property def runner(self): """Returns the algorithm runner This property allows for lazy loading of the runner """ if self._runner is None: self._runner = self.algorithm.runner() return self._runner
[docs] def setup(self): """Sets up the algorithm to start processing""" retval = self.runner.setup(self.data["parameters"]) logger.debug("User loop is setup: {}".format(retval)) return retval
[docs] def prepare(self): """Prepare the algorithm""" retval = self.runner.prepare(self.data_loaders) logger.debug("User loop is prepared: {}".format(retval)) return retval
[docs] def process(self): """Executes the user algorithm code using the current interpreter.""" self.message_handler.start()
[docs] def validate(self, hypothesis): """Executes the loop validation code""" is_valid, answer = self.runner.validate(hypothesis) logger.debug("User loop has validated: {}\n{}".format(is_valid, answer)) return is_valid, answer
[docs] def write(self, processor_output_name, end_data_index=None): """Write the loop output""" retval = self.runner.write( self.output_list, processor_output_name, end_data_index ) logger.debug("User loop wrote output: {}".format(retval)) return retval
[docs] def read(self): """Move input to next element and make it read""" main_group = self.input_list.main_group main_group.restricted_access = False main_group.next() main_group.restricted_access = True return self.runner.read(self.input_list)
@property def address(self): """Address of the message handler""" return self.message_handler.address @property def valid(self): """A boolean that indicates if this executor is valid or not""" return not bool(self.errors)
[docs] def wait(self): """Wait for the message handle to finish""" try: self.message_handler.join() except RuntimeError: # tried to join the handler before it has started. pass self.message_handler = None
[docs] def close(self): """Close all outputs""" for output in self.output_list: output.close()