refactor(services): migrate dataset_service and clear_free_plan_tenant_expired_logs to SQLAlchemy 2.0 select() API (#34970)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wdeveloper16
2026-04-12 03:21:52 +02:00
committed by GitHub
parent 452067db19
commit 7515eee0a8
5 changed files with 90 additions and 140 deletions

View File

@@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs:
for model, table_name in related_tables:
# Query records related to expired messages
records = (
session.query(model)
.where(
records = session.scalars(
select(model).where(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
)
).all()
if len(records) == 0:
continue
@@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs:
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).where(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
session.execute(
delete(model)
.where(
model.id.in_(record_ids), # type: ignore
)
.execution_options(synchronize_session=False)
)
click.echo(
click.style(
@@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs:
app_ids = [app.id for app in apps]
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
messages = (
session.query(Message)
messages = session.scalars(
select(Message)
.where(
Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(messages) == 0:
break
@@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs:
message_ids = [message.id for message in messages]
# delete messages
session.query(Message).where(
Message.id.in_(message_ids),
).delete(synchronize_session=False)
session.execute(
delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False)
)
cls._clear_message_related_tables(session, tenant_id, message_ids)
@@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
conversations = (
session.query(Conversation)
conversations = session.scalars(
select(Conversation)
.where(
Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(conversations) == 0:
break
@@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs:
)
conversation_ids = [conversation.id for conversation in conversations]
session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.execute(
delete(Conversation)
.where(Conversation.id.in_(conversation_ids))
.execution_options(synchronize_session=False)
)
click.echo(
click.style(
@@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
workflow_app_logs = session.scalars(
select(WorkflowAppLog)
.where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(workflow_app_logs) == 0:
break
@@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs:
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
session.execute(
delete(WorkflowAppLog)
.where(WorkflowAppLog.id.in_(workflow_app_log_ids))
.execution_options(synchronize_session=False)
)
click.echo(
@@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs:
current_time = started_at
with sessionmaker(db.engine).begin() as session:
total_tenant_count = session.query(Tenant.id).count()
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs:
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
session.scalar(
select(func.count(Tenant.id)).where(
Tenant.created_at.between(current_time, current_time + test_interval)
)
)
or 0
)
if tenant_count <= 100:
interval = test_interval
@@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs:
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
rs = session.execute(
select(Tenant.id)
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)

View File

@@ -552,8 +552,8 @@ class DatasetService:
external_knowledge_api_id: External knowledge API identifier
"""
with sessionmaker(db.engine).begin() as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
external_knowledge_binding = session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1)
)
if not external_knowledge_binding:
@@ -1454,15 +1454,17 @@ class DocumentService:
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
with session_factory.create_session() as session:
updated_count = (
session.query(Document)
.filter(
result = session.execute(
update(Document)
.where(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
.values(need_summary=need_summary)
.execution_options(synchronize_session=False)
)
updated_count = result.rowcount # type: ignore[union-attr,attr-defined]
session.commit()
logger.info(
"Updated need_summary to %s for %d documents in dataset %s",

View File

@@ -17,8 +17,7 @@ class TestClearFreePlanTenantExpiredLogs:
def mock_session(self):
"""Create a mock database session."""
session = Mock(spec=Session)
session.query.return_value.filter.return_value.all.return_value = []
session.query.return_value.filter.return_value.delete.return_value = 0
session.scalars.return_value.all.return_value = []
return session
@pytest.fixture
@@ -54,18 +53,18 @@ class TestClearFreePlanTenantExpiredLogs:
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
# Should not call any database operations
mock_session.query.assert_not_called()
mock_session.scalars.assert_not_called()
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
"""Test when no related records are found."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call query for each related table but find no records
assert mock_session.query.call_count > 0
# Should call scalars for each related table but find no records
assert mock_session.scalars.call_count > 0
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_with_records_and_to_dict(
@@ -73,7 +72,7 @@ class TestClearFreePlanTenantExpiredLogs:
):
"""Test when records are found and have to_dict method."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
mock_session.scalars.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
@@ -104,7 +103,7 @@ class TestClearFreePlanTenantExpiredLogs:
records.append(record)
# Mock records for first table only, empty for others
mock_session.query.return_value.where.return_value.all.side_effect = [
mock_session.scalars.return_value.all.side_effect = [
records,
[],
[],
@@ -126,13 +125,13 @@ class TestClearFreePlanTenantExpiredLogs:
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_storage.save.side_effect = Exception("Storage error")
mock_session.query.return_value.where.return_value.all.return_value = sample_records
mock_session.scalars.return_value.all.return_value = sample_records
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if backup fails
assert mock_session.query.return_value.where.return_value.delete.called
assert mock_session.execute.called
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
"""Test that method continues even when record serialization fails."""
@@ -141,23 +140,23 @@ class TestClearFreePlanTenantExpiredLogs:
record.id = "record-1"
record.to_dict.side_effect = Exception("Serialization error")
mock_session.query.return_value.where.return_value.all.return_value = [record]
mock_session.scalars.return_value.all.return_value = [record]
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if serialization fails
assert mock_session.query.return_value.where.return_value.delete.called
assert mock_session.execute.called
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
"""Test that deletion is called for found records."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
mock_session.scalars.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call delete for each table that has records
assert mock_session.query.return_value.where.return_value.delete.called
# Should call execute(delete(...)) for each table that has records
assert mock_session.execute.called
def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes(
self, mock_session, sample_message_ids
@@ -167,12 +166,12 @@ class TestClearFreePlanTenantExpiredLogs:
record.to_dict.side_effect = Exception("Serialization error")
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = [record]
mock_session.scalars.return_value.all.return_value = [record]
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
mock_storage.save.assert_not_called()
assert mock_session.query.return_value.where.return_value.delete.called
assert mock_session.execute.called
class _ImmediateFuture:
@@ -263,42 +262,23 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"})
log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"})
def make_query_with_batches(batches: list[list[object]]):
q = MagicMock()
q.where.return_value = q
q.limit.return_value = q
q.all.side_effect = batches
q.delete.return_value = 1
return q
msg_session_1 = MagicMock()
msg_session_1.query.side_effect = lambda model: (
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
)
msg_session_1.scalars.return_value.all.return_value = [msg1]
msg_session_2 = MagicMock()
msg_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
)
msg_session_2.scalars.return_value.all.return_value = []
conv_session_1 = MagicMock()
conv_session_1.query.side_effect = lambda model: (
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
)
conv_session_1.scalars.return_value.all.return_value = [conv1]
conv_session_2 = MagicMock()
conv_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
)
conv_session_2.scalars.return_value.all.return_value = []
wal_session_1 = MagicMock()
wal_session_1.query.side_effect = lambda model: (
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_1.scalars.return_value.all.return_value = [log1]
wal_session_2 = MagicMock()
wal_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_2.scalars.return_value.all.return_value = []
session_wrappers = [
_sessionmaker_wrapper_for_begin(msg_session_1),
@@ -354,9 +334,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
# Total tenant count query
count_session = MagicMock()
count_query = MagicMock()
count_query.count.return_value = 2
count_session.query.return_value = count_query
count_session.scalar.return_value = 2
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
@@ -421,32 +399,15 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
# Sessions used:
# 1) total tenant count
# 2) per-batch tenant scan (count + tenant list)
# 2) per-batch tenant scan (interval counts + tenant list)
total_session = MagicMock()
total_query = MagicMock()
total_query.count.return_value = 250
total_session.query.return_value = total_query
batch_session = MagicMock()
q1 = MagicMock()
q1.where.return_value = q1
q1.count.return_value = 200
q2 = MagicMock()
q2.where.return_value = q2
q2.count.return_value = 200
q3 = MagicMock()
q3.where.return_value = q3
q3.count.return_value = 200
q4 = MagicMock()
q4.where.return_value = q4
q4.count.return_value = 50 # choose this interval, then scale it
total_session.scalar.return_value = 250
rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")]
q_rs = MagicMock()
q_rs.where.return_value = q_rs
q_rs.order_by.return_value = rows
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
batch_session = MagicMock()
# 4 test intervals queried: 200, 200, 200, 50 — breaks on 50 <= 100 (4th interval = 3h)
batch_session.scalar.side_effect = [200, 200, 200, 50]
batch_session.execute.return_value = rows
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
@@ -464,9 +425,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
count_session = MagicMock()
count_query = MagicMock()
count_query.count.return_value = 100
count_session.query.return_value = count_query
count_session.scalar.return_value = 100
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
flask_app = service_module.Flask("test-app")
@@ -513,25 +472,13 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
total_session = MagicMock()
total_query = MagicMock()
total_query.count.return_value = 250
total_session.query.return_value = total_query
batch_session = MagicMock()
# Count results for all 5 intervals, all > 100 => take the for-else path.
count_queries = []
for _ in range(5):
q = MagicMock()
q.where.return_value = q
q.count.return_value = 200
count_queries.append(q)
total_session.scalar.return_value = 250
rows = [SimpleNamespace(id="tenant-a")]
q_rs = MagicMock()
q_rs.where.return_value = q_rs
q_rs.order_by.return_value = rows
batch_session.query.side_effect = [*count_queries, q_rs]
batch_session = MagicMock()
# All 5 intervals have > 100 tenants => for-else falls through to min interval (1h)
batch_session.scalar.side_effect = [200, 200, 200, 200, 200]
batch_session.execute.return_value = rows
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
@@ -542,8 +489,7 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[])
assert process_tenant_mock.call_count == 1
assert len(count_queries) == 5
assert batch_session.query.call_count >= 6
assert batch_session.scalar.call_count == 5
def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -565,11 +511,7 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
# Make message/conversation/workflow_app_log loops no-op (empty immediately)
empty_session = MagicMock()
q_empty = MagicMock()
q_empty.where.return_value = q_empty
q_empty.limit.return_value = q_empty
q_empty.all.return_value = []
empty_session.query.return_value = q_empty
empty_session.scalars.return_value.all.return_value = []
session_wrappers = [
_sessionmaker_wrapper_for_begin(empty_session),
_sessionmaker_wrapper_for_begin(empty_session),

View File

@@ -577,7 +577,7 @@ class TestDatasetServiceCreationAndUpdate:
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = binding
session.scalar.return_value = binding
session.add = MagicMock()
session_context = _make_session_context(session)
@@ -596,7 +596,7 @@ class TestDatasetServiceCreationAndUpdate:
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()

View File

@@ -129,7 +129,7 @@ class TestDocumentServiceQueryAndDownloadHelpers:
def test_update_documents_need_summary_updates_matching_documents_and_commits(self):
session = MagicMock()
session.query.return_value.filter.return_value.update.return_value = 2
session.execute.return_value.rowcount = 2
with patch("services.dataset_service.session_factory") as session_factory_mock:
session_factory_mock.create_session.return_value = _make_session_context(session)