#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.web module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import os
import json
from django.http import HttpResponse
from django.core.urlresolvers import reverse
from rest_framework.response import Response
from rest_framework import permissions
from rest_framework import views
from rest_framework import status
from rest_framework import generics
from .models import Database
from .models import DatabaseSetTemplate
from .serializers import DatabaseSerializer, DatabaseCreationSerializer
from .exceptions import DatabaseCreationError
from ..common import is_true
from ..common.mixins import IsAdminOrReadOnlyMixin
from ..common.api import ListCreateBaseView
from ..common.responses import BadRequestResponse
from ..common.utils import ensure_html
from ..dataformats.serializers import ReferencedDataFormatSerializer
import logging
import traceback
logger = logging.getLogger(__name__)
#----------------------------------------------------------
[docs]def database_to_json(database, request_user, fields_to_return,
last_version=None):
# Prepare the response
result = {}
if 'name' in fields_to_return:
result['name'] = database.fullname()
if 'version' in fields_to_return:
result['version'] = database.version
if 'last_version' in fields_to_return:
result['last_version'] = last_version
if 'short_description' in fields_to_return:
result['short_description'] = database.short_description
if 'description' in fields_to_return:
result['description'] = database.description
if 'previous_version' in fields_to_return:
result['previous_version'] = \
(database.previous_version.fullname() if \
database.previous_version is not None else None)
if 'creation_date' in fields_to_return:
result['creation_date'] = database.creation_date.isoformat(' ')
if 'hash' in fields_to_return:
result['hash'] = database.hash
if 'accessibility' in fields_to_return:
result['accessibility'] = database.accessibility_for(request_user)
return result
[docs]def clean_paths(declaration):
pseudo_path = '/path_to_db_folder'
root_folder = declaration['root_folder']
cleaned_folder = os.path.basename(os.path.normpath(root_folder))
declaration['root_folder'] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration['protocols']:
for set_ in protocol['sets']:
if 'parameters' in set_ and 'annotations' in set_['parameters']:
annotations_folder = set_['parameters']['annotations']
cleaned_folder = annotations_folder.split('/')[-2:]
set_['parameters']['annotations'] = os.path.join(pseudo_path, *cleaned_folder)
return declaration
#----------------------------------------------------------
[docs]class ListCreateDatabasesView(IsAdminOrReadOnlyMixin, ListCreateBaseView):
"""
Read/Write end point that list the database available
to a user and allows the creation of new databases only to
platform administrator
"""
model = Database
serializer_class = DatabaseSerializer
writing_serializer_class = DatabaseCreationSerializer
namespace = 'api_databases'
[docs] def get_queryset(self):
user = self.request.user
return self.model.objects.for_user(user, True)
[docs] def get(self, request, *args, **kwargs):
fields_to_return = self.get_serializer_fields(request)
limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False))
all_databases = self.get_queryset().order_by('name')
if limit_to_latest_versions:
all_databases = self.model.filter_latest_versions(all_databases)
all_databases.sort(key=lambda x: x.fullname())
serializer = self.get_serializer(all_databases, many=True, fields=fields_to_return)
return Response(serializer.data)
#----------------------------------------------------------
[docs]class ListTemplatesView(views.APIView):
"""
List all templates available
"""
permission_classes = [permissions.AllowAny]
[docs] def get(self, request):
result = {}
# Retrieve all the protocols available to user
databases = Database.objects.for_user(request.user, True)
databases = Database.filter_latest_versions(databases)
for set_template in DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases).distinct().order_by('name'):
(db_template, dataset) = set_template.name.split('__')
if db_template not in result:
result[db_template] = {
'templates': {},
'sets': [],
}
result[db_template]['templates'][dataset] = map(lambda x: x.name, set_template.outputs.order_by('name'))
known_sets = []
for db_set in set_template.sets.iterator():
if db_set.name not in known_sets:
result[db_template]['sets'].append({
'name': db_set.name,
'template': dataset,
'id': db_set.id,
})
known_sets.append(db_set.name)
for name, entry in result.items():
entry['sets'].sort(key=lambda x: x['id'])
result[name]['sets'] = map(lambda x: {
'name': x['name'],
'template': x['template'],
}, entry['sets'])
return Response(result)
#----------------------------------------------------------
[docs]class RetrieveDatabaseView(views.APIView):
"""
Returns the given database details
"""
permission_classes = [permissions.AllowAny]
[docs] def get(self, request, database_name, version=None):
# Retrieve the database
try:
if version is not None:
version = int(version)
databases = Database.objects.for_user(request.user, True).filter(name__iexact=database_name,
version__gte=version).order_by('version')
database = databases[0]
if database.version != version:
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
last_version = (len(databases) == 1)
else:
database = Database.objects.for_user(request.user, True).filter(
name__iexact=database_name).order_by('-version')[0]
last_version = True
except:
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
# Process the query string
if 'fields' in request.GET:
fields_to_return = request.GET['fields'].split(',')
else:
fields_to_return = [ 'name', 'version', 'last_version',
'short_description', 'description', 'fork_of',
'previous_version', 'is_owner', 'accessibility', 'sharing',
'opensource', 'hash', 'creation_date',
'declaration', 'code' ]
try:
# Prepare the response
result = database_to_json(database, request.user, fields_to_return,
last_version=last_version)
# Retrieve the code
if 'declaration' in fields_to_return:
try:
declaration = database.declaration
except:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
cleaned_declaration = clean_paths(declaration)
result['declaration'] = json.dumps(cleaned_declaration)
# Retrieve the source code
if 'code' in fields_to_return:
try:
result['code'] = database.source_code
except:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
# Retrieve the description in HTML format
if 'html_description' in fields_to_return:
description = database.description
if len(description) > 0:
result['html_description'] = ensure_html(description)
else:
result['html_description'] = ''
# Retrieve the referenced data formats
if 'referenced_dataformats' in fields_to_return:
dataformats = database.all_referenced_dataformats()
referenced_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
referenced_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(referenced_dataformats)
result['referenced_dataformats'] = serializer.data
# Retrieve the needed data formats
if 'needed_dataformats' in fields_to_return:
dataformats = database.all_needed_dataformats()
needed_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
needed_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(needed_dataformats)
result['needed_dataformats'] = serializer.data
# Return the result
return Response(result)
except:
logger.error(traceback.format_exc())
return HttpResponse(status=500)