mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 23:40:16 +08:00
refactor(services): migrate dataset_service and clear_free_plan_tenant_expired_logs to SQLAlchemy 2.0 select() API (#34970)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
for model, table_name in related_tables:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.where(
|
||||
records = session.scalars(
|
||||
select(model).where(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(records) == 0:
|
||||
continue
|
||||
@@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(model)
|
||||
.where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
messages = (
|
||||
session.query(Message)
|
||||
messages = session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.app_id.in_(app_ids),
|
||||
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if len(messages) == 0:
|
||||
break
|
||||
|
||||
@@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
message_ids = [message.id for message in messages]
|
||||
|
||||
# delete messages
|
||||
session.query(Message).where(
|
||||
Message.id.in_(message_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
cls._clear_message_related_tables(session, tenant_id, message_ids)
|
||||
|
||||
@@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
conversations = (
|
||||
session.query(Conversation)
|
||||
conversations = session.scalars(
|
||||
select(Conversation)
|
||||
.where(
|
||||
Conversation.app_id.in_(app_ids),
|
||||
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(conversations) == 0:
|
||||
break
|
||||
@@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
)
|
||||
|
||||
conversation_ids = [conversation.id for conversation in conversations]
|
||||
session.query(Conversation).where(
|
||||
Conversation.id.in_(conversation_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(Conversation)
|
||||
.where(Conversation.id.in_(conversation_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
workflow_app_logs = session.scalars(
|
||||
select(WorkflowAppLog)
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(workflow_app_logs) == 0:
|
||||
break
|
||||
@@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.id.in_(workflow_app_log_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
@@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
current_time = started_at
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
@@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(Tenant.id)).where(
|
||||
Tenant.created_at.between(current_time, current_time + test_interval)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
@@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
rs = session.execute(
|
||||
select(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
@@ -552,8 +552,8 @@ class DatasetService:
|
||||
external_knowledge_api_id: External knowledge API identifier
|
||||
"""
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
external_knowledge_binding = session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
@@ -1454,15 +1454,17 @@ class DocumentService:
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
updated_count = (
|
||||
session.query(Document)
|
||||
.filter(
|
||||
result = session.execute(
|
||||
update(Document)
|
||||
.where(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
.values(need_summary=need_summary)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
updated_count = result.rowcount # type: ignore[union-attr,attr-defined]
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Updated need_summary to %s for %d documents in dataset %s",
|
||||
|
||||
@@ -17,8 +17,7 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = Mock(spec=Session)
|
||||
session.query.return_value.filter.return_value.all.return_value = []
|
||||
session.query.return_value.filter.return_value.delete.return_value = 0
|
||||
session.scalars.return_value.all.return_value = []
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
@@ -54,18 +53,18 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
|
||||
|
||||
# Should not call any database operations
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.scalars.assert_not_called()
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
|
||||
"""Test when no related records are found."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call query for each related table but find no records
|
||||
assert mock_session.query.call_count > 0
|
||||
# Should call scalars for each related table but find no records
|
||||
assert mock_session.scalars.call_count > 0
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_with_records_and_to_dict(
|
||||
@@ -73,7 +72,7 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
):
|
||||
"""Test when records are found and have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.scalars.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
@@ -104,7 +103,7 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
records.append(record)
|
||||
|
||||
# Mock records for first table only, empty for others
|
||||
mock_session.query.return_value.where.return_value.all.side_effect = [
|
||||
mock_session.scalars.return_value.all.side_effect = [
|
||||
records,
|
||||
[],
|
||||
[],
|
||||
@@ -126,13 +125,13 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_storage.save.side_effect = Exception("Storage error")
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.scalars.return_value.all.return_value = sample_records
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if backup fails
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
|
||||
"""Test that method continues even when record serialization fails."""
|
||||
@@ -141,23 +140,23 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
record.id = "record-1"
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [record]
|
||||
mock_session.scalars.return_value.all.return_value = [record]
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if serialization fails
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
|
||||
"""Test that deletion is called for found records."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.scalars.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call delete for each table that has records
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
# Should call execute(delete(...)) for each table that has records
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes(
|
||||
self, mock_session, sample_message_ids
|
||||
@@ -167,12 +166,12 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [record]
|
||||
mock_session.scalars.return_value.all.return_value = [record]
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
mock_storage.save.assert_not_called()
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
assert mock_session.execute.called
|
||||
|
||||
|
||||
class _ImmediateFuture:
|
||||
@@ -263,42 +262,23 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
|
||||
conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"})
|
||||
log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"})
|
||||
|
||||
def make_query_with_batches(batches: list[list[object]]):
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.limit.return_value = q
|
||||
q.all.side_effect = batches
|
||||
q.delete.return_value = 1
|
||||
return q
|
||||
|
||||
msg_session_1 = MagicMock()
|
||||
msg_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_1.scalars.return_value.all.return_value = [msg1]
|
||||
|
||||
msg_session_2 = MagicMock()
|
||||
msg_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_2.scalars.return_value.all.return_value = []
|
||||
|
||||
conv_session_1 = MagicMock()
|
||||
conv_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_1.scalars.return_value.all.return_value = [conv1]
|
||||
|
||||
conv_session_2 = MagicMock()
|
||||
conv_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_2.scalars.return_value.all.return_value = []
|
||||
|
||||
wal_session_1 = MagicMock()
|
||||
wal_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_1.scalars.return_value.all.return_value = [log1]
|
||||
|
||||
wal_session_2 = MagicMock()
|
||||
wal_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_2.scalars.return_value.all.return_value = []
|
||||
|
||||
session_wrappers = [
|
||||
_sessionmaker_wrapper_for_begin(msg_session_1),
|
||||
@@ -354,9 +334,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
|
||||
|
||||
# Total tenant count query
|
||||
count_session = MagicMock()
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 2
|
||||
count_session.query.return_value = count_query
|
||||
count_session.scalar.return_value = 2
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||
|
||||
@@ -421,32 +399,15 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
|
||||
|
||||
# Sessions used:
|
||||
# 1) total tenant count
|
||||
# 2) per-batch tenant scan (count + tenant list)
|
||||
# 2) per-batch tenant scan (interval counts + tenant list)
|
||||
total_session = MagicMock()
|
||||
total_query = MagicMock()
|
||||
total_query.count.return_value = 250
|
||||
total_session.query.return_value = total_query
|
||||
|
||||
batch_session = MagicMock()
|
||||
q1 = MagicMock()
|
||||
q1.where.return_value = q1
|
||||
q1.count.return_value = 200
|
||||
q2 = MagicMock()
|
||||
q2.where.return_value = q2
|
||||
q2.count.return_value = 200
|
||||
q3 = MagicMock()
|
||||
q3.where.return_value = q3
|
||||
q3.count.return_value = 200
|
||||
q4 = MagicMock()
|
||||
q4.where.return_value = q4
|
||||
q4.count.return_value = 50 # choose this interval, then scale it
|
||||
total_session.scalar.return_value = 250
|
||||
|
||||
rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")]
|
||||
q_rs = MagicMock()
|
||||
q_rs.where.return_value = q_rs
|
||||
q_rs.order_by.return_value = rows
|
||||
|
||||
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
|
||||
batch_session = MagicMock()
|
||||
# 4 test intervals queried: 200, 200, 200, 50 — breaks on 50 <= 100 (4th interval = 3h)
|
||||
batch_session.scalar.side_effect = [200, 200, 200, 50]
|
||||
batch_session.execute.return_value = rows
|
||||
|
||||
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||
@@ -464,9 +425,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
count_session = MagicMock()
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 100
|
||||
count_session.query.return_value = count_query
|
||||
count_session.scalar.return_value = 100
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||
|
||||
flask_app = service_module.Flask("test-app")
|
||||
@@ -513,25 +472,13 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
|
||||
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
|
||||
|
||||
total_session = MagicMock()
|
||||
total_query = MagicMock()
|
||||
total_query.count.return_value = 250
|
||||
total_session.query.return_value = total_query
|
||||
|
||||
batch_session = MagicMock()
|
||||
# Count results for all 5 intervals, all > 100 => take the for-else path.
|
||||
count_queries = []
|
||||
for _ in range(5):
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.count.return_value = 200
|
||||
count_queries.append(q)
|
||||
total_session.scalar.return_value = 250
|
||||
|
||||
rows = [SimpleNamespace(id="tenant-a")]
|
||||
q_rs = MagicMock()
|
||||
q_rs.where.return_value = q_rs
|
||||
q_rs.order_by.return_value = rows
|
||||
|
||||
batch_session.query.side_effect = [*count_queries, q_rs]
|
||||
batch_session = MagicMock()
|
||||
# All 5 intervals have > 100 tenants => for-else falls through to min interval (1h)
|
||||
batch_session.scalar.side_effect = [200, 200, 200, 200, 200]
|
||||
batch_session.execute.return_value = rows
|
||||
|
||||
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||
@@ -542,8 +489,7 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
|
||||
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[])
|
||||
|
||||
assert process_tenant_mock.call_count == 1
|
||||
assert len(count_queries) == 5
|
||||
assert batch_session.query.call_count >= 6
|
||||
assert batch_session.scalar.call_count == 5
|
||||
|
||||
|
||||
def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -565,11 +511,7 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
|
||||
|
||||
# Make message/conversation/workflow_app_log loops no-op (empty immediately)
|
||||
empty_session = MagicMock()
|
||||
q_empty = MagicMock()
|
||||
q_empty.where.return_value = q_empty
|
||||
q_empty.limit.return_value = q_empty
|
||||
q_empty.all.return_value = []
|
||||
empty_session.query.return_value = q_empty
|
||||
empty_session.scalars.return_value.all.return_value = []
|
||||
session_wrappers = [
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
|
||||
@@ -577,7 +577,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
|
||||
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
session.scalar.return_value = binding
|
||||
session.add = MagicMock()
|
||||
session_context = _make_session_context(session)
|
||||
|
||||
@@ -596,7 +596,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
session_context = _make_session_context(session)
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
|
||||
@@ -129,7 +129,7 @@ class TestDocumentServiceQueryAndDownloadHelpers:
|
||||
|
||||
def test_update_documents_need_summary_updates_matching_documents_and_commits(self):
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter.return_value.update.return_value = 2
|
||||
session.execute.return_value.rowcount = 2
|
||||
|
||||
with patch("services.dataset_service.session_factory") as session_factory_mock:
|
||||
session_factory_mock.create_session.return_value = _make_session_context(session)
|
||||
|
||||
Reference in New Issue
Block a user