Source code for beat.web.common.serializers

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.web module of the BEAT platform.              #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

import copy
import difflib

import simplejson as json
from django.conf import settings
from django.contrib.auth.models import User
from rest_framework import exceptions as drf_exceptions
from rest_framework import serializers

from ..common import fields as beat_fields
from ..common.utils import annotate_full_name
from ..common.utils import ensure_html
from ..common.utils import validate_restructuredtext
from ..team.models import Team
from . import fields as serializer_fields
from .models import Contribution
from .models import Shareable
from .models import Versionable

# ----------------------------------------------------------


[docs]class DiffSerializer(serializers.Serializer): diff = serializers.SerializerMethodField()
[docs] def get_diff(self, obj): source1 = json.dumps(obj["object1"].declaration, indent=4) source2 = json.dumps(obj["object2"].declaration, indent=4) diff = difflib.ndiff(source1.splitlines(), source2.splitlines()) return "\n".join(filter(lambda x: x[0] != "?", list(diff)))
# ----------------------------------------------------------
[docs]class CheckNameSerializer(serializers.Serializer): name = serializers.CharField() used = serializers.SerializerMethodField()
[docs] def validate_name(self, name): return Contribution.sanitize_name(name)
[docs] def get_used(self, obj): name = obj.get("name") model = self.context.get("model") user = self.context.get("user") return model.objects.filter(author=user, name=name).exists()
[docs] def create(self, validated_data): return validated_data
# ----------------------------------------------------------
[docs]class SharingSerializer(serializers.Serializer): users = serializer_fields.StringListField(required=False) teams = serializer_fields.StringListField(required=False)
[docs] def validate_users(self, users): user_accounts = User.objects.filter(username__in=users).values_list( "username", flat=True ) if len(users) != user_accounts.count(): unknown_users = [user for user in users if user not in user_accounts] if len(unknown_users) > 1: raise serializers.ValidationError( ["Unknown usernames: " + ", ".join(unknown_users)] ) else: raise serializers.ValidationError( ["Unknown username: " + unknown_users[0]] ) return users
[docs] def validate_teams(self, teams): unknown_teams = [] user = self.context.get("user") for team_name in teams: parts = team_name.split("/") if len(parts) > 2: unknown_teams.append(team_name) continue elif len(parts) == 1: parts = [user.username, team_name] db_team = Team.objects.filter(owner__username=parts[0], name=parts[1]) if len(db_team) == 0: unknown_teams.append(team_name) if len(unknown_teams) > 1: raise serializers.ValidationError( "Unknown teams: " + ", ".join(unknown_teams) ) elif len(unknown_teams) == 1: raise serializers.ValidationError("Unknown team: " + unknown_teams[0]) return teams
# ----------------------------------------------------------
[docs]class DynamicFieldsSerializer(serializers.ModelSerializer):
[docs] class Meta: default_fields = []
def __init__(self, *args, **kwargs): # Don't pass the 'fields' arg up to the superclass fields = kwargs.pop("fields", self.Meta.default_fields) # Instantiate the superclass normally super(DynamicFieldsSerializer, self).__init__(*args, **kwargs) # Drop any fields that are not specified in the `fields` argument. allowed = set(fields) existing = set(self.fields.keys()) for field_name in existing - allowed: self.fields.pop(field_name)
# ----------------------------------------------------------
[docs]class ShareableSerializer(DynamicFieldsSerializer): accessibility = serializers.SerializerMethodField() sharing = serializers.SerializerMethodField() modifiable = serializers.BooleanField() deletable = serializers.BooleanField() is_owner = serializers.SerializerMethodField()
[docs] class Meta(DynamicFieldsSerializer.Meta): model = Shareable default_fields = DynamicFieldsSerializer.Meta.default_fields + [ "is_owner", "modifiable", "deletable", "sharing", ]
[docs] def get_accessibility(self, obj): if obj.sharing == Versionable.PUBLIC: return "public" elif obj.sharing == Versionable.SHARED or obj.sharing == Versionable.USABLE: return "confidential" else: return "private"
[docs] def get_sharing(self, obj): user = self.context.get("user") sharing = None if hasattr(obj, "author") and user == obj.author: sharing = {"status": obj.get_sharing_display().lower()} if obj.shared_with.count() > 0: sharing["shared_with"] = [ user.username for user in obj.shared_with.all() ] if obj.shared_with_team.count() > 0: sharing["shared_with_team"] = [ team.fullname() for team in obj.shared_with_team.all() ] return sharing
[docs] def get_is_owner(self, obj): if hasattr(obj, "author"): return obj.author == self.context.get("user") return False
# ----------------------------------------------------------
[docs]class VersionableSerializer(ShareableSerializer): name = serializers.CharField(source="fullname") fork_of = serializers.SerializerMethodField() last_version = serializers.SerializerMethodField() previous_version = serializers.SerializerMethodField() history = serializers.SerializerMethodField() short_description = serializers.CharField(max_length=100) code = serializers.SerializerMethodField()
[docs] class Meta(ShareableSerializer.Meta): model = Versionable default_fields = ShareableSerializer.Meta.default_fields + [ "name", "version", "last_version", "short_description", "fork_of", "previous_version", "accessibility", "hash", "creation_date", ]
[docs] def get_fork_of(self, obj): if not (obj.fork_of): return None accessibility_infos = obj.fork_of.accessibility_for(self.context.get("user")) return obj.fork_of.fullname() if accessibility_infos[0] else None
[docs] def get_last_version(self, obj): return self.Meta.model.objects.is_last_version(obj)
[docs] def get_previous_version(self, obj): if not (obj.previous_version): return None user = self.context.get("user") previous_version = obj.previous_version while previous_version is not None: accessibility_infos = previous_version.accessibility_for(user) if accessibility_infos[0]: return obj.previous_version.fullname() previous_version = previous_version.previous_version return None
[docs] def get_history(self, obj): return obj.api_history(self.context.get("user"))
# ----------------------------------------------------------
[docs]class ContributionSerializer(VersionableSerializer): description = serializers.SerializerMethodField() declaration = serializers.SerializerMethodField() html_description = serializers.SerializerMethodField()
[docs] class Meta(VersionableSerializer.Meta): model = Contribution extra_fields = ["description", "declaration"] exclude = ["description_file", "declaration_file"]
[docs] def get_description(self, obj): return obj.description
[docs] def get_declaration(self, obj): object_format = self.context.get("object_format") if object_format == "string": return json.dumps(obj.declaration, indent=4) else: return obj.declaration
[docs] def get_html_description(self, obj): description = obj.description if len(description) > 0: return ensure_html(description) return ""
# ----------------------------------------------------------
[docs]class ContributionModSerializer(serializers.ModelSerializer): declaration = beat_fields.JSONField(required=False) description = serializers.CharField(required=False, allow_blank=True)
[docs] class Meta: fields = ["short_description", "description", "declaration"] beat_core_class = None
[docs] def validate_description(self, description): if description.find("\\") >= 0: # was escaped, unescape description = description.decode("string_escape") validate_restructuredtext(description) return description
[docs] def validate_declaration(self, declaration): decl = copy.deepcopy(declaration) obj = self.Meta.beat_core_class(prefix=settings.PREFIX, data=decl) if not obj.valid: raise drf_exceptions.ValidationError(obj.errors) return declaration
[docs] def update(self, instance, validated_data): declaration = validated_data.get("declaration") if declaration is not None and not instance.modifiable(): raise drf_exceptions.PermissionDenied( "The {} isn't modifiable anymore (either shared with someone else, or needed by an attestation)".format( self.Meta.model.__name__.lower() ) ) return super().update(instance, validated_data)
[docs] def filter_representation(self, representation): """Filter out fields if given in query parameters""" request = self.context["request"] fields = request.query_params.get("fields", None) if fields is not None: fields = fields.split(",") to_remove = [key for key in representation.keys() if key not in fields] for key in to_remove: representation.pop(key) # Retrieve the description in HTML format if "html_description" in fields: description = self.instance.description if len(description) > 0: representation["html_description"] = ensure_html(description) else: representation["html_description"] = "" return representation
[docs] def to_representation(self, instance): representation = super().to_representation(instance) return self.filter_representation(representation)
# ----------------------------------------------------------
[docs]class ContributionCreationSerializer(ContributionModSerializer): fork_of = serializers.CharField(required=False) previous_version = serializers.CharField(required=False) version = serializers.IntegerField(min_value=1)
[docs] class Meta(ContributionModSerializer.Meta): fields = ContributionModSerializer.Meta.fields + [ "name", "previous_version", "fork_of", "version", ]
[docs] def validate_fork_of(self, fork_of): if "previous_version" in self.initial_data: raise serializers.ValidationError( "fork_of and previous_version cannot appear together" ) return fork_of
[docs] def validate_previous_version(self, previous_version): if "fork_of" in self.initial_data: raise serializers.ValidationError( "previous_version and fork_of cannot appear together" ) return previous_version
[docs] def validate_version(self, version): # If version is not one then it's necessarily a new version # forks start at one if version > 1 and "previous_version" not in self.initial_data: name = self.initial_data["name"] raise serializers.ValidationError( "{} {} version {} incomplete history data posted".format( self.Meta.model.__name__.lower(), name, version ) ) return version
[docs] def validate(self, data): user = self.context.get("user") name = self.Meta.model.sanitize_name(data["name"]) data["name"] = name version = data.get("version") kwargs = { "name": name, "version": version, } if hasattr(self.Meta.model, "author"): kwargs["author"] = user if self.Meta.model.objects.filter(**kwargs).exists(): raise serializers.ValidationError( "{} {} version {} already exists".format( self.Meta.model.__name__.lower(), name, version ) ) previous_version = data.get("previous_version") fork_of = data.get("fork_of") if previous_version is not None: try: previous_object = annotate_full_name(self.Meta.model.objects).get( full_name=previous_version ) except self.Meta.model.DoesNotExist: raise serializers.ValidationError( "{} '{}' not found".format( self.Meta.model.__name__, previous_version ) ) accessibility_infos = previous_object.accessibility_for(user) if not accessibility_infos.has_access: raise serializers.ValidationError("No access allowed") if version - previous_object.version != 1: raise serializers.ValidationError( "The requested version ({}) for this {} does not match" "the standard increment with {}".format( version, self.Meta.model.__name__, previous_object.version ) ) data["previous_version"] = previous_object elif fork_of is not None: if version > 1: raise serializers.ValidationError("A fork starts at 1") try: forked_of_object = annotate_full_name(self.Meta.model.objects).get( full_name=fork_of ) except self.Meta.model.DoesNotExist: raise serializers.ValidationError( "{} '{}' fork origin not found".format( self.Meta.model.__name__, fork_of ) ) accessibility_infos = forked_of_object.accessibility_for(user) if not accessibility_infos.has_access: raise serializers.ValidationError("No access allowed") data["fork_of"] = forked_of_object return data
[docs] def create(self, validated_data): (db_object, error) = self.Meta.model.objects.create_object(**validated_data) if error: raise drf_exceptions.APIException( "{} '{}' creation failed: {}".format( self.Meta.model.__name__, validated_data["name"], error ) ) return db_object