mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 23:40:16 +08:00
test: implement Account/Tenant model integration tests to replace db-mocked unit tests (#34994)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,79 +1,202 @@
|
||||
# import secrets
|
||||
"""
|
||||
Integration tests for Account and Tenant model methods that interact with the database.
|
||||
|
||||
# import pytest
|
||||
# from sqlalchemy import select
|
||||
# from sqlalchemy.orm import Session
|
||||
# from sqlalchemy.orm.exc import DetachedInstanceError
|
||||
Migrated from unit_tests/models/test_account_models.py, replacing
|
||||
@patch("models.account.db") mock patches with real PostgreSQL operations.
|
||||
|
||||
# from libs.datetime_utils import naive_utc_now
|
||||
# from models.account import Account, Tenant, TenantAccountJoin
|
||||
Covers:
|
||||
- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin)
|
||||
- Account.set_tenant_id (resolves tenant + role from real join row)
|
||||
- Account.get_by_openid (AccountIntegrate lookup then Account fetch)
|
||||
- Tenant.get_accounts (returns accounts linked via TenantAccountJoin)
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(db_session_with_containers):
|
||||
# with Session(db_session_with_containers.get_bind()) as session:
|
||||
# yield session
|
||||
def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None:
|
||||
"""Delete rows tracked during the test so committed state does not leak into the DB.
|
||||
|
||||
Rolls back any pending (uncommitted) session state first, then issues DELETE
|
||||
statements by primary key for each tracked entity (in reverse creation order)
|
||||
and commits. This cleans up rows created via either flush() or commit().
|
||||
"""
|
||||
db_session.rollback()
|
||||
for entity in reversed(tracked):
|
||||
db_session.execute(delete(type(entity)).where(type(entity).id == entity.id))
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def account(session):
|
||||
# account = Account(
|
||||
# name="test account",
|
||||
# email=f"test_{secrets.token_hex(8)}@example.com",
|
||||
# )
|
||||
# session.add(account)
|
||||
# session.commit()
|
||||
# return account
|
||||
def _build_tenant() -> Tenant:
|
||||
return Tenant(name=f"Tenant {uuid4()}")
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def tenant(session):
|
||||
# tenant = Tenant(name="test tenant")
|
||||
# session.add(tenant)
|
||||
# session.commit()
|
||||
# return tenant
|
||||
def _build_account(email_prefix: str = "account") -> Account:
|
||||
return Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"{email_prefix}_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def tenant_account_join(session, account, tenant):
|
||||
# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id)
|
||||
# session.add(tenant_join)
|
||||
# session.commit()
|
||||
# yield tenant_join
|
||||
# session.delete(tenant_join)
|
||||
# session.commit()
|
||||
class _DBTrackingTestBase:
|
||||
"""Base class providing a tracker list and shared row factories for account/tenant tests."""
|
||||
|
||||
_tracked: list
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
self._tracked = []
|
||||
yield
|
||||
_cleanup_tracked_rows(db_session_with_containers, self._tracked)
|
||||
|
||||
def _create_tenant(self, db_session: Session) -> Tenant:
|
||||
tenant = _build_tenant()
|
||||
db_session.add(tenant)
|
||||
db_session.flush()
|
||||
self._tracked.append(tenant)
|
||||
return tenant
|
||||
|
||||
def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account:
|
||||
account = _build_account(email_prefix)
|
||||
db_session.add(account)
|
||||
db_session.flush()
|
||||
self._tracked.append(account)
|
||||
return account
|
||||
|
||||
def _create_join(
|
||||
self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True
|
||||
) -> TenantAccountJoin:
|
||||
join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current)
|
||||
db_session.add(join)
|
||||
db_session.flush()
|
||||
self._tracked.append(join)
|
||||
return join
|
||||
|
||||
|
||||
# class TestAccountTenant:
|
||||
# def test_set_current_tenant_should_reload_tenant(
|
||||
# self,
|
||||
# db_session_with_containers,
|
||||
# account,
|
||||
# tenant,
|
||||
# tenant_account_join,
|
||||
# ):
|
||||
# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session:
|
||||
# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one()
|
||||
# account.current_tenant = scoped_tenant
|
||||
# scoped_tenant.created_at = naive_utc_now()
|
||||
# # session.commit()
|
||||
class TestAccountCurrentTenantSetter(_DBTrackingTestBase):
|
||||
"""Integration tests for Account.current_tenant property setter."""
|
||||
|
||||
# # Ensure the tenant used in assignment is detached.
|
||||
# with pytest.raises(DetachedInstanceError):
|
||||
# _ = scoped_tenant.name
|
||||
def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None:
|
||||
"""current_tenant getter returns the in-memory _current_tenant without DB access."""
|
||||
account = self._create_account(db_session_with_containers)
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account._current_tenant = tenant
|
||||
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
assert account.current_tenant is tenant
|
||||
|
||||
# def test_set_tenant_id_should_load_tenant_as_not_expire(
|
||||
# self,
|
||||
# flask_app_with_containers,
|
||||
# account,
|
||||
# tenant,
|
||||
# tenant_account_join,
|
||||
# ):
|
||||
# with flask_app_with_containers.test_request_context():
|
||||
# account.set_tenant_id(tenant.id)
|
||||
def test_current_tenant_setter_sets_tenant_and_role_when_join_exists(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""Setting current_tenant loads the join row and assigns role when relationship exists."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
account.current_tenant = tenant
|
||||
|
||||
assert account._current_tenant is not None
|
||||
assert account._current_tenant.id == tenant.id
|
||||
assert account.role == TenantAccountRole.OWNER
|
||||
|
||||
def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None:
|
||||
"""Setting current_tenant results in _current_tenant=None when no join row exists."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
|
||||
assert account._current_tenant is None
|
||||
|
||||
|
||||
class TestAccountSetTenantId(_DBTrackingTestBase):
|
||||
"""Integration tests for Account.set_tenant_id method."""
|
||||
|
||||
def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""set_tenant_id loads the tenant and assigns role when a join row exists."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
assert account._current_tenant is not None
|
||||
assert account._current_tenant.id == tenant.id
|
||||
assert account.role == TenantAccountRole.ADMIN
|
||||
|
||||
def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""set_tenant_id does nothing when no join row matches the tenant."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account = self._create_account(db_session_with_containers)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
assert account._current_tenant is None
|
||||
|
||||
|
||||
class TestAccountGetByOpenId(_DBTrackingTestBase):
|
||||
"""Integration tests for Account.get_by_openid class method."""
|
||||
|
||||
def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None:
|
||||
"""get_by_openid returns the Account when a matching AccountIntegrate row exists."""
|
||||
account = self._create_account(db_session_with_containers, email_prefix="openid")
|
||||
provider = "google"
|
||||
open_id = f"google_{uuid4()}"
|
||||
|
||||
integrate = AccountIntegrate(
|
||||
account_id=account.id,
|
||||
provider=provider,
|
||||
open_id=open_id,
|
||||
encrypted_token="token",
|
||||
)
|
||||
db_session_with_containers.add(integrate)
|
||||
db_session_with_containers.flush()
|
||||
self._tracked.append(integrate)
|
||||
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == account.id
|
||||
|
||||
def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None:
|
||||
"""get_by_openid returns None when no AccountIntegrate row matches."""
|
||||
result = Account.get_by_openid("github", f"github_{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTenantGetAccounts(_DBTrackingTestBase):
|
||||
"""Integration tests for Tenant.get_accounts method."""
|
||||
|
||||
def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None:
|
||||
"""get_accounts returns all accounts linked to the tenant via TenantAccountJoin."""
|
||||
tenant = self._create_tenant(db_session_with_containers)
|
||||
account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member")
|
||||
account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member")
|
||||
self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False)
|
||||
self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False)
|
||||
|
||||
accounts = tenant.get_accounts()
|
||||
|
||||
assert len(accounts) == 2
|
||||
account_ids = {a.id for a in accounts}
|
||||
assert account1.id in account_ids
|
||||
assert account2.id in account_ids
|
||||
|
||||
@@ -12,7 +12,6 @@ This test suite covers:
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -310,90 +309,6 @@ class TestAccountStatusTransitions:
|
||||
class TestTenantRelationshipIntegrity:
|
||||
"""Test suite for tenant relationship integrity."""
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_property(self, mock_db):
|
||||
"""Test the current_tenant property getter."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
account._current_tenant = tenant
|
||||
|
||||
# Act
|
||||
result = account.current_tenant
|
||||
|
||||
# Assert
|
||||
assert result == tenant
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_setter_with_valid_tenant(self, mock_db, mock_session_class):
|
||||
"""Test setting current_tenant with a valid tenant relationship."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock TenantAccountJoin query result
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
mock_session.scalar.return_value = tenant_join
|
||||
|
||||
# Mock Tenant query result
|
||||
mock_session.scalars.return_value.one.return_value = tenant
|
||||
|
||||
# Act
|
||||
account.current_tenant = tenant
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant == tenant
|
||||
assert account.role == TenantAccountRole.OWNER
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_setter_without_relationship(self, mock_db, mock_session_class):
|
||||
"""Test setting current_tenant when no relationship exists."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock no TenantAccountJoin found
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
account.current_tenant = tenant
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant is None
|
||||
|
||||
def test_account_current_tenant_id_property(self):
|
||||
"""Test the current_tenant_id property."""
|
||||
# Arrange
|
||||
@@ -418,61 +333,6 @@ class TestTenantRelationshipIntegrity:
|
||||
# Assert
|
||||
assert tenant_id_none is None
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_set_tenant_id_method(self, mock_db, mock_session_class):
|
||||
"""Test the set_tenant_id method."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.ADMIN,
|
||||
)
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.first.return_value = (tenant, tenant_join)
|
||||
|
||||
# Act
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant == tenant
|
||||
assert account.role == TenantAccountRole.ADMIN
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_set_tenant_id_with_no_relationship(self, mock_db, mock_session_class):
|
||||
"""Test set_tenant_id when no relationship exists."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
account.set_tenant_id(tenant_id)
|
||||
|
||||
# Assert - should not set tenant when no relationship exists
|
||||
# The method returns early without setting _current_tenant
|
||||
|
||||
|
||||
class TestAccountRolePermissions:
|
||||
"""Test suite for account role permissions."""
|
||||
@@ -605,51 +465,6 @@ class TestAccountRolePermissions:
|
||||
assert current_role == TenantAccountRole.EDITOR
|
||||
|
||||
|
||||
class TestAccountGetByOpenId:
|
||||
"""Test suite for get_by_openid class method."""
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_get_by_openid_success(self, mock_db):
|
||||
"""Test successful retrieval of account by OpenID."""
|
||||
# Arrange
|
||||
provider = "google"
|
||||
open_id = "google_user_123"
|
||||
account_id = str(uuid4())
|
||||
|
||||
mock_account_integrate = MagicMock()
|
||||
mock_account_integrate.account_id = account_id
|
||||
|
||||
mock_account = Account(name="Test User", email="test@example.com")
|
||||
mock_account.id = account_id
|
||||
|
||||
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
|
||||
# Mock db.session.scalar() for Account lookup
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_account
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_get_by_openid_not_found(self, mock_db):
|
||||
"""Test get_by_openid when account integrate doesn't exist."""
|
||||
# Arrange
|
||||
provider = "github"
|
||||
open_id = "github_user_456"
|
||||
|
||||
# Mock db.session.execute().scalar_one_or_none() to return None
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTenantAccountJoinModel:
|
||||
"""Test suite for TenantAccountJoin model."""
|
||||
|
||||
@@ -760,31 +575,6 @@ class TestTenantModel:
|
||||
# Assert
|
||||
assert tenant.custom_config == '{"feature1": true, "feature2": "value"}'
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_tenant_get_accounts(self, mock_db):
|
||||
"""Test getting accounts associated with a tenant."""
|
||||
# Arrange
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
account1 = Account(name="User 1", email="user1@example.com")
|
||||
account1.id = str(uuid4())
|
||||
account2 = Account(name="User 2", email="user2@example.com")
|
||||
account2.id = str(uuid4())
|
||||
|
||||
# Mock the query chain
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = [account1, account2]
|
||||
mock_db.session.scalars.return_value = mock_scalars
|
||||
|
||||
# Act
|
||||
accounts = tenant.get_accounts()
|
||||
|
||||
# Assert
|
||||
assert len(accounts) == 2
|
||||
assert account1 in accounts
|
||||
assert account2 in accounts
|
||||
|
||||
|
||||
class TestTenantStatusEnum:
|
||||
"""Test suite for TenantStatus enum."""
|
||||
|
||||
Reference in New Issue
Block a user