Source code for beat.web.databases.api

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.web module of the BEAT platform.              #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

import json
import logging
import os
from pathlib import PurePath

from rest_framework import exceptions as drf_exceptions
from rest_framework import permissions as drf_permissions
from rest_framework import views
from rest_framework.response import Response

from ..common import is_true
from ..common import permissions as beat_permissions
from ..common.api import ListCreateBaseView
from ..common.utils import ensure_html
from ..dataformats.serializers import ReferencedDataFormatSerializer
from .models import Database
from .models import DatabaseSetTemplate
from .serializers import DatabaseCreationSerializer
from .serializers import DatabaseSerializer

logger = logging.getLogger(__name__)


# ----------------------------------------------------------


[docs]def database_to_json(database, request_user, fields_to_return): # Prepare the response result = {} if "name" in fields_to_return: result["name"] = database.fullname() if "version" in fields_to_return: result["version"] = database.version if "last_version" in fields_to_return: latest = ( Database.objects.for_user(request_user, True) .filter(name=database.name) .order_by("-version")[:1] .first() ) result["last_version"] = database.version == latest.version if "short_description" in fields_to_return: result["short_description"] = database.short_description if "description" in fields_to_return: result["description"] = database.description if "previous_version" in fields_to_return: result["previous_version"] = ( database.previous_version.fullname() if database.previous_version is not None else None ) if "creation_date" in fields_to_return: result["creation_date"] = database.creation_date.isoformat(" ") if "hash" in fields_to_return: result["hash"] = database.hash if "accessibility" in fields_to_return: result["accessibility"] = database.accessibility_for(request_user) return result
[docs]def clean_paths(declaration): pseudo_path = "/path_to_db_folder" def _clean_path(item): parameters = item.get("parameters", {}) if "annotations" not in parameters: return ppath = PurePath(parameters["annotations"]) if not ppath.is_absolute(): return cleaned_folder = ppath.parts[-2:] parameters["annotations"] = os.path.join(pseudo_path, *cleaned_folder) root_folder = declaration["root_folder"] cleaned_folder = os.path.basename(os.path.normpath(root_folder)) declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder) for protocol in declaration["protocols"]: # sets is a key only available in the V1 version of databases if "sets" in protocol: for set_ in protocol["sets"]: _clean_path(set_) else: for view in protocol["views"].values(): _clean_path(view) return declaration
# ----------------------------------------------------------
[docs]class ListCreateDatabasesView(ListCreateBaseView): """ Read/Write end point that list the database available to a user and allows the creation of new databases only to platform administrator """ model = Database permission_classes = [beat_permissions.IsAdminOrReadOnly] serializer_class = DatabaseSerializer writing_serializer_class = DatabaseCreationSerializer namespace = "api_databases"
[docs] def get_queryset(self): user = self.request.user return self.model.objects.for_user(user, True)
[docs] def get(self, request, *args, **kwargs): fields_to_return = self.get_serializer_fields(request) limit_to_latest_versions = is_true( request.query_params.get("latest_versions", False) ) all_databases = self.get_queryset().order_by("name") if limit_to_latest_versions: all_databases = self.model.filter_latest_versions(all_databases) all_databases.sort(key=lambda x: x.fullname()) serializer = self.get_serializer( all_databases, many=True, fields=fields_to_return ) return Response(serializer.data)
# ----------------------------------------------------------
[docs]class ListTemplatesView(views.APIView): """ List all templates available """ permission_classes = [drf_permissions.AllowAny]
[docs] def get(self, request): result = {} # Retrieve all the protocols available to user databases = Database.objects.for_user(request.user, True) databases = Database.filter_latest_versions(databases) for set_template in ( DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases) .distinct() .order_by("name") ): (db_template, dataset) = set_template.name.split("__") if db_template not in result: result[db_template] = {"templates": {}, "sets": []} result[db_template]["templates"][dataset] = map( lambda x: x.name, set_template.outputs.order_by("name") ) known_sets = [] for db_set in set_template.sets.iterator(): if db_set.name not in known_sets: result[db_template]["sets"].append( {"name": db_set.name, "template": dataset, "id": db_set.id} ) known_sets.append(db_set.name) for name, entry in result.items(): entry["sets"].sort(key=lambda x: x["id"]) result[name]["sets"] = map( lambda x: {"name": x["name"], "template": x["template"]}, entry["sets"] ) return Response(result)
# ----------------------------------------------------------
[docs]class RetrieveDatabaseView(views.APIView): """ Returns the given database details """ model = Database permission_classes = [drf_permissions.AllowAny]
[docs] def get_object(self): version = self.kwargs["version"] database_name = self.kwargs["database_name"] user = self.request.user try: obj = self.model.objects.for_user(user, True).get( name__iexact=database_name, version=version ) except self.model.DoesNotExist: raise drf_exceptions.NotFound() return obj
[docs] def get(self, request, database_name, version): # Retrieve the database database = self.get_object() self.check_object_permissions(request, database) # Process the query string if "fields" in request.GET: fields_to_return = request.GET["fields"].split(",") else: fields_to_return = [ "name", "version", "last_version", "short_description", "description", "fork_of", "previous_version", "is_owner", "accessibility", "sharing", "opensource", "hash", "creation_date", "declaration", "code", ] # Prepare the response result = database_to_json(database, request.user, fields_to_return) # Retrieve the code if "declaration" in fields_to_return: declaration = database.declaration cleaned_declaration = clean_paths(declaration) result["declaration"] = json.dumps(cleaned_declaration) # Retrieve the source code if "code" in fields_to_return: result["code"] = database.source_code # Retrieve the description in HTML format if "html_description" in fields_to_return: description = database.description if len(description) > 0: result["html_description"] = ensure_html(description) else: result["html_description"] = "" # Retrieve the referenced data formats if "referenced_dataformats" in fields_to_return: dataformats = database.all_referenced_dataformats() referenced_dataformats = [] for dataformat in dataformats: (has_access, accessibility) = dataformat.accessibility_for(request.user) if has_access: referenced_dataformats.append(dataformat) serializer = ReferencedDataFormatSerializer(referenced_dataformats) result["referenced_dataformats"] = serializer.data # Retrieve the needed data formats if "needed_dataformats" in fields_to_return: dataformats = database.all_needed_dataformats() needed_dataformats = [] for dataformat in dataformats: (has_access, accessibility) = dataformat.accessibility_for(request.user) if has_access: needed_dataformats.append(dataformat) serializer = ReferencedDataFormatSerializer(needed_dataformats) result["needed_dataformats"] = serializer.data # Return the result return Response(result)