#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.web module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import copy
import difflib
import simplejson as json
from django.conf import settings
from django.contrib.auth.models import User
from rest_framework import exceptions as drf_exceptions
from rest_framework import serializers
from ..common import fields as beat_fields
from ..common.utils import annotate_full_name
from ..common.utils import ensure_html
from ..common.utils import validate_restructuredtext
from ..team.models import Team
from . import fields as serializer_fields
from .models import Contribution
from .models import Shareable
from .models import Versionable
# ----------------------------------------------------------
[docs]class DiffSerializer(serializers.Serializer):
diff = serializers.SerializerMethodField()
[docs] def get_diff(self, obj):
source1 = json.dumps(obj["object1"].declaration, indent=4)
source2 = json.dumps(obj["object2"].declaration, indent=4)
diff = difflib.ndiff(source1.splitlines(), source2.splitlines())
return "\n".join(filter(lambda x: x[0] != "?", list(diff)))
# ----------------------------------------------------------
[docs]class CheckNameSerializer(serializers.Serializer):
name = serializers.CharField()
used = serializers.SerializerMethodField()
[docs] def validate_name(self, name):
return Contribution.sanitize_name(name)
[docs] def get_used(self, obj):
name = obj.get("name")
model = self.context.get("model")
user = self.context.get("user")
return model.objects.filter(author=user, name=name).exists()
[docs] def create(self, validated_data):
return validated_data
# ----------------------------------------------------------
[docs]class SharingSerializer(serializers.Serializer):
users = serializer_fields.StringListField(required=False)
teams = serializer_fields.StringListField(required=False)
[docs] def validate_users(self, users):
user_accounts = User.objects.filter(username__in=users).values_list(
"username", flat=True
)
if len(users) != user_accounts.count():
unknown_users = [user for user in users if user not in user_accounts]
if len(unknown_users) > 1:
raise serializers.ValidationError(
["Unknown usernames: " + ", ".join(unknown_users)]
)
else:
raise serializers.ValidationError(
["Unknown username: " + unknown_users[0]]
)
return users
[docs] def validate_teams(self, teams):
unknown_teams = []
user = self.context.get("user")
for team_name in teams:
parts = team_name.split("/")
if len(parts) > 2:
unknown_teams.append(team_name)
continue
elif len(parts) == 1:
parts = [user.username, team_name]
db_team = Team.objects.filter(owner__username=parts[0], name=parts[1])
if len(db_team) == 0:
unknown_teams.append(team_name)
if len(unknown_teams) > 1:
raise serializers.ValidationError(
"Unknown teams: " + ", ".join(unknown_teams)
)
elif len(unknown_teams) == 1:
raise serializers.ValidationError("Unknown team: " + unknown_teams[0])
return teams
# ----------------------------------------------------------
[docs]class DynamicFieldsSerializer(serializers.ModelSerializer):
def __init__(self, *args, **kwargs):
# Don't pass the 'fields' arg up to the superclass
fields = kwargs.pop("fields", self.Meta.default_fields)
# Instantiate the superclass normally
super(DynamicFieldsSerializer, self).__init__(*args, **kwargs)
# Drop any fields that are not specified in the `fields` argument.
allowed = set(fields)
existing = set(self.fields.keys())
for field_name in existing - allowed:
self.fields.pop(field_name)
# ----------------------------------------------------------
[docs]class ShareableSerializer(DynamicFieldsSerializer):
accessibility = serializers.SerializerMethodField()
sharing = serializers.SerializerMethodField()
modifiable = serializers.BooleanField()
deletable = serializers.BooleanField()
is_owner = serializers.SerializerMethodField()
[docs] def get_accessibility(self, obj):
if obj.sharing == Versionable.PUBLIC:
return "public"
elif obj.sharing == Versionable.SHARED or obj.sharing == Versionable.USABLE:
return "confidential"
else:
return "private"
[docs] def get_sharing(self, obj):
user = self.context.get("user")
sharing = None
if hasattr(obj, "author") and user == obj.author:
sharing = {"status": obj.get_sharing_display().lower()}
if obj.shared_with.count() > 0:
sharing["shared_with"] = [
user.username for user in obj.shared_with.all()
]
if obj.shared_with_team.count() > 0:
sharing["shared_with_team"] = [
team.fullname() for team in obj.shared_with_team.all()
]
return sharing
[docs] def get_is_owner(self, obj):
if hasattr(obj, "author"):
return obj.author == self.context.get("user")
return False
# ----------------------------------------------------------
[docs]class VersionableSerializer(ShareableSerializer):
name = serializers.CharField(source="fullname")
fork_of = serializers.SerializerMethodField()
last_version = serializers.SerializerMethodField()
previous_version = serializers.SerializerMethodField()
history = serializers.SerializerMethodField()
short_description = serializers.CharField(max_length=100)
code = serializers.SerializerMethodField()
[docs] def get_fork_of(self, obj):
if not (obj.fork_of):
return None
accessibility_infos = obj.fork_of.accessibility_for(self.context.get("user"))
return obj.fork_of.fullname() if accessibility_infos[0] else None
[docs] def get_last_version(self, obj):
return self.Meta.model.objects.is_last_version(obj)
[docs] def get_previous_version(self, obj):
if not (obj.previous_version):
return None
user = self.context.get("user")
previous_version = obj.previous_version
while previous_version is not None:
accessibility_infos = previous_version.accessibility_for(user)
if accessibility_infos[0]:
return obj.previous_version.fullname()
previous_version = previous_version.previous_version
return None
[docs] def get_history(self, obj):
return obj.api_history(self.context.get("user"))
# ----------------------------------------------------------
[docs]class ContributionSerializer(VersionableSerializer):
description = serializers.SerializerMethodField()
declaration = serializers.SerializerMethodField()
html_description = serializers.SerializerMethodField()
[docs] def get_description(self, obj):
return obj.description
[docs] def get_declaration(self, obj):
object_format = self.context.get("object_format")
if object_format == "string":
return json.dumps(obj.declaration, indent=4)
else:
return obj.declaration
[docs] def get_html_description(self, obj):
description = obj.description
if len(description) > 0:
return ensure_html(description)
return ""
# ----------------------------------------------------------
[docs]class ContributionModSerializer(serializers.ModelSerializer):
declaration = beat_fields.JSONField(required=False)
description = serializers.CharField(required=False, allow_blank=True)
[docs] def validate_description(self, description):
if description.find("\\") >= 0: # was escaped, unescape
description = description.decode("string_escape")
validate_restructuredtext(description)
return description
[docs] def validate_declaration(self, declaration):
decl = copy.deepcopy(declaration)
obj = self.Meta.beat_core_class(prefix=settings.PREFIX, data=decl)
if not obj.valid:
raise drf_exceptions.ValidationError(obj.errors)
return declaration
[docs] def update(self, instance, validated_data):
declaration = validated_data.get("declaration")
if declaration is not None and not instance.modifiable():
raise drf_exceptions.PermissionDenied(
"The {} isn't modifiable anymore (either shared with someone else, or needed by an attestation)".format(
self.Meta.model.__name__.lower()
)
)
return super().update(instance, validated_data)
[docs] def filter_representation(self, representation):
"""Filter out fields if given in query parameters"""
request = self.context["request"]
fields = request.query_params.get("fields", None)
if fields is not None:
fields = fields.split(",")
to_remove = [key for key in representation.keys() if key not in fields]
for key in to_remove:
representation.pop(key)
# Retrieve the description in HTML format
if "html_description" in fields:
description = self.instance.description
if len(description) > 0:
representation["html_description"] = ensure_html(description)
else:
representation["html_description"] = ""
return representation
[docs] def to_representation(self, instance):
representation = super().to_representation(instance)
return self.filter_representation(representation)
# ----------------------------------------------------------
[docs]class ContributionCreationSerializer(ContributionModSerializer):
fork_of = serializers.CharField(required=False)
previous_version = serializers.CharField(required=False)
version = serializers.IntegerField(min_value=1)
[docs] def validate_fork_of(self, fork_of):
if "previous_version" in self.initial_data:
raise serializers.ValidationError(
"fork_of and previous_version cannot appear together"
)
return fork_of
[docs] def validate_previous_version(self, previous_version):
if "fork_of" in self.initial_data:
raise serializers.ValidationError(
"previous_version and fork_of cannot appear together"
)
return previous_version
[docs] def validate_version(self, version):
# If version is not one then it's necessarily a new version
# forks start at one
if version > 1 and "previous_version" not in self.initial_data:
name = self.initial_data["name"]
raise serializers.ValidationError(
"{} {} version {} incomplete history data posted".format(
self.Meta.model.__name__.lower(), name, version
)
)
return version
[docs] def validate(self, data):
user = self.context.get("user")
name = self.Meta.model.sanitize_name(data["name"])
data["name"] = name
version = data.get("version")
kwargs = {
"name": name,
"version": version,
}
if hasattr(self.Meta.model, "author"):
kwargs["author"] = user
if self.Meta.model.objects.filter(**kwargs).exists():
raise serializers.ValidationError(
"{} {} version {} already exists".format(
self.Meta.model.__name__.lower(), name, version
)
)
previous_version = data.get("previous_version")
fork_of = data.get("fork_of")
if previous_version is not None:
try:
previous_object = annotate_full_name(self.Meta.model.objects).get(
full_name=previous_version
)
except self.Meta.model.DoesNotExist:
raise serializers.ValidationError(
"{} '{}' not found".format(
self.Meta.model.__name__, previous_version
)
)
accessibility_infos = previous_object.accessibility_for(user)
if not accessibility_infos.has_access:
raise serializers.ValidationError("No access allowed")
if version - previous_object.version != 1:
raise serializers.ValidationError(
"The requested version ({}) for this {} does not match"
"the standard increment with {}".format(
version, self.Meta.model.__name__, previous_object.version
)
)
data["previous_version"] = previous_object
elif fork_of is not None:
if version > 1:
raise serializers.ValidationError("A fork starts at 1")
try:
forked_of_object = annotate_full_name(self.Meta.model.objects).get(
full_name=fork_of
)
except self.Meta.model.DoesNotExist:
raise serializers.ValidationError(
"{} '{}' fork origin not found".format(
self.Meta.model.__name__, fork_of
)
)
accessibility_infos = forked_of_object.accessibility_for(user)
if not accessibility_infos.has_access:
raise serializers.ValidationError("No access allowed")
data["fork_of"] = forked_of_object
return data
[docs] def create(self, validated_data):
(db_object, error) = self.Meta.model.objects.create_object(**validated_data)
if error:
raise drf_exceptions.APIException(
"{} '{}' creation failed: {}".format(
self.Meta.model.__name__, validated_data["name"], error
)
)
return db_object