diff options
author | Dashie <dashie@sigpipe.me> | 2021-02-25 11:05:12 +0100 |
---|---|---|
committer | Dashie <dashie@sigpipe.me> | 2021-02-25 11:05:12 +0100 |
commit | a2f864ff2e49dd198f5eadd147af360841644518 (patch) | |
tree | 3973e763c4f3b4aa7c0412765a83ec7aa0225346 | |
parent | 37e0f874a1c6a108a8940329e2e1e534d8d328b9 (diff) |
Import Authlib SQLAlchemy backend from before integrations split; Fix migrations
-rw-r--r-- | api/authlib_sqla.py | 335 | ||||
-rw-r--r-- | api/controllers/api/v1/auth.py | 23 | ||||
-rw-r--r-- | api/migrations/versions/67_7df5c87e5fef_.py | 98 | ||||
-rw-r--r-- | api/migrations/versions/68_dff4edfb26b6_.py | 41 | ||||
-rw-r--r-- | api/migrations/versions/f1993296be9e_.py | 35 | ||||
-rw-r--r-- | api/models.py | 2 |
6 files changed, 381 insertions, 153 deletions
diff --git a/api/authlib_sqla.py b/api/authlib_sqla.py new file mode 100644 index 00000000..4b3300f2 --- /dev/null +++ b/api/authlib_sqla.py @@ -0,0 +1,335 @@ +import time +import json +from sqlalchemy import Column, String, Boolean, Text, Integer +from sqlalchemy.ext.hybrid import hybrid_property +from authlib.oauth2.rfc6749 import ( + ClientMixin, + TokenMixin, + AuthorizationCodeMixin, +) +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.oidc.core import AuthorizationCodeMixin as OIDCCodeMixin + + +class OAuth2ClientMixin(ClientMixin): + client_id = Column(String(48), index=True) + client_secret = Column(String(120)) + issued_at = Column(Integer, nullable=False, default=lambda: int(time.time())) + expires_at = Column(Integer, nullable=False, default=0) + + redirect_uri = Column(Text) + token_endpoint_auth_method = Column(String(48), default="client_secret_basic") + grant_type = Column(Text, nullable=False, default="") + response_type = Column(Text, nullable=False, default="") + scope = Column(Text, nullable=False, default="") + + client_name = Column(String(100)) + client_uri = Column(Text) + logo_uri = Column(Text) + contact = Column(Text) + tos_uri = Column(Text) + policy_uri = Column(Text) + jwks_uri = Column(Text) + jwks_text = Column(Text) + i18n_metadata = Column(Text) + + software_id = Column(String(36)) + software_version = Column(String(48)) + + def __repr__(self): + return "<Client: {}>".format(self.client_id) + + @hybrid_property + def redirect_uris(self): + if self.redirect_uri: + return self.redirect_uri.splitlines() + return [] + + @redirect_uris.setter + def redirect_uris(self, value): + self.redirect_uri = "\n".join(value) + + @hybrid_property + def grant_types(self): + if self.grant_type: + return self.grant_type.splitlines() + return [] + + @grant_types.setter + def grant_types(self, value): + self.grant_type = "\n".join(value) + + @hybrid_property + def response_types(self): + if self.response_type: + return self.response_type.splitlines() + return [] + + @response_types.setter + def response_types(self, value): + self.response_type = "\n".join(value) + + @hybrid_property + def contacts(self): + if self.contact: + return json.loads(self.contact) + return [] + + @contacts.setter + def contacts(self, value): + self.contact = json.dumps(value) + + @hybrid_property + def jwks(self): + if self.jwks_text: + return json.loads(self.jwks_text) + return None + + @jwks.setter + def jwks(self, value): + self.jwks_text = json.dumps(value) + + @hybrid_property + def client_metadata(self): + """Implementation for Client Metadata in OAuth 2.0 Dynamic Client + Registration Protocol via `Section 2`_. + + .. _`Section 2`: https://tools.ietf.org/html/rfc7591#section-2 + """ + keys = [ + "redirect_uris", + "token_endpoint_auth_method", + "grant_types", + "response_types", + "client_name", + "client_uri", + "logo_uri", + "scope", + "contacts", + "tos_uri", + "policy_uri", + "jwks_uri", + "jwks", + ] + metadata = {k: getattr(self, k) for k in keys} + if self.i18n_metadata: + metadata.update(json.loads(self.i18n_metadata)) + return metadata + + @client_metadata.setter + def client_metadata(self, value): + i18n_metadata = {} + for k in value: + if hasattr(self, k): + setattr(self, k, value[k]) + elif "#" in k: + i18n_metadata[k] = value[k] + + self.i18n_metadata = json.dumps(i18n_metadata) + + @property + def client_info(self): + """Implementation for Client Info in OAuth 2.0 Dynamic Client + Registration Protocol via `Section 3.2.1`_. + + .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1 + """ + return dict( + client_id=self.client_id, + client_secret=self.client_secret, + client_id_issued_at=self.issued_at, + client_secret_expires_at=self.expires_at, + ) + + def get_client_id(self): + return self.client_id + + def get_default_redirect_uri(self): + if self.redirect_uris: + return self.redirect_uris[0] + + def get_allowed_scope(self, scope): + if not scope: + return "" + allowed = set(self.scope.split()) + scopes = scope_to_list(scope) + return list_to_scope([s for s in scopes if s in allowed]) + + def check_redirect_uri(self, redirect_uri): + return redirect_uri in self.redirect_uris + + def has_client_secret(self): + return bool(self.client_secret) + + def check_client_secret(self, client_secret): + return self.client_secret == client_secret + + def check_token_endpoint_auth_method(self, method): + return self.token_endpoint_auth_method == method + + def check_response_type(self, response_type): + if self.response_type: + return response_type in self.response_types + return False + + def check_grant_type(self, grant_type): + if self.grant_type: + return grant_type in self.grant_types + return False + + +class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin): + code = Column(String(120), unique=True, nullable=False) + client_id = Column(String(48)) + redirect_uri = Column(Text, default="") + response_type = Column(Text, default="") + scope = Column(Text, default="") + auth_time = Column(Integer, nullable=False, default=lambda: int(time.time())) + + def is_expired(self): + return self.auth_time + 300 < time.time() + + def get_redirect_uri(self): + return self.redirect_uri + + def get_scope(self): + return self.scope + + def get_auth_time(self): + return self.auth_time + + +class OIDCAuthorizationCodeMixin(OAuth2AuthorizationCodeMixin, OIDCCodeMixin): + nonce = Column(Text) + + def get_nonce(self): + return self.nonce + + +class OAuth2TokenMixin(TokenMixin): + client_id = Column(String(48)) + token_type = Column(String(40)) + access_token = Column(String(255), unique=True, nullable=False) + refresh_token = Column(String(255), index=True) + scope = Column(Text, default="") + revoked = Column(Boolean, default=False) + issued_at = Column(Integer, nullable=False, default=lambda: int(time.time())) + expires_in = Column(Integer, nullable=False, default=0) + + def get_client_id(self): + return self.client_id + + def get_scope(self): + return self.scope + + def get_expires_in(self): + return self.expires_in + + def get_expires_at(self): + return self.issued_at + self.expires_in + + +def create_query_client_func(session, client_model): + """Create an ``query_client`` function that can be used in authorization + server. + + :param session: SQLAlchemy session + :param client_model: Client model class + """ + + def query_client(client_id): + q = session.query(client_model) + return q.filter_by(client_id=client_id).first() + + return query_client + + +def create_save_token_func(session, token_model): + """Create an ``save_token`` function that can be used in authorization + server. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + + def save_token(token, request): + if request.user: + user_id = request.user.get_user_id() + else: + user_id = None + client = request.client + item = token_model(client_id=client.client_id, user_id=user_id, **token) + session.add(item) + session.commit() + + return save_token + + +def create_query_token_func(session, token_model): + """Create an ``query_token`` function for revocation, introspection + token endpoints. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + + def query_token(token, token_type_hint, client): + q = session.query(token_model) + q = q.filter_by(client_id=client.client_id, revoked=False) + if token_type_hint == "access_token": + return q.filter_by(access_token=token).first() + elif token_type_hint == "refresh_token": + return q.filter_by(refresh_token=token).first() + # without token_type_hint + item = q.filter_by(access_token=token).first() + if item: + return item + return q.filter_by(refresh_token=token).first() + + return query_token + + +def create_revocation_endpoint(session, token_model): + """Create a revocation endpoint class with SQLAlchemy session + and token model. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + from authlib.oauth2.rfc7009 import RevocationEndpoint + + query_token = create_query_token_func(session, token_model) + + class _RevocationEndpoint(RevocationEndpoint): + def query_token(self, token, token_type_hint, client): + return query_token(token, token_type_hint, client) + + def revoke_token(self, token): + token.revoked = True + session.add(token) + session.commit() + + return _RevocationEndpoint + + +def create_bearer_token_validator(session, token_model): + """Create an bearer token validator class with SQLAlchemy session + and token model. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + from authlib.oauth2.rfc6750 import BearerTokenValidator + + class _BearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string): + q = session.query(token_model) + return q.filter_by(access_token=token_string).first() + + def request_invalid(self, request): + return False + + def token_revoked(self, token): + return token.revoked + + return _BearerTokenValidator diff --git a/api/controllers/api/v1/auth.py b/api/controllers/api/v1/auth.py index 10261675..865bc5a6 100644 --- a/api/controllers/api/v1/auth.py +++ b/api/controllers/api/v1/auth.py @@ -41,23 +41,20 @@ def create_client(): return response client = OAuth2Client() + client.client_name = req.get("client_name") + client.client_uri = req.get("website", None) + client.redirect_uri = req.get("redirect_uris") + client.scope = req.get("scopes") client.client_id = gen_salt(24) - metadatas = { - "client_name": req.get("client_name"), - "client_uri": req.get("website", None), - "redirect_uris": [req.get("redirect_uris")], # TODO, should make a check or something - "scope": req.get("scopes"), - # this needs to be hardcoded for whatever reason - "response_type": "code", - "grant_types": ["authorization_code", "client_credentials", "password"], - "token_endpoint_auth_method": "client_secret_post", - } if client.token_endpoint_auth_method == "none": client.client_secret = "" else: client.client_secret = gen_salt(48) - client.set_client_metadata(metadatas) + # this needs to be hardcoded for whatever reason + client.response_type = "code" + client.grant_type = "authorization_code\r\nclient_credentials\r\npassword" + client.token_endpoint_auth_method = "client_secret_post" db.session.add(client) db.session.commit() @@ -67,7 +64,7 @@ def create_client(): "client_secret": client.client_secret, "id": client.id, "name": client.client_name, - "redirect_uris": client.redirect_uris[0] if len(client.redirect_uris) else "", + "redirect_uri": client.redirect_uri, "website": client.client_uri, "vapid_key": None, # FIXME to implement this } @@ -140,7 +137,7 @@ def oauth_token(): d["redirect_uri"] = request.json["redirect_uri"] request.form = ImmutableMultiDict(d) - return authorization.create_token_response(request) + return authorization.create_token_response() @bp_api_v1_auth.route("/oauth/revoke", methods=["POST"]) diff --git a/api/migrations/versions/67_7df5c87e5fef_.py b/api/migrations/versions/67_7df5c87e5fef_.py deleted file mode 100644 index 23e99f37..00000000 --- a/api/migrations/versions/67_7df5c87e5fef_.py +++ /dev/null @@ -1,98 +0,0 @@ -"""oauth2_client + fs_uniquifier - -Revision ID: 7df5c87e5fef -Revises: f537ac7a67d6 -Create Date: 2021-02-23 19:07:51.420492 - -""" - -# revision identifiers, used by Alembic. -revision = "7df5c87e5fef" -down_revision = "f537ac7a67d6" - -from alembic import op # noqa: E402 -import sqlalchemy as sa # noqa: E402 -import uuid # noqa: E402 - -""" -Note: Authlib introduced major breaking change by removing thoses fields that I wasn't able to find in the changelogs -Theses will stay commented, they have been added into the models. -""" - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("oauth2_client", sa.Column("client_id_issued_at", sa.Integer(), nullable=False)) - op.add_column("oauth2_client", sa.Column("client_metadata", sa.Text(), nullable=True)) - op.add_column("oauth2_client", sa.Column("client_secret_expires_at", sa.Integer(), nullable=False)) - op.drop_column("oauth2_client", "policy_uri") - op.drop_column("oauth2_client", "tos_uri") - # op.drop_column("oauth2_client", "client_name") - # op.drop_column("oauth2_client", "response_type") - op.drop_column("oauth2_client", "expires_at") - op.drop_column("oauth2_client", "software_id") - # op.drop_column("oauth2_client", "redirect_uri") - op.drop_column("oauth2_client", "software_version") - # op.drop_column("oauth2_client", "client_uri") - op.drop_column("oauth2_client", "contact") - op.drop_column("oauth2_client", "i18n_metadata") - op.drop_column("oauth2_client", "logo_uri") - # op.drop_column("oauth2_client", "grant_type") - op.drop_column("oauth2_client", "jwks_uri") - op.drop_column("oauth2_client", "issued_at") - # op.drop_column("oauth2_client", "scope") - op.drop_column("oauth2_client", "jwks_text") - # op.drop_column("oauth2_client", "token_endpoint_auth_method") - op.add_column("oauth2_code", sa.Column("code_challenge", sa.Text(), nullable=True)) - op.add_column("oauth2_code", sa.Column("code_challenge_method", sa.String(length=48), nullable=True)) - op.add_column("oauth2_code", sa.Column("nonce", sa.Text(), nullable=True)) - - # fs_uniquifier - op.add_column("user", sa.Column("fs_uniquifier", sa.String(length=255), nullable=True)) - user_table = sa.Table( - "user", sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True), sa.Column("fs_uniquifier", sa.String) - ) - conn = op.get_bind() - for row in conn.execute(sa.select([user_table.c.id])): - conn.execute(user_table.update().values(fs_uniquifier=uuid.uuid4().hex).where(user_table.c.id == row["id"])) - op.alter_column("user", "fs_uniquifier", nullable=False) - - op.create_unique_constraint(None, "user", ["fs_uniquifier"]) - # end fs_uniquifier - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "user", type_="unique") - op.drop_column("user", "fs_uniquifier") - op.drop_column("oauth2_code", "nonce") - op.drop_column("oauth2_code", "code_challenge_method") - op.drop_column("oauth2_code", "code_challenge") - # op.add_column( - # "oauth2_client", - # sa.Column("token_endpoint_auth_method", sa.VARCHAR(length=48), autoincrement=False, nullable=True), - # ) - op.add_column("oauth2_client", sa.Column("jwks_text", sa.TEXT(), autoincrement=False, nullable=True)) - # op.add_column("oauth2_client", sa.Column("scope", sa.TEXT(), autoincrement=False, nullable=False)) - op.add_column("oauth2_client", sa.Column("issued_at", sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column("oauth2_client", sa.Column("jwks_uri", sa.TEXT(), autoincrement=False, nullable=True)) - # op.add_column("oauth2_client", sa.Column("grant_type", sa.TEXT(), autoincrement=False, nullable=False)) - op.add_column("oauth2_client", sa.Column("logo_uri", sa.TEXT(), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("i18n_metadata", sa.TEXT(), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("contact", sa.TEXT(), autoincrement=False, nullable=True)) - # op.add_column("oauth2_client", sa.Column("client_uri", sa.TEXT(), autoincrement=False, nullable=True)) - op.add_column( - "oauth2_client", sa.Column("software_version", sa.VARCHAR(length=48), autoincrement=False, nullable=True) - ) - # op.add_column("oauth2_client", sa.Column("redirect_uri", sa.TEXT(), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("software_id", sa.VARCHAR(length=36), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("expires_at", sa.INTEGER(), autoincrement=False, nullable=False)) - # op.add_column("oauth2_client", sa.Column("response_type", sa.TEXT(), autoincrement=False, nullable=False)) - # op.add_column("oauth2_client", sa.Column("client_name", sa.VARCHAR(length=100), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("tos_uri", sa.TEXT(), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("policy_uri", sa.TEXT(), autoincrement=False, nullable=True)) - op.drop_column("oauth2_client", "client_secret_expires_at") - op.drop_column("oauth2_client", "client_metadata") - op.drop_column("oauth2_client", "client_id_issued_at") - # ### end Alembic commands ### diff --git a/api/migrations/versions/68_dff4edfb26b6_.py b/api/migrations/versions/68_dff4edfb26b6_.py deleted file mode 100644 index 8d74d34d..00000000 --- a/api/migrations/versions/68_dff4edfb26b6_.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Migrate to client_metadata - -Revision ID: dff4edfb26b6 -Revises: 7df5c87e5fef -Create Date: 2021-02-23 22:59:23.112766 - -""" - -# revision identifiers, used by Alembic. -revision = "dff4edfb26b6" -down_revision = "7df5c87e5fef" - -from alembic import op # noqa: E402 -import sqlalchemy as sa # noqa: E402 - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("oauth2_client", "redirect_uri") - op.drop_column("oauth2_client", "client_uri") - op.drop_column("oauth2_client", "response_type") - op.drop_column("oauth2_client", "client_name") - op.drop_column("oauth2_client", "grant_type") - op.drop_column("oauth2_client", "token_endpoint_auth_method") - op.drop_column("oauth2_client", "scope") - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("oauth2_client", sa.Column("scope", sa.TEXT(), autoincrement=False, nullable=False)) - op.add_column( - "oauth2_client", - sa.Column("token_endpoint_auth_method", sa.VARCHAR(length=48), autoincrement=False, nullable=True), - ) - op.add_column("oauth2_client", sa.Column("grant_type", sa.TEXT(), autoincrement=False, nullable=False)) - op.add_column("oauth2_client", sa.Column("client_name", sa.VARCHAR(length=100), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("response_type", sa.TEXT(), autoincrement=False, nullable=False)) - op.add_column("oauth2_client", sa.Column("client_uri", sa.TEXT(), autoincrement=False, nullable=True)) - op.add_column("oauth2_client", sa.Column("redirect_uri", sa.TEXT(), autoincrement=False, nullable=True)) - # ### end Alembic commands ### diff --git a/api/migrations/versions/f1993296be9e_.py b/api/migrations/versions/f1993296be9e_.py new file mode 100644 index 00000000..ac8e69ab --- /dev/null +++ b/api/migrations/versions/f1993296be9e_.py @@ -0,0 +1,35 @@ +"""Flask-Security fs_uniquifier migration + +Revision ID: f1993296be9e +Revises: f537ac7a67d6 +Create Date: 2021-02-25 10:50:29.973424 + +""" + +# revision identifiers, used by Alembic. +revision = "f1993296be9e" +down_revision = "f537ac7a67d6" + +from alembic import op # noqa: E402 +import sqlalchemy as sa # noqa: E402 +import uuid # noqa: E402 + + +def upgrade(): + op.add_column("user", sa.Column("fs_uniquifier", sa.String(length=255), nullable=True)) + user_table = sa.Table( + "user", sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True), sa.Column("fs_uniquifier", sa.String) + ) + conn = op.get_bind() + for row in conn.execute(sa.select([user_table.c.id])): + conn.execute(user_table.update().values(fs_uniquifier=uuid.uuid4().hex).where(user_table.c.id == row["id"])) + op.alter_column("user", "fs_uniquifier", nullable=False) + + op.create_unique_constraint(None, "user", ["fs_uniquifier"]) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "user", type_="unique") + op.drop_column("user", "fs_uniquifier") + # ### end Alembic commands ### diff --git a/api/models.py b/api/models.py index 7dd2699c..62dc2755 100644 --- a/api/models.py +++ b/api/models.py @@ -19,7 +19,7 @@ from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy import text as sa_text from little_boxes import activitypub as ap from urllib.parse import urlparse -from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin, OAuth2AuthorizationCodeMixin, OAuth2TokenMixin +from authlib_sqla import OAuth2ClientMixin, OAuth2AuthorizationCodeMixin, OAuth2TokenMixin import time import uuid from utils.defaults import Reel2bitsDefaults |