Source code for beat.web.databases.api

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.web module of the BEAT platform.              #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

import os
import json

from django.http import HttpResponse
from django.core.urlresolvers import reverse

from rest_framework.response import Response
from rest_framework import permissions
from rest_framework import views
from rest_framework import status
from rest_framework import generics


from .models import Database
from .models import DatabaseSetTemplate

from .serializers import DatabaseSerializer, DatabaseCreationSerializer
from .exceptions import DatabaseCreationError

from ..common import is_true
from ..common.mixins import IsAdminOrReadOnlyMixin
from ..common.api import ListCreateBaseView
from ..common.responses import BadRequestResponse
from ..common.utils import ensure_html
from ..dataformats.serializers import ReferencedDataFormatSerializer

import logging
import traceback
logger = logging.getLogger(__name__)


#----------------------------------------------------------

[docs]def database_to_json(database, request_user, fields_to_return, last_version=None): # Prepare the response result = {} if 'name' in fields_to_return: result['name'] = database.fullname() if 'version' in fields_to_return: result['version'] = database.version if 'last_version' in fields_to_return: result['last_version'] = last_version if 'short_description' in fields_to_return: result['short_description'] = database.short_description if 'description' in fields_to_return: result['description'] = database.description if 'previous_version' in fields_to_return: result['previous_version'] = \ (database.previous_version.fullname() if \ database.previous_version is not None else None) if 'creation_date' in fields_to_return: result['creation_date'] = database.creation_date.isoformat(' ') if 'hash' in fields_to_return: result['hash'] = database.hash if 'accessibility' in fields_to_return: result['accessibility'] = database.accessibility_for(request_user) return result
[docs]def clean_paths(declaration): pseudo_path = '/path_to_db_folder' root_folder = declaration['root_folder'] cleaned_folder = os.path.basename(os.path.normpath(root_folder)) declaration['root_folder'] = os.path.join(pseudo_path, cleaned_folder) for protocol in declaration['protocols']: for set_ in protocol['sets']: if 'parameters' in set_ and 'annotations' in set_['parameters']: annotations_folder = set_['parameters']['annotations'] cleaned_folder = annotations_folder.split('/')[-2:] set_['parameters']['annotations'] = os.path.join(pseudo_path, *cleaned_folder) return declaration
#----------------------------------------------------------
[docs]class ListCreateDatabasesView(IsAdminOrReadOnlyMixin, ListCreateBaseView): """ Read/Write end point that list the database available to a user and allows the creation of new databases only to platform administrator """ model = Database serializer_class = DatabaseSerializer writing_serializer_class = DatabaseCreationSerializer namespace = 'api_databases'
[docs] def get_queryset(self): user = self.request.user return self.model.objects.for_user(user, True)
[docs] def get(self, request, *args, **kwargs): fields_to_return = self.get_serializer_fields(request) limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False)) all_databases = self.get_queryset().order_by('name') if limit_to_latest_versions: all_databases = self.model.filter_latest_versions(all_databases) all_databases.sort(key=lambda x: x.fullname()) serializer = self.get_serializer(all_databases, many=True, fields=fields_to_return) return Response(serializer.data)
#----------------------------------------------------------
[docs]class ListTemplatesView(views.APIView): """ List all templates available """ permission_classes = [permissions.AllowAny]
[docs] def get(self, request): result = {} # Retrieve all the protocols available to user databases = Database.objects.for_user(request.user, True) databases = Database.filter_latest_versions(databases) for set_template in DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases).distinct().order_by('name'): (db_template, dataset) = set_template.name.split('__') if db_template not in result: result[db_template] = { 'templates': {}, 'sets': [], } result[db_template]['templates'][dataset] = map(lambda x: x.name, set_template.outputs.order_by('name')) known_sets = [] for db_set in set_template.sets.iterator(): if db_set.name not in known_sets: result[db_template]['sets'].append({ 'name': db_set.name, 'template': dataset, 'id': db_set.id, }) known_sets.append(db_set.name) for name, entry in result.items(): entry['sets'].sort(key=lambda x: x['id']) result[name]['sets'] = map(lambda x: { 'name': x['name'], 'template': x['template'], }, entry['sets']) return Response(result)
#----------------------------------------------------------
[docs]class RetrieveDatabaseView(views.APIView): """ Returns the given database details """ permission_classes = [permissions.AllowAny]
[docs] def get(self, request, database_name, version=None): # Retrieve the database try: if version is not None: version = int(version) databases = Database.objects.for_user(request.user, True).filter(name__iexact=database_name, version__gte=version).order_by('version') database = databases[0] if database.version != version: return HttpResponse(status=status.HTTP_404_NOT_FOUND) last_version = (len(databases) == 1) else: database = Database.objects.for_user(request.user, True).filter( name__iexact=database_name).order_by('-version')[0] last_version = True except: return HttpResponse(status=status.HTTP_404_NOT_FOUND) # Process the query string if 'fields' in request.GET: fields_to_return = request.GET['fields'].split(',') else: fields_to_return = [ 'name', 'version', 'last_version', 'short_description', 'description', 'fork_of', 'previous_version', 'is_owner', 'accessibility', 'sharing', 'opensource', 'hash', 'creation_date', 'declaration', 'code' ] try: # Prepare the response result = database_to_json(database, request.user, fields_to_return, last_version=last_version) # Retrieve the code if 'declaration' in fields_to_return: try: declaration = database.declaration except: logger.error(traceback.format_exc()) return HttpResponse(status=500) cleaned_declaration = clean_paths(declaration) result['declaration'] = json.dumps(cleaned_declaration) # Retrieve the source code if 'code' in fields_to_return: try: result['code'] = database.source_code except: logger.error(traceback.format_exc()) return HttpResponse(status=500) # Retrieve the description in HTML format if 'html_description' in fields_to_return: description = database.description if len(description) > 0: result['html_description'] = ensure_html(description) else: result['html_description'] = '' # Retrieve the referenced data formats if 'referenced_dataformats' in fields_to_return: dataformats = database.all_referenced_dataformats() referenced_dataformats = [] for dataformat in dataformats: (has_access, accessibility) = dataformat.accessibility_for(request.user) if has_access: referenced_dataformats.append(dataformat) serializer = ReferencedDataFormatSerializer(referenced_dataformats) result['referenced_dataformats'] = serializer.data # Retrieve the needed data formats if 'needed_dataformats' in fields_to_return: dataformats = database.all_needed_dataformats() needed_dataformats = [] for dataformat in dataformats: (has_access, accessibility) = dataformat.accessibility_for(request.user) if has_access: needed_dataformats.append(dataformat) serializer = ReferencedDataFormatSerializer(needed_dataformats) result['needed_dataformats'] = serializer.data # Return the result return Response(result) except: logger.error(traceback.format_exc()) return HttpResponse(status=500)