Source code for beat.core.toolchain
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###################################################################################
# #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, this #
# list of conditions and the following disclaimer. #
# #
# 2. Redistributions in binary form must reproduce the above copyright notice, #
# this list of conditions and the following disclaimer in the documentation #
# and/or other materials provided with the distribution. #
# #
# 3. Neither the name of the copyright holder nor the names of its contributors #
# may be used to endorse or promote products derived from this software without #
# specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# #
###################################################################################
"""
=========
toolchain
=========
Validation for toolchains
"""
import collections
import simplejson as json
from . import prototypes
from . import schema
from . import utils
[docs]class Storage(utils.Storage):
"""Resolves paths for toolchains
Parameters:
prefix (str): Establishes the prefix of your installation.
name (str): The name of the toolchain object in the format
``<user>/<name>/<version>``.
"""
asset_type = "toolchain"
asset_folder = "toolchains"
def __init__(self, prefix, name):
if name.count("/") != 2:
raise RuntimeError("invalid toolchain name: `%s'" % name)
self.username, self.name, self.version = name.split("/")
self.fullname = name
self.prefix = prefix
path = utils.hashed_or_simple(
self.prefix, self.asset_folder, name, suffix=".json"
)
path = path[:-5]
super(Storage, self).__init__(path)
[docs]class Toolchain(object):
"""Toolchains define the dataflow in an experiment.
Parameters:
prefix (str): Establishes the prefix of your installation.
data (:py:class:`object`, Optional): The piece of data representing the
toolchain. It must validate against the schema defined for toolchains.
If a string is passed, it is supposed to be a valid path to an
toolchain in the designated prefix area. If ``None`` is passed, loads
our default prototype for toolchains.
Attributes:
storage (object): A simple object that provides information about file
paths for this toolchain
errors (list): A list containing errors found while loading this
toolchain.
data (dict): The original data for this toolchain, as loaded by our JSON
decoder.
"""
def __init__(self, prefix, data):
self._name = None
self.storage = None
self.prefix = prefix
self.errors = []
self.data = data
self._load(data)
def _load(self, data):
self._name = None
self.storage = None
if data is None: # loads prototype and validates it
self.data, self.errors = prototypes.load("toolchain")
assert not self.errors, "\n * %s" % "\n *".join(self.errors) # nosec
else:
if not isinstance(data, dict): # user has a file pointer
self._name = data
self.storage = Storage(self.prefix, self._name)
if not self.storage.exists():
self.errors.append(
"Toolchain declaration file not found: %s" % data
)
return
data = self.storage.json.path
# this runs basic validation, including JSON loading if required
self.data, self.errors = schema.validate("toolchain", data)
if self.errors:
return # don't proceed with the rest of validation
# these will be filled by the following methods
channels = []
inputs = []
outputs = []
names = {}
connections = []
loop_connections = []
self._check_datasets(channels, outputs, names)
self._check_blocks(channels, inputs, outputs, names)
self._check_loops(channels, inputs, outputs, names)
self._check_analyzers(channels, inputs, names)
self._check_connections(channels, inputs, outputs, connections)
self._check_representation(channels, names, connections, loop_connections)
def _check_datasets(self, channels, outputs, names):
"""Checks all datasets"""
for i, dataset in enumerate(self.data["datasets"]):
if dataset["name"] in names:
self.errors.append(
"/datasets/[#%d]/name: duplicated name, first "
"occurance of '%s' happened at '%s'"
% (i, dataset["name"], names[dataset["name"]])
)
else:
names[dataset["name"]] = "/datasets/%s[#%d]" % (dataset["name"], i)
channels.append(dataset["name"])
outputs += ["%s.%s" % (dataset["name"], k) for k in dataset["outputs"]]
return channels, outputs, names
def _check_blocks(self, channels, inputs, outputs, names):
"""Checks all blocks"""
for i, block in enumerate(self.data["blocks"]):
if block["name"] in names:
self.errors.append(
"/blocks/[#%d]/name: duplicated name, first "
"occurance of '%s' happened at '%s'"
% (i, block["name"], names[block["name"]])
)
else:
names[block["name"]] = "/blocks/%s[#%d]" % (block["name"], i)
inputs += ["%s.%s" % (block["name"], k) for k in block["inputs"]]
outputs += ["%s.%s" % (block["name"], k) for k in block["outputs"]]
if block["synchronized_channel"] not in channels:
self.errors.append(
"/blocks/%s[#%d]/synchronized_channel: invalid "
"synchronization channel '%s'"
% (block["name"], i, block["synchronized_channel"])
)
return channels, inputs, outputs, names
def _check_loops(self, channels, inputs, outputs, names):
"""Check all loops"""
if "loops" in self.data:
for i, loop in enumerate(self.data["loops"]):
loop_name = loop["name"]
if loop_name in names:
self.errors.append(
"/loops/[#%d]/name: duplicated name, first "
"occurance of '%s' happened at '%s'"
% (i, loop_name, names[loop_name])
)
else:
names[loop_name] = "/loops/%s[#%d]" % (loop_name, i)
for prefix in ["processor_", "evaluator_"]:
inputs += [
"%s.%s" % (loop_name, k) for k in loop[prefix + "inputs"]
]
outputs += [
"%s.%s" % (loop["name"], k) for k in loop[prefix + "outputs"]
]
if loop["synchronized_channel"] not in channels:
self.errors.append(
"/loops/%s[#%d]/synchronized_channel: "
"invalid synchronization channel '%s'"
% (loop_name, i, loop["synchronized_channel"])
)
return channels, inputs, outputs, names
def _check_analyzers(self, channels, inputs, names):
"""Checks all analyzers"""
for i, analyzer in enumerate(self.data["analyzers"]):
if analyzer["name"] in names:
self.errors.append(
"/analyzers/[#%d]/name: duplicated name, first "
"occurance of '%s' happened at '%s'"
% (i, analyzer["name"], names[analyzer["name"]])
)
else:
names[analyzer["name"]] = "/analyzers/%s[#%d]" % (analyzer["name"], i)
inputs += ["%s.%s" % (analyzer["name"], k) for k in analyzer["inputs"]]
if analyzer["synchronized_channel"] not in channels:
self.errors.append(
"/analyzers/%s[#%d]/synchronized_channel: "
"invalid synchronization channel '%s'"
% (analyzer["name"], i, analyzer["synchronized_channel"])
)
def _check_connections(self, channels, inputs, outputs, connections):
"""Checks connection consistency"""
input_endpoints = dict()
unconnected_inputs = set(inputs)
for i, connection in enumerate(self.data["connections"]):
# checks no 2 connections arrive at the same input
if connection["to"] in input_endpoints:
connected = input_endpoints[connection["to"]]
self.errors.append(
"/connection/%s->%s[#%d]/: ending on the same "
"input as /connection/%s->%s[#%d] is unsupported"
% (
connection["from"],
connection["to"],
i,
connected["from"],
connection["to"],
connected["position"],
)
)
else:
input_endpoints[connection["to"]] = {
"from": connection["from"],
"position": i,
}
if connection["from"] not in outputs:
self.errors.append(
"/connections/%s->%s[#%d]/: invalid output endpoint '%s'"
% (connection["from"], connection["to"], i, connection["from"])
)
if connection["to"] not in inputs:
self.errors.append(
"/connections/%s->%s[#%d]/: invalid input "
"endpoint '%s'"
% (connection["from"], connection["to"], i, connection["to"])
)
else:
# we now know this input is connected at least once
if connection["to"] in unconnected_inputs:
unconnected_inputs.remove(connection["to"])
if connection["channel"] not in channels:
self.errors.append(
"/connections/%s->%s[#d]/channel: invalid "
"synchronization channel '%s'"
% (connection["from"], connection["to"], connection["channel"])
)
connections.append("%s/%s" % (connection["from"], connection["to"]))
if len(unconnected_inputs) != 0:
self.errors.append(
"input(s) `%s' remain unconnected" % (", ".join(unconnected_inputs),)
)
def _check_representation(self, channels, names, connections, loop_connections):
"""Checks the representation for this toolchain"""
# all connections must exist
for connection in self.data["representation"]["connections"]:
if connection not in connections:
self.errors.append(
"/representation/connections/%s: not listed "
"on /connections" % connection
)
# all blocks must exist
for block in self.data["representation"]["blocks"]:
if block not in names:
self.errors.append(
"/representation/blocks/%s: not listed on "
"/datasets, /blocks or /analyzers" % block
)
# all channel colors must be a valid dataset name
for channel in self.data["representation"]["channel_colors"]:
if channel not in channels:
self.errors.append(
"/representation/channel_colors/%s: not a "
"dataset listed on /datasets" % channel
)
@property
def schema_version(self):
"""Returns the schema version"""
return self.data.get("schema_version", 1)
@property
def name(self):
"""The name of this object"""
return self._name or "__unnamed_toolchain__"
@name.setter
def name(self, value):
self._name = value
self.storage = Storage(self.prefix, value)
@property
def datasets(self):
"""All declared datasets"""
data = self.data["datasets"]
return dict(zip([k["name"] for k in data], data))
@property
def blocks(self):
"""All declared blocks"""
data = self.data["blocks"]
return dict(zip([k["name"] for k in data], data))
@property
def loops(self):
"""All declared loops"""
data = self.data.get("loops", {})
return dict(zip([k["name"] for k in data], data))
@property
def analyzers(self):
"""All declared analyzers"""
data = self.data["analyzers"]
return dict(zip([k["name"] for k in data], data))
[docs] def algorithm_item(self, name):
"""Returns a block, loop or analyzer matching the name given"""
item = None
for algo_items in [self.blocks, self.loops, self.analyzers]:
if name in algo_items:
item = algo_items.get(name)
break
return item
@property
def connections(self):
"""All declared connections"""
return self.data["connections"]
[docs] def dependencies(self, name):
"""Returns the block dependencies for a given block/analyzer in a set
The calculation uses all declared connections for that block/analyzer.
Dataset connections are ignored.
"""
dependencies = set()
datasets = self.datasets # property - does some work nevertheless
for conn in self.data["connections"]:
from_ = conn["from"].split(".", 1)[0]
to_ = conn["to"].split(".", 1)[0]
if to_ == name and from_ not in datasets:
dependencies.add(from_)
return dependencies
[docs] def execution_order(self):
"""Returns the execution order in an ordered dictionary with block
deps.
"""
items = [
k["name"]
for k in self.data["blocks"]
+ self.data.get("loops", [])
+ self.data["analyzers"]
]
deps = dict(zip(items, [self.dependencies(k) for k in items]))
queue = collections.OrderedDict()
while len(items) != len(queue): # while there are blocks/analyzers to treat
insert = collections.OrderedDict()
for k in items:
# if block has no executed deps
if k not in queue and deps[k].issubset(queue.keys()):
insert[k] = deps[k] # insert into queue
queue.update(insert)
return queue
[docs] def dot_diagram(
self,
title=None,
label_callback=None,
edge_callback=None,
result_callback=None,
is_layout=False,
):
"""Returns a dot diagram representation of the toolchain
Parameters:
title (str): A title for the generated drawing. If ``None`` is given,
then prints out the toolchain name.
label_callback (:std:term:`function`): A python function that is
called back each time a label needs to be inserted into a block.
The prototype of this function is ``label_callback(type, name)``.
``type`` may be one of ``dataset``, ``block`` or ``analyzer``. This
callback is used by the experiment class to complement diagram
information before plotting.
edge_callback (:std:term:`function`): A python function that is
called back each time an edge needs to be inserted into the graph.
The prototype of this function is ``edge_callback(start)``.
``start`` is the name of the starting point for the connection, it
should determine the dataformat for the connection.
result_callback (:std:term:`function`): A function to draw ports on
analyzer blocks. The prototype of this function is
``result_callback(name, color)``.
Returns
graphviz.Digraph: With the graph ready for show-time.
"""
# the representation for channel colors must be complete
all_colors = set(self.data["representation"]["channel_colors"].keys())
channels = set([k["name"] for k in self.data["datasets"]])
missing = channels - all_colors
if missing:
raise KeyError(
"/representation/channel_colors/%s: is missing "
"from object descriptor - fix it before drawing" % ",".join(missing)
)
label_callback = None
if is_layout:
label_callback = label_callback or (lambda x, y: "%s" % y)
else:
label_callback = label_callback or (lambda x, y: "<b><u>%s</u></b>" % y)
edge_callback = edge_callback or (lambda x: "")
result_callback = result_callback or (lambda x: [])
title = title or "Toolchain: %s" % self.name
channel_colors = self.data["representation"]["channel_colors"]
from graphviz import Digraph
from .drawing import make_label as make_drawing_label
from .drawing import make_layout_label
fontname = "Helvetica"
fontsize = "12"
make_label = make_layout_label if is_layout else make_drawing_label
root = Digraph(self.name)
splineType = "line" if is_layout else "polyline"
# default is 0.25, but it seems 0.5 is needed to keep everything separated
# when the layout is parsed by beat.editor
nodesep = "0.5" if is_layout else "0.25"
root.attr(
"graph",
rankdir="LR",
compound="true",
splines=splineType,
labelloc="t",
label=title,
fontname=fontname,
fontsize=str(3 * int(fontsize)),
nodesep=nodesep,
)
datasets = Digraph("dataset_cluster")
datasets.attr("graph", rank="same", label="datasets")
for d, info in self.datasets.items():
datasets.node(
d,
label=make_label(
[], label_callback("dataset", d), info["outputs"], channel_colors[d]
),
shape="none",
fontsize=fontsize,
fontname=fontname,
)
root.subgraph(datasets)
def _draw_block(graph, n, info):
color = channel_colors[info["synchronized_channel"]]
if "outputs" in info:
label = make_label(
info["inputs"], label_callback("block", n), info["outputs"], color
)
else:
label = make_label(
info["inputs"],
label_callback("analyzer", n),
result_callback(n),
color,
)
root.node(
n, label=label, shape="none", fontsize=fontsize, fontname=fontname
)
for c in [k for k in self.connections if k["to"].startswith(n + ".")]:
edge_color = channel_colors[c["channel"]]
if is_layout:
label = edge_callback(c["from"])
root.body.append(
'\t%s:e -> %s:w [color="%s"]'
% (
c["from"].replace(".", ":output_"),
c["to"].replace(".", ":input_"),
edge_color,
)
)
else:
root.edge(
c["from"].replace(".", ":output_"),
c["to"].replace(".", ":input_"),
color=edge_color,
label=edge_callback(c["from"]),
fontcolor=color,
fontsize=fontsize,
fontname=fontname,
)
for name, info in self.blocks.items():
_draw_block(root, name, info)
analyzers = Digraph("analyzer_cluster")
analyzers.attr("graph", rank="same", label="analyzers")
for name, info in self.analyzers.items():
_draw_block(analyzers, name, info)
root.subgraph(analyzers)
return root
@property
def valid(self):
"""A boolean that indicates if this toolchain is valid or not"""
return not bool(self.errors)
@property
def description(self):
"""The short description for this object"""
return self.data.get("description", None)
@description.setter
def description(self, value):
"""Sets the short description for this object"""
self.data["description"] = value
@property
def documentation(self):
"""The full-length description for this object"""
if not self._name:
raise RuntimeError("toolchain has no name")
if self.storage.doc.exists():
return self.storage.doc.load()
return None
@documentation.setter
def documentation(self, value):
"""Sets the full-length description for this object"""
if not self._name:
raise RuntimeError("toolchain has no name")
if hasattr(value, "read"):
self.storage.doc.save(value.read())
else:
self.storage.doc.save(value)
[docs] def hash(self):
"""Returns the hexadecimal hash for its declaration"""
if not self._name:
raise RuntimeError("toolchain has no name")
return self.storage.hash()
[docs] def json_dumps(self, indent=4):
"""Dumps the JSON declaration of this object in a string
Parameters:
indent (int): The number of indentation spaces at every indentation
level
Returns:
str: The JSON representation for this object
"""
return json.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder)
def __str__(self):
return self.json_dumps()
[docs] def write(self, storage=None):
"""Writes contents to prefix location
Parameters:
storage (:py:class:`.Storage`, Optional): If you pass a new storage,
then this object will be written to that storage point rather than
its default.
"""
if storage is None:
if not self._name:
raise RuntimeError("toolchain has no name")
storage = self.storage # overwrite
storage.save(str(self), self.description)