diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 400df138b8..1aaa5d3a62 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,5 +1,3 @@ -from typing import Any - import flask_login from flask import make_response, request from flask_restx import Resource @@ -42,7 +40,7 @@ from libs.token import ( set_csrf_token_to_cookie, set_refresh_token_to_cookie, ) -from services.account_service import AccountService, RegisterService, TenantService +from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError @@ -101,7 +99,7 @@ class LoginApi(Resource): raise EmailPasswordLoginLimitError() invite_token = args.invite_token - invitation_data: dict[str, Any] | None = None + invitation_data: InvitationDetailDict | None = None if invite_token: invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) if invitation_data is None: diff --git a/api/services/account_service.py b/api/services/account_service.py index ee4c199df8..28c736a1e9 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -83,6 +83,12 @@ from tasks.mail_reset_password_task import ( logger = logging.getLogger(__name__) +class InvitationDetailDict(TypedDict): + account: Account + data: InvitationData + tenant: Tenant + + def _try_join_enterprise_default_workspace(account_id: str) -> None: """Best-effort join to enterprise default workspace.""" if not dify_config.ENTERPRISE_ENABLED: @@ -1585,7 +1591,7 @@ class RegisterService: @classmethod def get_invitation_if_token_valid( cls, workspace_id: str | None, email: str | None, token: str - ) -> dict[str, Any] | None: + ) -> InvitationDetailDict | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1647,7 +1653,7 @@ class RegisterService: @classmethod def get_invitation_with_case_fallback( cls, workspace_id: str | None, email: str | None, token: str - ) -> dict[str, Any] | None: + ) -> InvitationDetailDict | None: invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) if invitation or not email or email == email.lower(): return invitation