summaryrefslogtreecommitdiffstats
path: root/api/authlib_sqla.py
diff options
context:
space:
mode:
Diffstat (limited to 'api/authlib_sqla.py')
-rw-r--r--api/authlib_sqla.py335
1 files changed, 335 insertions, 0 deletions
diff --git a/api/authlib_sqla.py b/api/authlib_sqla.py
new file mode 100644
index 00000000..4b3300f2
--- /dev/null
+++ b/api/authlib_sqla.py
@@ -0,0 +1,335 @@
+import time
+import json
+from sqlalchemy import Column, String, Boolean, Text, Integer
+from sqlalchemy.ext.hybrid import hybrid_property
+from authlib.oauth2.rfc6749 import (
+ ClientMixin,
+ TokenMixin,
+ AuthorizationCodeMixin,
+)
+from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope
+from authlib.oidc.core import AuthorizationCodeMixin as OIDCCodeMixin
+
+
+class OAuth2ClientMixin(ClientMixin):
+ client_id = Column(String(48), index=True)
+ client_secret = Column(String(120))
+ issued_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
+ expires_at = Column(Integer, nullable=False, default=0)
+
+ redirect_uri = Column(Text)
+ token_endpoint_auth_method = Column(String(48), default="client_secret_basic")
+ grant_type = Column(Text, nullable=False, default="")
+ response_type = Column(Text, nullable=False, default="")
+ scope = Column(Text, nullable=False, default="")
+
+ client_name = Column(String(100))
+ client_uri = Column(Text)
+ logo_uri = Column(Text)
+ contact = Column(Text)
+ tos_uri = Column(Text)
+ policy_uri = Column(Text)
+ jwks_uri = Column(Text)
+ jwks_text = Column(Text)
+ i18n_metadata = Column(Text)
+
+ software_id = Column(String(36))
+ software_version = Column(String(48))
+
+ def __repr__(self):
+ return "<Client: {}>".format(self.client_id)
+
+ @hybrid_property
+ def redirect_uris(self):
+ if self.redirect_uri:
+ return self.redirect_uri.splitlines()
+ return []
+
+ @redirect_uris.setter
+ def redirect_uris(self, value):
+ self.redirect_uri = "\n".join(value)
+
+ @hybrid_property
+ def grant_types(self):
+ if self.grant_type:
+ return self.grant_type.splitlines()
+ return []
+
+ @grant_types.setter
+ def grant_types(self, value):
+ self.grant_type = "\n".join(value)
+
+ @hybrid_property
+ def response_types(self):
+ if self.response_type:
+ return self.response_type.splitlines()
+ return []
+
+ @response_types.setter
+ def response_types(self, value):
+ self.response_type = "\n".join(value)
+
+ @hybrid_property
+ def contacts(self):
+ if self.contact:
+ return json.loads(self.contact)
+ return []
+
+ @contacts.setter
+ def contacts(self, value):
+ self.contact = json.dumps(value)
+
+ @hybrid_property
+ def jwks(self):
+ if self.jwks_text:
+ return json.loads(self.jwks_text)
+ return None
+
+ @jwks.setter
+ def jwks(self, value):
+ self.jwks_text = json.dumps(value)
+
+ @hybrid_property
+ def client_metadata(self):
+ """Implementation for Client Metadata in OAuth 2.0 Dynamic Client
+ Registration Protocol via `Section 2`_.
+
+ .. _`Section 2`: https://tools.ietf.org/html/rfc7591#section-2
+ """
+ keys = [
+ "redirect_uris",
+ "token_endpoint_auth_method",
+ "grant_types",
+ "response_types",
+ "client_name",
+ "client_uri",
+ "logo_uri",
+ "scope",
+ "contacts",
+ "tos_uri",
+ "policy_uri",
+ "jwks_uri",
+ "jwks",
+ ]
+ metadata = {k: getattr(self, k) for k in keys}
+ if self.i18n_metadata:
+ metadata.update(json.loads(self.i18n_metadata))
+ return metadata
+
+ @client_metadata.setter
+ def client_metadata(self, value):
+ i18n_metadata = {}
+ for k in value:
+ if hasattr(self, k):
+ setattr(self, k, value[k])
+ elif "#" in k:
+ i18n_metadata[k] = value[k]
+
+ self.i18n_metadata = json.dumps(i18n_metadata)
+
+ @property
+ def client_info(self):
+ """Implementation for Client Info in OAuth 2.0 Dynamic Client
+ Registration Protocol via `Section 3.2.1`_.
+
+ .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
+ """
+ return dict(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ client_id_issued_at=self.issued_at,
+ client_secret_expires_at=self.expires_at,
+ )
+
+ def get_client_id(self):
+ return self.client_id
+
+ def get_default_redirect_uri(self):
+ if self.redirect_uris:
+ return self.redirect_uris[0]
+
+ def get_allowed_scope(self, scope):
+ if not scope:
+ return ""
+ allowed = set(self.scope.split())
+ scopes = scope_to_list(scope)
+ return list_to_scope([s for s in scopes if s in allowed])
+
+ def check_redirect_uri(self, redirect_uri):
+ return redirect_uri in self.redirect_uris
+
+ def has_client_secret(self):
+ return bool(self.client_secret)
+
+ def check_client_secret(self, client_secret):
+ return self.client_secret == client_secret
+
+ def check_token_endpoint_auth_method(self, method):
+ return self.token_endpoint_auth_method == method
+
+ def check_response_type(self, response_type):
+ if self.response_type:
+ return response_type in self.response_types
+ return False
+
+ def check_grant_type(self, grant_type):
+ if self.grant_type:
+ return grant_type in self.grant_types
+ return False
+
+
+class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
+ code = Column(String(120), unique=True, nullable=False)
+ client_id = Column(String(48))
+ redirect_uri = Column(Text, default="")
+ response_type = Column(Text, default="")
+ scope = Column(Text, default="")
+ auth_time = Column(Integer, nullable=False, default=lambda: int(time.time()))
+
+ def is_expired(self):
+ return self.auth_time + 300 < time.time()
+
+ def get_redirect_uri(self):
+ return self.redirect_uri
+
+ def get_scope(self):
+ return self.scope
+
+ def get_auth_time(self):
+ return self.auth_time
+
+
+class OIDCAuthorizationCodeMixin(OAuth2AuthorizationCodeMixin, OIDCCodeMixin):
+ nonce = Column(Text)
+
+ def get_nonce(self):
+ return self.nonce
+
+
+class OAuth2TokenMixin(TokenMixin):
+ client_id = Column(String(48))
+ token_type = Column(String(40))
+ access_token = Column(String(255), unique=True, nullable=False)
+ refresh_token = Column(String(255), index=True)
+ scope = Column(Text, default="")
+ revoked = Column(Boolean, default=False)
+ issued_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
+ expires_in = Column(Integer, nullable=False, default=0)
+
+ def get_client_id(self):
+ return self.client_id
+
+ def get_scope(self):
+ return self.scope
+
+ def get_expires_in(self):
+ return self.expires_in
+
+ def get_expires_at(self):
+ return self.issued_at + self.expires_in
+
+
+def create_query_client_func(session, client_model):
+ """Create an ``query_client`` function that can be used in authorization
+ server.
+
+ :param session: SQLAlchemy session
+ :param client_model: Client model class
+ """
+
+ def query_client(client_id):
+ q = session.query(client_model)
+ return q.filter_by(client_id=client_id).first()
+
+ return query_client
+
+
+def create_save_token_func(session, token_model):
+ """Create an ``save_token`` function that can be used in authorization
+ server.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+
+ def save_token(token, request):
+ if request.user:
+ user_id = request.user.get_user_id()
+ else:
+ user_id = None
+ client = request.client
+ item = token_model(client_id=client.client_id, user_id=user_id, **token)
+ session.add(item)
+ session.commit()
+
+ return save_token
+
+
+def create_query_token_func(session, token_model):
+ """Create an ``query_token`` function for revocation, introspection
+ token endpoints.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+
+ def query_token(token, token_type_hint, client):
+ q = session.query(token_model)
+ q = q.filter_by(client_id=client.client_id, revoked=False)
+ if token_type_hint == "access_token":
+ return q.filter_by(access_token=token).first()
+ elif token_type_hint == "refresh_token":
+ return q.filter_by(refresh_token=token).first()
+ # without token_type_hint
+ item = q.filter_by(access_token=token).first()
+ if item:
+ return item
+ return q.filter_by(refresh_token=token).first()
+
+ return query_token
+
+
+def create_revocation_endpoint(session, token_model):
+ """Create a revocation endpoint class with SQLAlchemy session
+ and token model.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+ from authlib.oauth2.rfc7009 import RevocationEndpoint
+
+ query_token = create_query_token_func(session, token_model)
+
+ class _RevocationEndpoint(RevocationEndpoint):
+ def query_token(self, token, token_type_hint, client):
+ return query_token(token, token_type_hint, client)
+
+ def revoke_token(self, token):
+ token.revoked = True
+ session.add(token)
+ session.commit()
+
+ return _RevocationEndpoint
+
+
+def create_bearer_token_validator(session, token_model):
+ """Create an bearer token validator class with SQLAlchemy session
+ and token model.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+ from authlib.oauth2.rfc6750 import BearerTokenValidator
+
+ class _BearerTokenValidator(BearerTokenValidator):
+ def authenticate_token(self, token_string):
+ q = session.query(token_model)
+ return q.filter_by(access_token=token_string).first()
+
+ def request_invalid(self, request):
+ return False
+
+ def token_revoked(self, token):
+ return token.revoked
+
+ return _BearerTokenValidator