You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
218 lines
7.3 KiB
Python
218 lines
7.3 KiB
Python
# Copyright (C) 2023 CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import string
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
from random import choice, random, sample
|
|
from urllib.parse import parse_qsl, urlparse
|
|
|
|
|
|
class CommonRequestHandlerClass(BaseHTTPRequestHandler, ABC):
|
|
def _set_headers(self):
|
|
self.send_response(406)
|
|
self.send_header("Content-type", "text/html")
|
|
self.end_headers()
|
|
self.wfile.write(f"Unsupported request. Path: {self.path}".encode("utf8"))
|
|
|
|
def get_profile(self, token=None):
|
|
if not token:
|
|
self.send_response(403)
|
|
self.end_headers()
|
|
return
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
|
|
self.wfile.write(json.dumps(self.PROFILE).encode("utf-8"))
|
|
|
|
@abstractmethod
|
|
def authorize(self, query_params):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def generate_access_token(self):
|
|
pass
|
|
|
|
def check_query(self, query_params):
|
|
supported_response_type = "code"
|
|
if not "client_id" in query_params:
|
|
self.send_response(400)
|
|
self.wfile.write("Client id not found in query params".encode("utf8"))
|
|
return
|
|
if not "redirect_uri" in query_params:
|
|
self.send_response(400)
|
|
self.wfile.write("Redirect uri not found".encode("utf8"))
|
|
return
|
|
if query_params.get("response_type", "code") != supported_response_type:
|
|
self.send_response(400)
|
|
self.wfile.write(
|
|
"Only code response type is supported by dummy auth server".encode("utf8")
|
|
)
|
|
return
|
|
|
|
def do_GET(self):
|
|
u = urlparse(self.path)
|
|
if u.path == self.AUTHORIZE_PATH:
|
|
return self.authorize(dict(parse_qsl(u.query)))
|
|
elif u.path == self.PROFILE_PATH:
|
|
token = self.headers.get("Authorization") or dict(parse_qsl(u.query)).get(
|
|
"access_token"
|
|
)
|
|
return self.get_profile(token)
|
|
self._set_headers()
|
|
|
|
def do_POST(self):
|
|
u = urlparse(self.path)
|
|
if u.path == self.TOKEN_PATH:
|
|
return self.generate_access_token()
|
|
self._set_headers()
|
|
|
|
|
|
class GithubRequestHandlerClass(CommonRequestHandlerClass):
|
|
AUTHORIZE_PATH = "/login/oauth/authorize"
|
|
PROFILE_PATH = "/user"
|
|
TOKEN_PATH = "/login/oauth/access_token"
|
|
|
|
CODE_LENGTH = 20
|
|
AUTH_TOKEN_LENGTH = 40
|
|
|
|
LOGIN = "test-user"
|
|
UID = int(random() * 100)
|
|
|
|
# demo profile not including all information returned by github
|
|
PROFILE = {
|
|
"login": LOGIN,
|
|
"id": UID,
|
|
"avatar_url": f"https://avatars.github.com/u/{UID}",
|
|
"url": f"https://api.github.com/users/{LOGIN}",
|
|
"html_url": f"https://github.com/{LOGIN}",
|
|
"type": "User",
|
|
"site_admin": False,
|
|
"name": "Test User",
|
|
"location": "Germany, Munich",
|
|
"email": "github.user@test.com",
|
|
"hireable": None,
|
|
"created_at": str(datetime.now()),
|
|
"updated_at": str(datetime.now()),
|
|
"two_factor_authentication": False,
|
|
}
|
|
|
|
def authorize(self, query_params):
|
|
super().check_query(query_params)
|
|
self.send_response(302)
|
|
redirect_to = query_params["redirect_uri"]
|
|
generated_code = "".join(sample(string.ascii_lowercase + string.digits, self.CODE_LENGTH))
|
|
|
|
# add query params
|
|
new_query = (
|
|
f"?code={generated_code}&state={query_params['state']}&"
|
|
f"scope={query_params['scope']}&promt=none"
|
|
)
|
|
redirect_to += new_query
|
|
self.send_header("Location", redirect_to)
|
|
self.send_header("Content-type", "text/html")
|
|
self.end_headers()
|
|
|
|
def generate_access_token(self):
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "application/x-www-form-urlencoded; charset=utf-8")
|
|
generated_token = "".join(
|
|
sample(string.ascii_letters + string.digits, self.AUTH_TOKEN_LENGTH)
|
|
)
|
|
scope = "read:user,user:email"
|
|
content = f"access_token={generated_token}&scope={scope}&token_type=bearer".encode("utf-8")
|
|
self.end_headers()
|
|
self.wfile.write(content)
|
|
|
|
|
|
class GoogleRequestHandlerClass(CommonRequestHandlerClass):
|
|
AUTHORIZE_PATH = "/o/oauth2/auth"
|
|
PROFILE_PATH = "/oauth2/v1/userinfo"
|
|
TOKEN_PATH = "/o/oauth2/token"
|
|
|
|
CODE_LENGTH = 70 # in real case 256 bytes
|
|
AUTH_TOKEN_LENGTH = 100 # in real case 2048 bytes
|
|
|
|
UID = int(random() * 100)
|
|
|
|
# demo profile not including all information returned by google
|
|
PROFILE = {
|
|
"id": UID,
|
|
"email": "google.user@gmail.com",
|
|
"verified_email": True,
|
|
"name": "Test User",
|
|
"given_name": "Test",
|
|
"family_name": "User",
|
|
"picture": f"https://avatars.google.com/u/{UID}",
|
|
"locale": "en",
|
|
}
|
|
|
|
def authorize(self, query_params):
|
|
super().check_query(query_params)
|
|
self.send_response(302)
|
|
redirect_to = query_params["redirect_uri"]
|
|
symbols = string.ascii_letters + string.digits
|
|
generated_code = "".join([choice(symbols) for i in range(self.CODE_LENGTH)])
|
|
|
|
# add query params
|
|
new_query = (
|
|
f"?code={generated_code}&state={query_params['state']}&"
|
|
f"scope={query_params['scope']}&promt=none"
|
|
)
|
|
redirect_to += new_query
|
|
self.send_header("Location", redirect_to)
|
|
self.send_header("Content-type", "text/html")
|
|
self.end_headers()
|
|
|
|
def generate_access_token(self):
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "application/json; charset=utf-8")
|
|
symbols = string.ascii_letters + string.digits + string.punctuation
|
|
generated_token = "".join([choice(symbols) for i in range(self.AUTH_TOKEN_LENGTH)])
|
|
id_token = "".join([choice(symbols) for i in range(self.AUTH_TOKEN_LENGTH)])
|
|
scope = "https://www.googleapis.com/auth/userinfo.profile openid https://www.googleapis.com/auth/userinfo.email"
|
|
content = {
|
|
"access_token": generated_token,
|
|
"expires_in": 3600, # 1 h
|
|
"scope": scope,
|
|
"token_type": "Bearer",
|
|
"id_token": id_token,
|
|
}
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(content).encode("utf-8"))
|
|
|
|
|
|
class AuthServer:
|
|
SERVER_HOST = "0.0.0.0"
|
|
|
|
def run(self):
|
|
print(f"Starting dummy authentication server on {self.SERVER_HOST}, {self.SERVER_PORT}")
|
|
HTTPServer((self.SERVER_HOST, self.SERVER_PORT), self.REQUEST_HANDLER_CLASS).serve_forever()
|
|
|
|
|
|
class GoogleAuthServer(AuthServer):
|
|
SERVER_PORT = int(os.environ.get("GOOGLE_SERVER_PORT", "4320"))
|
|
REQUEST_HANDLER_CLASS = GoogleRequestHandlerClass
|
|
|
|
|
|
class GithubAuthServer(AuthServer):
|
|
SERVER_PORT = int(os.environ.get("GITHUB_SERVER_PORT", "4321"))
|
|
REQUEST_HANDLER_CLASS = GithubRequestHandlerClass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--server", choices=["google", "github"], type=str, default="google")
|
|
server = parser.parse_args().server
|
|
auth_servers = {
|
|
"google": GoogleAuthServer,
|
|
"github": GithubAuthServer,
|
|
}
|
|
server_class = auth_servers[server]
|
|
server_class().run()
|