#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.web module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import json
import logging
import os
from pathlib import PurePath
from rest_framework import exceptions as drf_exceptions
from rest_framework import permissions as drf_permissions
from rest_framework import views
from rest_framework.response import Response
from ..common import is_true
from ..common import permissions as beat_permissions
from ..common.api import ListCreateBaseView
from ..common.utils import ensure_html
from ..dataformats.serializers import ReferencedDataFormatSerializer
from .models import Database
from .models import DatabaseSetTemplate
from .serializers import DatabaseCreationSerializer
from .serializers import DatabaseSerializer
logger = logging.getLogger(__name__)
# ----------------------------------------------------------
[docs]def database_to_json(database, request_user, fields_to_return):
# Prepare the response
result = {}
if "name" in fields_to_return:
result["name"] = database.fullname()
if "version" in fields_to_return:
result["version"] = database.version
if "last_version" in fields_to_return:
latest = (
Database.objects.for_user(request_user, True)
.filter(name=database.name)
.order_by("-version")[:1]
.first()
)
result["last_version"] = database.version == latest.version
if "short_description" in fields_to_return:
result["short_description"] = database.short_description
if "description" in fields_to_return:
result["description"] = database.description
if "previous_version" in fields_to_return:
result["previous_version"] = (
database.previous_version.fullname()
if database.previous_version is not None
else None
)
if "creation_date" in fields_to_return:
result["creation_date"] = database.creation_date.isoformat(" ")
if "hash" in fields_to_return:
result["hash"] = database.hash
if "accessibility" in fields_to_return:
result["accessibility"] = database.accessibility_for(request_user)
return result
[docs]def clean_paths(declaration):
pseudo_path = "/path_to_db_folder"
def _clean_path(item):
parameters = item.get("parameters", {})
if "annotations" not in parameters:
return
ppath = PurePath(parameters["annotations"])
if not ppath.is_absolute():
return
cleaned_folder = ppath.parts[-2:]
parameters["annotations"] = os.path.join(pseudo_path, *cleaned_folder)
root_folder = declaration["root_folder"]
cleaned_folder = os.path.basename(os.path.normpath(root_folder))
declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration["protocols"]:
# sets is a key only available in the V1 version of databases
if "sets" in protocol:
for set_ in protocol["sets"]:
_clean_path(set_)
else:
for view in protocol["views"].values():
_clean_path(view)
return declaration
# ----------------------------------------------------------
[docs]class ListCreateDatabasesView(ListCreateBaseView):
"""
Read/Write end point that list the database available
to a user and allows the creation of new databases only to
platform administrator
"""
model = Database
permission_classes = [beat_permissions.IsAdminOrReadOnly]
serializer_class = DatabaseSerializer
writing_serializer_class = DatabaseCreationSerializer
namespace = "api_databases"
[docs] def get_queryset(self):
user = self.request.user
return self.model.objects.for_user(user, True)
[docs] def get(self, request, *args, **kwargs):
fields_to_return = self.get_serializer_fields(request)
limit_to_latest_versions = is_true(
request.query_params.get("latest_versions", False)
)
all_databases = self.get_queryset().order_by("name")
if limit_to_latest_versions:
all_databases = self.model.filter_latest_versions(all_databases)
all_databases.sort(key=lambda x: x.fullname())
serializer = self.get_serializer(
all_databases, many=True, fields=fields_to_return
)
return Response(serializer.data)
# ----------------------------------------------------------
[docs]class ListTemplatesView(views.APIView):
"""
List all templates available
"""
permission_classes = [drf_permissions.AllowAny]
[docs] def get(self, request):
result = {}
# Retrieve all the protocols available to user
databases = Database.objects.for_user(request.user, True)
databases = Database.filter_latest_versions(databases)
for set_template in (
DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases)
.distinct()
.order_by("name")
):
(db_template, dataset) = set_template.name.split("__")
if db_template not in result:
result[db_template] = {"templates": {}, "sets": []}
result[db_template]["templates"][dataset] = map(
lambda x: x.name, set_template.outputs.order_by("name")
)
known_sets = []
for db_set in set_template.sets.iterator():
if db_set.name not in known_sets:
result[db_template]["sets"].append(
{"name": db_set.name, "template": dataset, "id": db_set.id}
)
known_sets.append(db_set.name)
for name, entry in result.items():
entry["sets"].sort(key=lambda x: x["id"])
result[name]["sets"] = map(
lambda x: {"name": x["name"], "template": x["template"]}, entry["sets"]
)
return Response(result)
# ----------------------------------------------------------
[docs]class RetrieveDatabaseView(views.APIView):
"""
Returns the given database details
"""
model = Database
permission_classes = [drf_permissions.AllowAny]
[docs] def get_object(self):
version = self.kwargs["version"]
database_name = self.kwargs["database_name"]
user = self.request.user
try:
obj = self.model.objects.for_user(user, True).get(
name__iexact=database_name, version=version
)
except self.model.DoesNotExist:
raise drf_exceptions.NotFound()
return obj
[docs] def get(self, request, database_name, version):
# Retrieve the database
database = self.get_object()
self.check_object_permissions(request, database)
# Process the query string
if "fields" in request.GET:
fields_to_return = request.GET["fields"].split(",")
else:
fields_to_return = [
"name",
"version",
"last_version",
"short_description",
"description",
"fork_of",
"previous_version",
"is_owner",
"accessibility",
"sharing",
"opensource",
"hash",
"creation_date",
"declaration",
"code",
]
# Prepare the response
result = database_to_json(database, request.user, fields_to_return)
# Retrieve the code
if "declaration" in fields_to_return:
declaration = database.declaration
cleaned_declaration = clean_paths(declaration)
result["declaration"] = json.dumps(cleaned_declaration)
# Retrieve the source code
if "code" in fields_to_return:
result["code"] = database.source_code
# Retrieve the description in HTML format
if "html_description" in fields_to_return:
description = database.description
if len(description) > 0:
result["html_description"] = ensure_html(description)
else:
result["html_description"] = ""
# Retrieve the referenced data formats
if "referenced_dataformats" in fields_to_return:
dataformats = database.all_referenced_dataformats()
referenced_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
referenced_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(referenced_dataformats)
result["referenced_dataformats"] = serializer.data
# Retrieve the needed data formats
if "needed_dataformats" in fields_to_return:
dataformats = database.all_needed_dataformats()
needed_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
needed_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(needed_dataformats)
result["needed_dataformats"] = serializer.data
# Return the result
return Response(result)