summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDashie <dashie@sigpipe.me>2021-02-25 11:05:12 +0100
committerDashie <dashie@sigpipe.me>2021-02-25 11:05:12 +0100
commita2f864ff2e49dd198f5eadd147af360841644518 (patch)
tree3973e763c4f3b4aa7c0412765a83ec7aa0225346
parent37e0f874a1c6a108a8940329e2e1e534d8d328b9 (diff)
Import Authlib SQLAlchemy backend from before integrations split; Fix migrations
-rw-r--r--api/authlib_sqla.py335
-rw-r--r--api/controllers/api/v1/auth.py23
-rw-r--r--api/migrations/versions/67_7df5c87e5fef_.py98
-rw-r--r--api/migrations/versions/68_dff4edfb26b6_.py41
-rw-r--r--api/migrations/versions/f1993296be9e_.py35
-rw-r--r--api/models.py2
6 files changed, 381 insertions, 153 deletions
diff --git a/api/authlib_sqla.py b/api/authlib_sqla.py
new file mode 100644
index 00000000..4b3300f2
--- /dev/null
+++ b/api/authlib_sqla.py
@@ -0,0 +1,335 @@
+import time
+import json
+from sqlalchemy import Column, String, Boolean, Text, Integer
+from sqlalchemy.ext.hybrid import hybrid_property
+from authlib.oauth2.rfc6749 import (
+ ClientMixin,
+ TokenMixin,
+ AuthorizationCodeMixin,
+)
+from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope
+from authlib.oidc.core import AuthorizationCodeMixin as OIDCCodeMixin
+
+
+class OAuth2ClientMixin(ClientMixin):
+ client_id = Column(String(48), index=True)
+ client_secret = Column(String(120))
+ issued_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
+ expires_at = Column(Integer, nullable=False, default=0)
+
+ redirect_uri = Column(Text)
+ token_endpoint_auth_method = Column(String(48), default="client_secret_basic")
+ grant_type = Column(Text, nullable=False, default="")
+ response_type = Column(Text, nullable=False, default="")
+ scope = Column(Text, nullable=False, default="")
+
+ client_name = Column(String(100))
+ client_uri = Column(Text)
+ logo_uri = Column(Text)
+ contact = Column(Text)
+ tos_uri = Column(Text)
+ policy_uri = Column(Text)
+ jwks_uri = Column(Text)
+ jwks_text = Column(Text)
+ i18n_metadata = Column(Text)
+
+ software_id = Column(String(36))
+ software_version = Column(String(48))
+
+ def __repr__(self):
+ return "<Client: {}>".format(self.client_id)
+
+ @hybrid_property
+ def redirect_uris(self):
+ if self.redirect_uri:
+ return self.redirect_uri.splitlines()
+ return []
+
+ @redirect_uris.setter
+ def redirect_uris(self, value):
+ self.redirect_uri = "\n".join(value)
+
+ @hybrid_property
+ def grant_types(self):
+ if self.grant_type:
+ return self.grant_type.splitlines()
+ return []
+
+ @grant_types.setter
+ def grant_types(self, value):
+ self.grant_type = "\n".join(value)
+
+ @hybrid_property
+ def response_types(self):
+ if self.response_type:
+ return self.response_type.splitlines()
+ return []
+
+ @response_types.setter
+ def response_types(self, value):
+ self.response_type = "\n".join(value)
+
+ @hybrid_property
+ def contacts(self):
+ if self.contact:
+ return json.loads(self.contact)
+ return []
+
+ @contacts.setter
+ def contacts(self, value):
+ self.contact = json.dumps(value)
+
+ @hybrid_property
+ def jwks(self):
+ if self.jwks_text:
+ return json.loads(self.jwks_text)
+ return None
+
+ @jwks.setter
+ def jwks(self, value):
+ self.jwks_text = json.dumps(value)
+
+ @hybrid_property
+ def client_metadata(self):
+ """Implementation for Client Metadata in OAuth 2.0 Dynamic Client
+ Registration Protocol via `Section 2`_.
+
+ .. _`Section 2`: https://tools.ietf.org/html/rfc7591#section-2
+ """
+ keys = [
+ "redirect_uris",
+ "token_endpoint_auth_method",
+ "grant_types",
+ "response_types",
+ "client_name",
+ "client_uri",
+ "logo_uri",
+ "scope",
+ "contacts",
+ "tos_uri",
+ "policy_uri",
+ "jwks_uri",
+ "jwks",
+ ]
+ metadata = {k: getattr(self, k) for k in keys}
+ if self.i18n_metadata:
+ metadata.update(json.loads(self.i18n_metadata))
+ return metadata
+
+ @client_metadata.setter
+ def client_metadata(self, value):
+ i18n_metadata = {}
+ for k in value:
+ if hasattr(self, k):
+ setattr(self, k, value[k])
+ elif "#" in k:
+ i18n_metadata[k] = value[k]
+
+ self.i18n_metadata = json.dumps(i18n_metadata)
+
+ @property
+ def client_info(self):
+ """Implementation for Client Info in OAuth 2.0 Dynamic Client
+ Registration Protocol via `Section 3.2.1`_.
+
+ .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
+ """
+ return dict(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ client_id_issued_at=self.issued_at,
+ client_secret_expires_at=self.expires_at,
+ )
+
+ def get_client_id(self):
+ return self.client_id
+
+ def get_default_redirect_uri(self):
+ if self.redirect_uris:
+ return self.redirect_uris[0]
+
+ def get_allowed_scope(self, scope):
+ if not scope:
+ return ""
+ allowed = set(self.scope.split())
+ scopes = scope_to_list(scope)
+ return list_to_scope([s for s in scopes if s in allowed])
+
+ def check_redirect_uri(self, redirect_uri):
+ return redirect_uri in self.redirect_uris
+
+ def has_client_secret(self):
+ return bool(self.client_secret)
+
+ def check_client_secret(self, client_secret):
+ return self.client_secret == client_secret
+
+ def check_token_endpoint_auth_method(self, method):
+ return self.token_endpoint_auth_method == method
+
+ def check_response_type(self, response_type):
+ if self.response_type:
+ return response_type in self.response_types
+ return False
+
+ def check_grant_type(self, grant_type):
+ if self.grant_type:
+ return grant_type in self.grant_types
+ return False
+
+
+class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
+ code = Column(String(120), unique=True, nullable=False)
+ client_id = Column(String(48))
+ redirect_uri = Column(Text, default="")
+ response_type = Column(Text, default="")
+ scope = Column(Text, default="")
+ auth_time = Column(Integer, nullable=False, default=lambda: int(time.time()))
+
+ def is_expired(self):
+ return self.auth_time + 300 < time.time()
+
+ def get_redirect_uri(self):
+ return self.redirect_uri
+
+ def get_scope(self):
+ return self.scope
+
+ def get_auth_time(self):
+ return self.auth_time
+
+
+class OIDCAuthorizationCodeMixin(OAuth2AuthorizationCodeMixin, OIDCCodeMixin):
+ nonce = Column(Text)
+
+ def get_nonce(self):
+ return self.nonce
+
+
+class OAuth2TokenMixin(TokenMixin):
+ client_id = Column(String(48))
+ token_type = Column(String(40))
+ access_token = Column(String(255), unique=True, nullable=False)
+ refresh_token = Column(String(255), index=True)
+ scope = Column(Text, default="")
+ revoked = Column(Boolean, default=False)
+ issued_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
+ expires_in = Column(Integer, nullable=False, default=0)
+
+ def get_client_id(self):
+ return self.client_id
+
+ def get_scope(self):
+ return self.scope
+
+ def get_expires_in(self):
+ return self.expires_in
+
+ def get_expires_at(self):
+ return self.issued_at + self.expires_in
+
+
+def create_query_client_func(session, client_model):
+ """Create an ``query_client`` function that can be used in authorization
+ server.
+
+ :param session: SQLAlchemy session
+ :param client_model: Client model class
+ """
+
+ def query_client(client_id):
+ q = session.query(client_model)
+ return q.filter_by(client_id=client_id).first()
+
+ return query_client
+
+
+def create_save_token_func(session, token_model):
+ """Create an ``save_token`` function that can be used in authorization
+ server.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+
+ def save_token(token, request):
+ if request.user:
+ user_id = request.user.get_user_id()
+ else:
+ user_id = None
+ client = request.client
+ item = token_model(client_id=client.client_id, user_id=user_id, **token)
+ session.add(item)
+ session.commit()
+
+ return save_token
+
+
+def create_query_token_func(session, token_model):
+ """Create an ``query_token`` function for revocation, introspection
+ token endpoints.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+
+ def query_token(token, token_type_hint, client):
+ q = session.query(token_model)
+ q = q.filter_by(client_id=client.client_id, revoked=False)
+ if token_type_hint == "access_token":
+ return q.filter_by(access_token=token).first()
+ elif token_type_hint == "refresh_token":
+ return q.filter_by(refresh_token=token).first()
+ # without token_type_hint
+ item = q.filter_by(access_token=token).first()
+ if item:
+ return item
+ return q.filter_by(refresh_token=token).first()
+
+ return query_token
+
+
+def create_revocation_endpoint(session, token_model):
+ """Create a revocation endpoint class with SQLAlchemy session
+ and token model.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+ from authlib.oauth2.rfc7009 import RevocationEndpoint
+
+ query_token = create_query_token_func(session, token_model)
+
+ class _RevocationEndpoint(RevocationEndpoint):
+ def query_token(self, token, token_type_hint, client):
+ return query_token(token, token_type_hint, client)
+
+ def revoke_token(self, token):
+ token.revoked = True
+ session.add(token)
+ session.commit()
+
+ return _RevocationEndpoint
+
+
+def create_bearer_token_validator(session, token_model):
+ """Create an bearer token validator class with SQLAlchemy session
+ and token model.
+
+ :param session: SQLAlchemy session
+ :param token_model: Token model class
+ """
+ from authlib.oauth2.rfc6750 import BearerTokenValidator
+
+ class _BearerTokenValidator(BearerTokenValidator):
+ def authenticate_token(self, token_string):
+ q = session.query(token_model)
+ return q.filter_by(access_token=token_string).first()
+
+ def request_invalid(self, request):
+ return False
+
+ def token_revoked(self, token):
+ return token.revoked
+
+ return _BearerTokenValidator
diff --git a/api/controllers/api/v1/auth.py b/api/controllers/api/v1/auth.py
index 10261675..865bc5a6 100644
--- a/api/controllers/api/v1/auth.py
+++ b/api/controllers/api/v1/auth.py
@@ -41,23 +41,20 @@ def create_client():
return response
client = OAuth2Client()
+ client.client_name = req.get("client_name")
+ client.client_uri = req.get("website", None)
+ client.redirect_uri = req.get("redirect_uris")
+ client.scope = req.get("scopes")
client.client_id = gen_salt(24)
- metadatas = {
- "client_name": req.get("client_name"),
- "client_uri": req.get("website", None),
- "redirect_uris": [req.get("redirect_uris")], # TODO, should make a check or something
- "scope": req.get("scopes"),
- # this needs to be hardcoded for whatever reason
- "response_type": "code",
- "grant_types": ["authorization_code", "client_credentials", "password"],
- "token_endpoint_auth_method": "client_secret_post",
- }
if client.token_endpoint_auth_method == "none":
client.client_secret = ""
else:
client.client_secret = gen_salt(48)
- client.set_client_metadata(metadatas)
+ # this needs to be hardcoded for whatever reason
+ client.response_type = "code"
+ client.grant_type = "authorization_code\r\nclient_credentials\r\npassword"
+ client.token_endpoint_auth_method = "client_secret_post"
db.session.add(client)
db.session.commit()
@@ -67,7 +64,7 @@ def create_client():
"client_secret": client.client_secret,
"id": client.id,
"name": client.client_name,
- "redirect_uris": client.redirect_uris[0] if len(client.redirect_uris) else "",
+ "redirect_uri": client.redirect_uri,
"website": client.client_uri,
"vapid_key": None, # FIXME to implement this
}
@@ -140,7 +137,7 @@ def oauth_token():
d["redirect_uri"] = request.json["redirect_uri"]
request.form = ImmutableMultiDict(d)
- return authorization.create_token_response(request)
+ return authorization.create_token_response()
@bp_api_v1_auth.route("/oauth/revoke", methods=["POST"])
diff --git a/api/migrations/versions/67_7df5c87e5fef_.py b/api/migrations/versions/67_7df5c87e5fef_.py
deleted file mode 100644
index 23e99f37..00000000
--- a/api/migrations/versions/67_7df5c87e5fef_.py
+++ /dev/null
@@ -1,98 +0,0 @@
-"""oauth2_client + fs_uniquifier
-
-Revision ID: 7df5c87e5fef
-Revises: f537ac7a67d6
-Create Date: 2021-02-23 19:07:51.420492
-
-"""
-
-# revision identifiers, used by Alembic.
-revision = "7df5c87e5fef"
-down_revision = "f537ac7a67d6"
-
-from alembic import op # noqa: E402
-import sqlalchemy as sa # noqa: E402
-import uuid # noqa: E402
-
-"""
-Note: Authlib introduced major breaking change by removing thoses fields that I wasn't able to find in the changelogs
-Theses will stay commented, they have been added into the models.
-"""
-
-
-def upgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.add_column("oauth2_client", sa.Column("client_id_issued_at", sa.Integer(), nullable=False))
- op.add_column("oauth2_client", sa.Column("client_metadata", sa.Text(), nullable=True))
- op.add_column("oauth2_client", sa.Column("client_secret_expires_at", sa.Integer(), nullable=False))
- op.drop_column("oauth2_client", "policy_uri")
- op.drop_column("oauth2_client", "tos_uri")
- # op.drop_column("oauth2_client", "client_name")
- # op.drop_column("oauth2_client", "response_type")
- op.drop_column("oauth2_client", "expires_at")
- op.drop_column("oauth2_client", "software_id")
- # op.drop_column("oauth2_client", "redirect_uri")
- op.drop_column("oauth2_client", "software_version")
- # op.drop_column("oauth2_client", "client_uri")
- op.drop_column("oauth2_client", "contact")
- op.drop_column("oauth2_client", "i18n_metadata")
- op.drop_column("oauth2_client", "logo_uri")
- # op.drop_column("oauth2_client", "grant_type")
- op.drop_column("oauth2_client", "jwks_uri")
- op.drop_column("oauth2_client", "issued_at")
- # op.drop_column("oauth2_client", "scope")
- op.drop_column("oauth2_client", "jwks_text")
- # op.drop_column("oauth2_client", "token_endpoint_auth_method")
- op.add_column("oauth2_code", sa.Column("code_challenge", sa.Text(), nullable=True))
- op.add_column("oauth2_code", sa.Column("code_challenge_method", sa.String(length=48), nullable=True))
- op.add_column("oauth2_code", sa.Column("nonce", sa.Text(), nullable=True))
-
- # fs_uniquifier
- op.add_column("user", sa.Column("fs_uniquifier", sa.String(length=255), nullable=True))
- user_table = sa.Table(
- "user", sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True), sa.Column("fs_uniquifier", sa.String)
- )
- conn = op.get_bind()
- for row in conn.execute(sa.select([user_table.c.id])):
- conn.execute(user_table.update().values(fs_uniquifier=uuid.uuid4().hex).where(user_table.c.id == row["id"]))
- op.alter_column("user", "fs_uniquifier", nullable=False)
-
- op.create_unique_constraint(None, "user", ["fs_uniquifier"])
- # end fs_uniquifier
- # ### end Alembic commands ###
-
-
-def downgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_constraint(None, "user", type_="unique")
- op.drop_column("user", "fs_uniquifier")
- op.drop_column("oauth2_code", "nonce")
- op.drop_column("oauth2_code", "code_challenge_method")
- op.drop_column("oauth2_code", "code_challenge")
- # op.add_column(
- # "oauth2_client",
- # sa.Column("token_endpoint_auth_method", sa.VARCHAR(length=48), autoincrement=False, nullable=True),
- # )
- op.add_column("oauth2_client", sa.Column("jwks_text", sa.TEXT(), autoincrement=False, nullable=True))
- # op.add_column("oauth2_client", sa.Column("scope", sa.TEXT(), autoincrement=False, nullable=False))
- op.add_column("oauth2_client", sa.Column("issued_at", sa.INTEGER(), autoincrement=False, nullable=False))
- op.add_column("oauth2_client", sa.Column("jwks_uri", sa.TEXT(), autoincrement=False, nullable=True))
- # op.add_column("oauth2_client", sa.Column("grant_type", sa.TEXT(), autoincrement=False, nullable=False))
- op.add_column("oauth2_client", sa.Column("logo_uri", sa.TEXT(), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("i18n_metadata", sa.TEXT(), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("contact", sa.TEXT(), autoincrement=False, nullable=True))
- # op.add_column("oauth2_client", sa.Column("client_uri", sa.TEXT(), autoincrement=False, nullable=True))
- op.add_column(
- "oauth2_client", sa.Column("software_version", sa.VARCHAR(length=48), autoincrement=False, nullable=True)
- )
- # op.add_column("oauth2_client", sa.Column("redirect_uri", sa.TEXT(), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("software_id", sa.VARCHAR(length=36), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("expires_at", sa.INTEGER(), autoincrement=False, nullable=False))
- # op.add_column("oauth2_client", sa.Column("response_type", sa.TEXT(), autoincrement=False, nullable=False))
- # op.add_column("oauth2_client", sa.Column("client_name", sa.VARCHAR(length=100), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("tos_uri", sa.TEXT(), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("policy_uri", sa.TEXT(), autoincrement=False, nullable=True))
- op.drop_column("oauth2_client", "client_secret_expires_at")
- op.drop_column("oauth2_client", "client_metadata")
- op.drop_column("oauth2_client", "client_id_issued_at")
- # ### end Alembic commands ###
diff --git a/api/migrations/versions/68_dff4edfb26b6_.py b/api/migrations/versions/68_dff4edfb26b6_.py
deleted file mode 100644
index 8d74d34d..00000000
--- a/api/migrations/versions/68_dff4edfb26b6_.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""Migrate to client_metadata
-
-Revision ID: dff4edfb26b6
-Revises: 7df5c87e5fef
-Create Date: 2021-02-23 22:59:23.112766
-
-"""
-
-# revision identifiers, used by Alembic.
-revision = "dff4edfb26b6"
-down_revision = "7df5c87e5fef"
-
-from alembic import op # noqa: E402
-import sqlalchemy as sa # noqa: E402
-
-
-def upgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_column("oauth2_client", "redirect_uri")
- op.drop_column("oauth2_client", "client_uri")
- op.drop_column("oauth2_client", "response_type")
- op.drop_column("oauth2_client", "client_name")
- op.drop_column("oauth2_client", "grant_type")
- op.drop_column("oauth2_client", "token_endpoint_auth_method")
- op.drop_column("oauth2_client", "scope")
- # ### end Alembic commands ###
-
-
-def downgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.add_column("oauth2_client", sa.Column("scope", sa.TEXT(), autoincrement=False, nullable=False))
- op.add_column(
- "oauth2_client",
- sa.Column("token_endpoint_auth_method", sa.VARCHAR(length=48), autoincrement=False, nullable=True),
- )
- op.add_column("oauth2_client", sa.Column("grant_type", sa.TEXT(), autoincrement=False, nullable=False))
- op.add_column("oauth2_client", sa.Column("client_name", sa.VARCHAR(length=100), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("response_type", sa.TEXT(), autoincrement=False, nullable=False))
- op.add_column("oauth2_client", sa.Column("client_uri", sa.TEXT(), autoincrement=False, nullable=True))
- op.add_column("oauth2_client", sa.Column("redirect_uri", sa.TEXT(), autoincrement=False, nullable=True))
- # ### end Alembic commands ###
diff --git a/api/migrations/versions/f1993296be9e_.py b/api/migrations/versions/f1993296be9e_.py
new file mode 100644
index 00000000..ac8e69ab
--- /dev/null
+++ b/api/migrations/versions/f1993296be9e_.py
@@ -0,0 +1,35 @@
+"""Flask-Security fs_uniquifier migration
+
+Revision ID: f1993296be9e
+Revises: f537ac7a67d6
+Create Date: 2021-02-25 10:50:29.973424
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "f1993296be9e"
+down_revision = "f537ac7a67d6"
+
+from alembic import op # noqa: E402
+import sqlalchemy as sa # noqa: E402
+import uuid # noqa: E402
+
+
+def upgrade():
+ op.add_column("user", sa.Column("fs_uniquifier", sa.String(length=255), nullable=True))
+ user_table = sa.Table(
+ "user", sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True), sa.Column("fs_uniquifier", sa.String)
+ )
+ conn = op.get_bind()
+ for row in conn.execute(sa.select([user_table.c.id])):
+ conn.execute(user_table.update().values(fs_uniquifier=uuid.uuid4().hex).where(user_table.c.id == row["id"]))
+ op.alter_column("user", "fs_uniquifier", nullable=False)
+
+ op.create_unique_constraint(None, "user", ["fs_uniquifier"])
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint(None, "user", type_="unique")
+ op.drop_column("user", "fs_uniquifier")
+ # ### end Alembic commands ###
diff --git a/api/models.py b/api/models.py
index 7dd2699c..62dc2755 100644
--- a/api/models.py
+++ b/api/models.py
@@ -19,7 +19,7 @@ from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy import text as sa_text
from little_boxes import activitypub as ap
from urllib.parse import urlparse
-from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin, OAuth2AuthorizationCodeMixin, OAuth2TokenMixin
+from authlib_sqla import OAuth2ClientMixin, OAuth2AuthorizationCodeMixin, OAuth2TokenMixin
import time
import uuid
from utils.defaults import Reel2bitsDefaults