You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

461 lines
18 KiB
Python

# Copyright (C) 2021-2023 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import os.path as osp
import functools
import hashlib
from django.utils.functional import SimpleLazyObject
from django.http import Http404, HttpResponseBadRequest, HttpResponseRedirect
from rest_framework import views, serializers
from rest_framework.exceptions import ValidationError, NotFound
from rest_framework.permissions import AllowAny
from rest_framework.decorators import api_view, permission_classes
from django.conf import settings
from django.http import HttpResponse
from django.views.decorators.http import etag as django_etag
from rest_framework.response import Response
from dj_rest_auth.registration.views import RegisterView, SocialLoginView
from dj_rest_auth.views import LoginView
from dj_rest_auth.utils import import_callable
from allauth.account import app_settings as allauth_settings
from allauth.account.views import ConfirmEmailView
from allauth.account.utils import has_verified_email, send_email_confirmation
from allauth.socialaccount.models import SocialLogin
from allauth.socialaccount.providers.oauth2.views import OAuth2CallbackView, OAuth2LoginView
from allauth.socialaccount.providers.oauth2.client import OAuth2Client
from allauth.utils import get_request_param
from furl import furl
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiResponse, OpenApiParameter, extend_schema, inline_serializer, extend_schema_view
from drf_spectacular.contrib.rest_auth import get_token_serializer_class
from .authentication import Signer
from cvat.apps.iam.serializers import SocialLoginSerializerEx, SocialAuthMethodSerializer
GitHubAdapter = (
import_callable(settings.SOCIALACCOUNT_GITHUB_ADAPTER)
if settings.USE_ALLAUTH_SOCIAL_ACCOUNTS
else None
)
GoogleAdapter = (
import_callable(settings.SOCIALACCOUNT_GOOGLE_ADAPTER)
if settings.USE_ALLAUTH_SOCIAL_ACCOUNTS
else None
)
AmazonCognitoAdapter = (
import_callable(settings.SOCIALACCOUNT_AMAZON_COGNITO_ADAPTER)
if settings.USE_ALLAUTH_SOCIAL_ACCOUNTS
else None
)
def get_context(request):
from cvat.apps.organizations.models import Organization, Membership
IAM_ROLES = {role:priority for priority, role in enumerate(settings.IAM_ROLES)}
groups = list(request.user.groups.filter(name__in=list(IAM_ROLES.keys())))
groups.sort(key=lambda group: IAM_ROLES[group.name])
organization = None
membership = None
try:
org_slug = request.GET.get('org')
org_id = request.GET.get('org_id')
org_header = request.headers.get('X-Organization')
if org_id is not None and (org_slug is not None or org_header is not None):
raise ValidationError('You cannot specify "org_id" query parameter with '
'"org" query parameter or "X-Organization" HTTP header at the same time.')
if org_slug is not None and org_header is not None and org_slug != org_header:
raise ValidationError('You cannot specify "org" query parameter and '
'"X-Organization" HTTP header with different values.')
org_slug = org_slug if org_slug is not None else org_header
org_filter = None
if org_slug:
organization = Organization.objects.get(slug=org_slug)
membership = Membership.objects.filter(organization=organization,
user=request.user).first()
org_filter = { 'organization': organization.id }
elif org_id:
organization = Organization.objects.get(id=int(org_id))
membership = Membership.objects.filter(organization=organization,
user=request.user).first()
org_filter = { 'organization': organization.id }
elif org_slug is not None:
org_filter = { 'organization': None }
except Organization.DoesNotExist:
raise NotFound(f'{org_slug or org_id} organization does not exist.')
if membership and not membership.is_active:
membership = None
context = {
"privilege": groups[0] if groups else None,
"membership": membership,
"organization": organization,
"visibility": org_filter,
}
return context
class ContextMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# https://stackoverflow.com/questions/26240832/django-and-middleware-which-uses-request-user-is-always-anonymous
request.iam_context = SimpleLazyObject(lambda: get_context(request))
return self.get_response(request)
@extend_schema(tags=['auth'])
@extend_schema_view(post=extend_schema(
summary='This method signs URL for access to the server',
description='Signed URL contains a token which authenticates a user on the server.'
'Signed URL is valid during 30 seconds since signing.',
request=inline_serializer(
name='Signing',
fields={
'url': serializers.CharField(),
}
),
responses={'200': OpenApiResponse(response=OpenApiTypes.STR, description='text URL')}))
class SigningView(views.APIView):
def post(self, request):
url = request.data.get('url')
if not url:
raise ValidationError('Please provide `url` parameter')
signer = Signer()
url = self.request.build_absolute_uri(url)
sign = signer.sign(self.request.user, url)
url = furl(url).add({Signer.QUERY_PARAM: sign}).url
return Response(url)
class LoginViewEx(LoginView):
"""
Check the credentials and return the REST Token
if the credentials are valid and authenticated.
If email verification is enabled and the user has the unverified email,
an email with a confirmation link will be sent.
Calls Django Auth login method to register User ID
in Django session framework.
Accept the following POST parameters: username, email, password
Return the REST Framework Token Object's key.
"""
@extend_schema(responses=get_token_serializer_class())
def post(self, request, *args, **kwargs):
self.request = request
self.serializer = self.get_serializer(data=self.request.data)
try:
self.serializer.is_valid(raise_exception=True)
except ValidationError:
user = self.serializer.get_auth_user(
self.serializer.data.get('username'),
self.serializer.data.get('email'),
self.serializer.data.get('password')
)
if not user:
raise
# Check that user's email is verified.
# If not, send a verification email.
if not has_verified_email(user):
send_email_confirmation(request, user)
# we cannot use redirect to ACCOUNT_EMAIL_VERIFICATION_SENT_REDIRECT_URL here
# because redirect will make a POST request and we'll get a 404 code
# (although in the browser request method will be displayed like GET)
return HttpResponseBadRequest('Unverified email')
except Exception: # nosec
pass
self.login()
return self.get_response()
class RegisterViewEx(RegisterView):
def get_response_data(self, user):
data = self.get_serializer(user).data
data['email_verification_required'] = True
data['key'] = None
if allauth_settings.EMAIL_VERIFICATION != \
allauth_settings.EmailVerificationMethod.MANDATORY:
data['email_verification_required'] = False
data['key'] = user.auth_token.key
return data
def _etag(etag_func):
"""
Decorator to support conditional retrieval (or change)
for a Django Rest Framework's ViewSet.
It calls Django's original decorator but pass correct request object to it.
Django's original decorator doesn't work with DRF request object.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(obj_self, request, *args, **kwargs):
drf_request = request
wsgi_request = request._request
@django_etag(etag_func=etag_func)
def patched_viewset_method(*_args, **_kwargs):
"""Call original viewset method with correct type of request"""
return func(obj_self, drf_request, *args, **kwargs)
return patched_viewset_method(wsgi_request, *args, **kwargs)
return wrapper
return decorator
class RulesView(views.APIView):
serializer_class = None
permission_classes = [AllowAny]
authentication_classes = []
iam_organization_field = None
@staticmethod
def _get_bundle_path():
return settings.IAM_OPA_BUNDLE_PATH
@staticmethod
def _etag_func(file_path):
with open(file_path, 'rb') as f:
return hashlib.blake2b(f.read()).hexdigest()
@_etag(lambda _: RulesView._etag_func(RulesView._get_bundle_path()))
def get(self, request):
file_obj = open(self._get_bundle_path() ,"rb")
return HttpResponse(file_obj, content_type='application/x-tar')
class OAuth2CallbackViewEx(OAuth2CallbackView):
def dispatch(self, request, *args, **kwargs):
# Distinguish cancel from error
if (auth_error := request.GET.get('error', None)):
if auth_error == self.adapter.login_cancelled_error:
return HttpResponseRedirect(settings.SOCIALACCOUNT_CALLBACK_CANCELLED_URL)
else: # unknown error
raise ValidationError(auth_error)
code = request.GET.get('code')
# verify request state
if self.adapter.supports_state:
state = SocialLogin.verify_and_unstash_state(
request, get_request_param(request, 'state')
)
else:
state = SocialLogin.unstash_state(request)
if not code:
return HttpResponseBadRequest('Parameter code not found in request')
provider = self.adapter.provider_id.replace('_', '-')
return HttpResponseRedirect(
f'{settings.SOCIAL_APP_LOGIN_REDIRECT_URL}/?provider={provider}&code={code}'
f'&auth_params={state.get("auth_params")}&process={state.get("process")}'
f'&scope={state.get("scope")}')
@extend_schema(
summary="Redirets to Github authentication page",
description="Redirects to the Github authentication page. "
"After successful authentication on the provider side, "
"a redirect to the callback endpoint is performed",
)
@api_view(["GET"])
@permission_classes([AllowAny])
def github_oauth2_login(*args, **kwargs):
return OAuth2LoginView.adapter_view(GitHubAdapter)(*args, **kwargs)
@extend_schema(
summary="Checks the authentication response from Github, redirects to the CVAT client if successful.",
description="Accepts a request from Github with code and state query parameters. "
"In case of successful authentication on the provider side, it will "
"redirect to the CVAT client",
parameters=[
OpenApiParameter('code', description='Returned by github',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter('state', description='Returned by github',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
],
)
@api_view(["GET"])
@permission_classes([AllowAny])
def github_oauth2_callback(*args, **kwargs):
return OAuth2CallbackViewEx.adapter_view(GitHubAdapter)(*args, **kwargs)
@extend_schema(
summary="Redirects to Google authentication page",
description="Redirects to the Google authentication page. "
"After successful authentication on the provider side, "
"a redirect to the callback endpoint is performed.",
)
@api_view(["GET"])
@permission_classes([AllowAny])
def google_oauth2_login(*args, **kwargs):
return OAuth2LoginView.adapter_view(GoogleAdapter)(*args, **kwargs)
@extend_schema(
summary="Redirects to Amazon Cognito authentication page",
description="Redirects to the Amazon Cognito authentication page. "
"After successful authentication on the provider side, "
"a redirect to the callback endpoint is performed.",
)
@api_view(["GET"])
@permission_classes([AllowAny])
def amazon_cognito_oauth2_login(*args, **kwargs):
return OAuth2LoginView.adapter_view(AmazonCognitoAdapter)(*args, **kwargs)
@extend_schema(
summary="Checks the authentication response from Google, redirects to the CVAT client if successful.",
description="Accepts a request from Google with code and state query parameters. "
"In case of successful authentication on the provider side, it will "
"redirect to the CVAT client",
parameters=[
OpenApiParameter('code', description='Returned by google',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter('state', description='Returned by google',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
],
)
@api_view(["GET"])
@permission_classes([AllowAny])
def google_oauth2_callback(*args, **kwargs):
return OAuth2CallbackViewEx.adapter_view(GoogleAdapter)(*args, **kwargs)
@extend_schema(
summary="Checks the authentication response from Amazon Cognito, redirects to the CVAT client if successful.",
description="Accepts a request from Amazon Cognito with code and state query parameters. "
"In case of successful authentication on the provider side, it will "
"redirect to the CVAT client",
parameters=[
OpenApiParameter('code', description='Returned by google',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter('state', description='Returned by google',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
],
)
@api_view(["GET"])
@permission_classes([AllowAny])
def amazon_cognito_oauth2_callback(*args, **kwargs):
return OAuth2CallbackViewEx.adapter_view(AmazonCognitoAdapter)(*args, **kwargs)
class ConfirmEmailViewEx(ConfirmEmailView):
template_name = 'account/email/email_confirmation_signup_message.html'
def get(self, *args, **kwargs):
try:
if not allauth_settings.CONFIRM_EMAIL_ON_GET:
return super().get(*args, **kwargs)
return self.post(*args, **kwargs)
except Http404:
return HttpResponseRedirect(settings.INCORRECT_EMAIL_CONFIRMATION_URL)
@extend_schema(
methods=['POST'],
summary='Method returns an authentication token based on code parameter',
description="After successful authentication on the provider side, "
"the provider returns the 'code' parameter used to receive "
"an authentication token required for CVAT authentication.",
parameters=[
OpenApiParameter('auth_params', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter('process', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter('scope', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
],
responses=get_token_serializer_class()
)
class SocialLoginViewEx(SocialLoginView):
serializer_class = SocialLoginSerializerEx
def post(self, request, *args, **kwargs):
# we have to re-implement this method because
# there is one case not covered by dj_rest_auth but covered by allauth
# user can be logged in with social account and "unverified" email
# (e.g. the provider doesn't provide information about email verification)
self.request = request
self.serializer = self.get_serializer(data=self.request.data)
self.serializer.is_valid(raise_exception=True)
if allauth_settings.EMAIL_VERIFICATION == allauth_settings.EmailVerificationMethod.MANDATORY and \
not has_verified_email(self.serializer.validated_data.get('user')):
return HttpResponseBadRequest('Unverified email')
self.login()
return self.get_response()
class GitHubLogin(SocialLoginViewEx):
adapter_class = GitHubAdapter
client_class = OAuth2Client
callback_url = getattr(settings, 'GITHUB_CALLBACK_URL', None)
class GoogleLogin(SocialLoginViewEx):
adapter_class = GoogleAdapter
client_class = OAuth2Client
callback_url = getattr(settings, 'GOOGLE_CALLBACK_URL', None)
class CognitoLogin(SocialLoginViewEx):
adapter_class = AmazonCognitoAdapter
client_class = OAuth2Client
callback_url = getattr(settings, 'AMAZON_COGNITO_REDIRECT_URI', None)
@extend_schema_view(
get=extend_schema(
summary='Method provides a list with integrated social accounts authentication.',
responses={
'200': OpenApiResponse(response=inline_serializer(
name="SocialAuthMethodsSerializer",
fields={
'google': SocialAuthMethodSerializer(),
'github': SocialAuthMethodSerializer(),
'amazon-cognito': SocialAuthMethodSerializer(),
}
)),
}
)
)
class SocialAuthMethods(views.APIView):
serializer_class = SocialAuthMethodSerializer
permission_classes = [AllowAny]
authentication_classes = []
iam_organization_field = None
def get(self, request, *args, **kwargs):
use_social_auth = settings.USE_ALLAUTH_SOCIAL_ACCOUNTS
integrated_auth_providers = settings.SOCIALACCOUNT_PROVIDERS.keys() if use_social_auth else []
response = dict()
for provider in integrated_auth_providers:
icon = None
is_enabled = bool(
getattr(settings, f'SOCIAL_AUTH_{provider.upper()}_CLIENT_ID', None)
and getattr(settings, f'SOCIAL_AUTH_{provider.upper()}_CLIENT_SECRET', None)
)
icon_path = osp.join(settings.STATIC_ROOT, 'social_authentication', f'social-{provider.replace("_", "-")}-logo.svg')
if is_enabled and osp.exists(icon_path):
with open(icon_path, 'r') as f:
icon = f.read()
serializer = SocialAuthMethodSerializer(data={
'is_enabled': is_enabled,
'icon': icon,
'public_name': settings.SOCIALACCOUNT_PROVIDERS[provider].get('PUBLIC_NAME', provider.title())
})
serializer.is_valid(raise_exception=True)
response[provider.replace("_", "-")] = serializer.validated_data
return Response(response)