From 7515eee0a8f241348feaba7b00eb5085f35223ec Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Sun, 12 Apr 2026 03:21:52 +0200 Subject: [PATCH] refactor(services): migrate dataset_service and clear_free_plan_tenant_expired_logs to SQLAlchemy 2.0 select() API (#34970) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../clear_free_plan_tenant_expired_logs.py | 76 +++++----- api/services/dataset_service.py | 14 +- ...est_clear_free_plan_tenant_expired_logs.py | 134 +++++------------- .../services/test_dataset_service_dataset.py | 4 +- .../services/test_dataset_service_document.py | 2 +- 5 files changed, 90 insertions(+), 140 deletions(-) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index b0f7efaccd..ea12e40420 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app from graphon.model_runtime.utils.encoders import jsonable_encoder -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs: for model, table_name in related_tables: # Query records related to expired messages - records = ( - session.query(model) - .where( + records = session.scalars( + select(model).where( model.message_id.in_(batch_message_ids), # type: ignore ) - .all() - ) + ).all() if len(records) == 0: continue @@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs: except Exception: logger.exception("Failed to save %s records", table_name) - session.query(model).where( - model.id.in_(record_ids), # type: ignore - ).delete(synchronize_session=False) + session.execute( + delete(model) + .where( + model.id.in_(record_ids), # type: ignore + ) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs: app_ids = [app.id for app in apps] while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - messages = ( - session.query(Message) + messages = session.scalars( + select(Message) .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(messages) == 0: break @@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs: message_ids = [message.id for message in messages] # delete messages - session.query(Message).where( - Message.id.in_(message_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False) + ) cls._clear_message_related_tables(session, tenant_id, message_ids) @@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs: while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - conversations = ( - session.query(Conversation) + conversations = session.scalars( + select(Conversation) .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(conversations) == 0: break @@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs: ) conversation_ids = [conversation.id for conversation in conversations] - session.query(Conversation).where( - Conversation.id.in_(conversation_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Conversation) + .where(Conversation.id.in_(conversation_ids)) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs: while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - workflow_app_logs = ( - session.query(WorkflowAppLog) + workflow_app_logs = session.scalars( + select(WorkflowAppLog) .where( WorkflowAppLog.tenant_id == tenant_id, WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(workflow_app_logs) == 0: break @@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs: workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] # delete workflow app logs - session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( - synchronize_session=False + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id.in_(workflow_app_log_ids)) + .execution_options(synchronize_session=False) ) click.echo( @@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs: current_time = started_at with sessionmaker(db.engine).begin() as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 9c71902849..b2920c1006 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -552,8 +552,8 @@ class DatasetService: external_knowledge_api_id: External knowledge API identifier """ with sessionmaker(db.engine).begin() as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + external_knowledge_binding = session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1) ) if not external_knowledge_binding: @@ -1454,15 +1454,17 @@ class DocumentService: document_id_list: list[str] = [str(document_id) for document_id in document_ids] with session_factory.create_session() as session: - updated_count = ( - session.query(Document) - .filter( + result = session.execute( + update(Document) + .where( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) - .update({Document.need_summary: need_summary}, synchronize_session=False) + .values(need_summary=need_summary) + .execution_options(synchronize_session=False) ) + updated_count = result.rowcount # type: ignore[union-attr,attr-defined] session.commit() logger.info( "Updated need_summary to %s for %d documents in dataset %s", diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index 3e989c55a3..1bbd214110 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -17,8 +17,7 @@ class TestClearFreePlanTenantExpiredLogs: def mock_session(self): """Create a mock database session.""" session = Mock(spec=Session) - session.query.return_value.filter.return_value.all.return_value = [] - session.query.return_value.filter.return_value.delete.return_value = 0 + session.scalars.return_value.all.return_value = [] return session @pytest.fixture @@ -54,18 +53,18 @@ class TestClearFreePlanTenantExpiredLogs: ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", []) # Should not call any database operations - mock_session.query.assert_not_called() + mock_session.scalars.assert_not_called() mock_storage.save.assert_not_called() def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): """Test when no related records are found.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - # Should call query for each related table but find no records - assert mock_session.query.call_count > 0 + # Should call scalars for each related table but find no records + assert mock_session.scalars.call_count > 0 mock_storage.save.assert_not_called() def test_clear_message_related_tables_with_records_and_to_dict( @@ -73,7 +72,7 @@ class TestClearFreePlanTenantExpiredLogs: ): """Test when records are found and have to_dict method.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.scalars.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) @@ -104,7 +103,7 @@ class TestClearFreePlanTenantExpiredLogs: records.append(record) # Mock records for first table only, empty for others - mock_session.query.return_value.where.return_value.all.side_effect = [ + mock_session.scalars.return_value.all.side_effect = [ records, [], [], @@ -126,13 +125,13 @@ class TestClearFreePlanTenantExpiredLogs: with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: mock_storage.save.side_effect = Exception("Storage error") - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.scalars.return_value.all.return_value = sample_records # Should not raise exception ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should still delete records even if backup fails - assert mock_session.query.return_value.where.return_value.delete.called + assert mock_session.execute.called def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): """Test that method continues even when record serialization fails.""" @@ -141,23 +140,23 @@ class TestClearFreePlanTenantExpiredLogs: record.id = "record-1" record.to_dict.side_effect = Exception("Serialization error") - mock_session.query.return_value.where.return_value.all.return_value = [record] + mock_session.scalars.return_value.all.return_value = [record] # Should not raise exception ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should still delete records even if serialization fails - assert mock_session.query.return_value.where.return_value.delete.called + assert mock_session.execute.called def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): """Test that deletion is called for found records.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.scalars.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - # Should call delete for each table that has records - assert mock_session.query.return_value.where.return_value.delete.called + # Should call execute(delete(...)) for each table that has records + assert mock_session.execute.called def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes( self, mock_session, sample_message_ids @@ -167,12 +166,12 @@ class TestClearFreePlanTenantExpiredLogs: record.to_dict.side_effect = Exception("Serialization error") with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = [record] + mock_session.scalars.return_value.all.return_value = [record] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) mock_storage.save.assert_not_called() - assert mock_session.query.return_value.where.return_value.delete.called + assert mock_session.execute.called class _ImmediateFuture: @@ -263,42 +262,23 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"}) log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"}) - def make_query_with_batches(batches: list[list[object]]): - q = MagicMock() - q.where.return_value = q - q.limit.return_value = q - q.all.side_effect = batches - q.delete.return_value = 1 - return q - msg_session_1 = MagicMock() - msg_session_1.query.side_effect = lambda model: ( - make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() - ) + msg_session_1.scalars.return_value.all.return_value = [msg1] + msg_session_2 = MagicMock() - msg_session_2.query.side_effect = lambda model: ( - make_query_with_batches([[]]) if model == service_module.Message else MagicMock() - ) + msg_session_2.scalars.return_value.all.return_value = [] conv_session_1 = MagicMock() - conv_session_1.query.side_effect = lambda model: ( - make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() - ) + conv_session_1.scalars.return_value.all.return_value = [conv1] conv_session_2 = MagicMock() - conv_session_2.query.side_effect = lambda model: ( - make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() - ) + conv_session_2.scalars.return_value.all.return_value = [] wal_session_1 = MagicMock() - wal_session_1.query.side_effect = lambda model: ( - make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() - ) + wal_session_1.scalars.return_value.all.return_value = [log1] wal_session_2 = MagicMock() - wal_session_2.query.side_effect = lambda model: ( - make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() - ) + wal_session_2.scalars.return_value.all.return_value = [] session_wrappers = [ _sessionmaker_wrapper_for_begin(msg_session_1), @@ -354,9 +334,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py # Total tenant count query count_session = MagicMock() - count_query = MagicMock() - count_query.count.return_value = 2 - count_session.query.return_value = count_query + count_session.scalar.return_value = 2 monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session)) @@ -421,32 +399,15 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt # Sessions used: # 1) total tenant count - # 2) per-batch tenant scan (count + tenant list) + # 2) per-batch tenant scan (interval counts + tenant list) total_session = MagicMock() - total_query = MagicMock() - total_query.count.return_value = 250 - total_session.query.return_value = total_query - - batch_session = MagicMock() - q1 = MagicMock() - q1.where.return_value = q1 - q1.count.return_value = 200 - q2 = MagicMock() - q2.where.return_value = q2 - q2.count.return_value = 200 - q3 = MagicMock() - q3.where.return_value = q3 - q3.count.return_value = 200 - q4 = MagicMock() - q4.where.return_value = q4 - q4.count.return_value = 50 # choose this interval, then scale it + total_session.scalar.return_value = 250 rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")] - q_rs = MagicMock() - q_rs.where.return_value = q_rs - q_rs.order_by.return_value = rows - - batch_session.query.side_effect = [q1, q2, q3, q4, q_rs] + batch_session = MagicMock() + # 4 test intervals queried: 200, 200, 200, 50 — breaks on 50 <= 100 (4th interval = 3h) + batch_session.scalar.side_effect = [200, 200, 200, 50] + batch_session.execute.return_value = rows sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)] monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0)) @@ -464,9 +425,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) count_session = MagicMock() - count_query = MagicMock() - count_query.count.return_value = 100 - count_session.query.return_value = count_query + count_session.scalar.return_value = 100 monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session)) flask_app = service_module.Flask("test-app") @@ -513,25 +472,13 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) total_session = MagicMock() - total_query = MagicMock() - total_query.count.return_value = 250 - total_session.query.return_value = total_query - - batch_session = MagicMock() - # Count results for all 5 intervals, all > 100 => take the for-else path. - count_queries = [] - for _ in range(5): - q = MagicMock() - q.where.return_value = q - q.count.return_value = 200 - count_queries.append(q) + total_session.scalar.return_value = 250 rows = [SimpleNamespace(id="tenant-a")] - q_rs = MagicMock() - q_rs.where.return_value = q_rs - q_rs.order_by.return_value = rows - - batch_session.query.side_effect = [*count_queries, q_rs] + batch_session = MagicMock() + # All 5 intervals have > 100 tenants => for-else falls through to min interval (1h) + batch_session.scalar.side_effect = [200, 200, 200, 200, 200] + batch_session.execute.return_value = rows sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)] monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0)) @@ -542,8 +489,7 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) assert process_tenant_mock.call_count == 1 - assert len(count_queries) == 5 - assert batch_session.query.call_count >= 6 + assert batch_session.scalar.call_count == 5 def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None: @@ -565,11 +511,7 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte # Make message/conversation/workflow_app_log loops no-op (empty immediately) empty_session = MagicMock() - q_empty = MagicMock() - q_empty.where.return_value = q_empty - q_empty.limit.return_value = q_empty - q_empty.all.return_value = [] - empty_session.query.return_value = q_empty + empty_session.scalars.return_value.all.return_value = [] session_wrappers = [ _sessionmaker_wrapper_for_begin(empty_session), _sessionmaker_wrapper_for_begin(empty_session), diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index b2c40763ea..c65ce24b3c 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -577,7 +577,7 @@ class TestDatasetServiceCreationAndUpdate: def test_update_external_knowledge_binding_updates_changed_binding_values(self): binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api") session = MagicMock() - session.query.return_value.filter_by.return_value.first.return_value = binding + session.scalar.return_value = binding session.add = MagicMock() session_context = _make_session_context(session) @@ -596,7 +596,7 @@ class TestDatasetServiceCreationAndUpdate: def test_update_external_knowledge_binding_raises_for_missing_binding(self): session = MagicMock() - session.query.return_value.filter_by.return_value.first.return_value = None + session.scalar.return_value = None session_context = _make_session_context(session) mock_sessionmaker = MagicMock() diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index 9b4734b7ad..3f9386e704 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -129,7 +129,7 @@ class TestDocumentServiceQueryAndDownloadHelpers: def test_update_documents_need_summary_updates_matching_documents_and_commits(self): session = MagicMock() - session.query.return_value.filter.return_value.update.return_value = 2 + session.execute.return_value.rowcount = 2 with patch("services.dataset_service.session_factory") as session_factory_mock: session_factory_mock.create_session.return_value = _make_session_context(session)