From 440602f52a92855671cfcbf7b35011f1e3b6eaa2 Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Sun, 12 Apr 2026 03:37:16 +0200 Subject: [PATCH] refactor(services): migrate summary_index_service to SQLAlchemy 2.0 select() API (#34971) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/summary_index_service.py | 211 +++++++++------ .../services/test_summary_index_service.py | 255 ++++-------------- 2 files changed, 171 insertions(+), 295 deletions(-) diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 8760d60de0..c906e3bca3 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -8,6 +8,7 @@ from typing import TypedDict, cast from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import select from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -109,8 +110,13 @@ class SummaryIndexService: """ with session_factory.create_session() as session: # Check if summary record already exists - existing_summary = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + existing_summary = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if existing_summary: @@ -309,8 +315,10 @@ class SummaryIndexService: summary_record_id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: @@ -323,10 +331,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -487,8 +498,10 @@ class SummaryIndexService: with session_factory.create_session() as error_session: # Try to find the record by id first # Note: Using assignment only (no type annotation) to avoid redeclaration error - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: # Try to find by chunk_id and dataset_id @@ -500,10 +513,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: @@ -551,14 +567,12 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query existing summary records - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset.id, ) - .all() - ) + ).all() existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} # Create or update records @@ -603,8 +617,13 @@ class SummaryIndexService: error: Error message """ with session_factory.create_session() as session: - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -639,8 +658,13 @@ class SummaryIndexService: with session_factory.create_session() as session: try: # Get or refresh summary record in this session - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -710,8 +734,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to generate summary for segment %s", segment.id) # Update summary record with error status - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: summary_record_in_session.status = SummaryStatus.ERROR @@ -769,17 +798,17 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query segments (only enabled segments) - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, # Only generate summaries for enabled segments + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments ) if segment_ids: - query = query.filter(DocumentSegment.id.in_(segment_ids)) + stmt = stmt.where(DocumentSegment.id.in_(segment_ids)) - segments = query.all() + segments = list(session.scalars(stmt).all()) if not segments: logger.info("No segments found for document %s", document.id) @@ -848,15 +877,15 @@ class SummaryIndexService: from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=True, # Only disable enabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -911,15 +940,15 @@ class SummaryIndexService: return with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=False, # Only enable disabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -935,13 +964,13 @@ class SummaryIndexService: enabled_count = 0 for summary in summaries: # Get the original segment - segment = ( - session.query(DocumentSegment) - .filter_by( - id=summary.chunk_id, - dataset_id=dataset.id, + segment = session.scalar( + select(DocumentSegment) + .where( + DocumentSegment.id == summary.chunk_id, + DocumentSegment.dataset_id == dataset.id, ) - .first() + .limit(1) ) # Summary.enabled stays in sync with chunk.enabled, @@ -988,12 +1017,12 @@ class SummaryIndexService: segment_ids: List of segment IDs to delete summaries for. If None, delete all. """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -1046,10 +1075,13 @@ class SummaryIndexService: # Check if summary_content is empty (whitespace-only strings are considered empty) if not summary_content or not summary_content.strip(): # If summary is empty, only delete existing summary vector and record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1077,8 +1109,13 @@ class SummaryIndexService: return None # Find existing summary record - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1162,8 +1199,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to update summary for segment %s", segment.id) # Update summary record with error status if it exists - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: summary_record.status = SummaryStatus.ERROR @@ -1185,14 +1227,14 @@ class SummaryIndexService: DocumentSegmentSummary instance if found, None otherwise """ with session_factory.create_session() as session: - return ( - session.query(DocumentSegmentSummary) + return session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .first() + .limit(1) ) @staticmethod @@ -1211,15 +1253,13 @@ class SummaryIndexService: return {} with session_factory.create_session() as session: - summary_records = ( - session.query(DocumentSegmentSummary) - .where( + summary_records = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .all() - ) + ).all() return {summary.chunk_id: summary for summary in summary_records} @@ -1239,16 +1279,16 @@ class SummaryIndexService: List of DocumentSegmentSummary instances (only enabled summaries) """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter( + stmt = select(DocumentSegmentSummary).where( DocumentSegmentSummary.document_id == document_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - return query.all() + return list(session.scalars(stmt).all()) @staticmethod def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: @@ -1265,16 +1305,15 @@ class SummaryIndexService: """ # Get all segments for this document (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == tenant_id, - ) - .all() + segment_ids = list( + session.scalars( + select(DocumentSegment.id).where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + ).all() ) - segment_ids = [seg.id for seg in segments] if not segment_ids: return None @@ -1312,15 +1351,13 @@ class SummaryIndexService: # Get all segments for these documents (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id, DocumentSegment.document_id) - .where( + segments = session.execute( + select(DocumentSegment.id, DocumentSegment.document_id).where( DocumentSegment.document_id.in_(document_ids), DocumentSegment.status != "re_segment", DocumentSegment.tenant_id == tenant_id, ) - .all() - ) + ).all() # Group segments by document_id document_segments_map: dict[str, list[str]] = {} diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index cbf3e121d8..e17d4134ac 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -124,10 +124,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes existing.disabled_by = "u" session = MagicMock(name="session") - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = existing - session.query.return_value = query + session.scalar.return_value = existing create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -149,10 +146,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None: session = MagicMock(name="session") - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -234,10 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat # New session used after vectorization succeeds (record not found by id nor chunk_id). session = MagicMock(name="session") - q1 = MagicMock() - q1.filter_by.return_value = q1 - q1.first.side_effect = [None, None] - session.query.return_value = q1 + session.scalar.side_effect = [None, None] create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -267,10 +258,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes # error_session should find record and commit status update error_session = MagicMock(name="error_session") - q = MagicMock() - q.filter_by.return_value = q - q.first.return_value = summary - error_session.query.return_value = q + error_session.scalar.return_value = summary create_session_mock = MagicMock(return_value=_SessionContext(error_session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -302,10 +290,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo existing.enabled = False session = MagicMock() - query = MagicMock() - query.filter.return_value = query - query.all.return_value = [existing] - session.query.return_value = query + session.scalars.return_value.all.return_value = [existing] monkeypatch.setattr( summary_module, @@ -324,10 +309,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon record = _summary_record() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -346,10 +328,7 @@ def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch) record = _summary_record(summary_content="") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, @@ -373,10 +352,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch record = _summary_record(summary_content="") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, @@ -415,10 +391,7 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch existing = _summary_record(summary_content="old", node_id="old-node") existing.id = "other-id" session = MagicMock(name="session") - q = MagicMock() - q.filter_by.return_value = q - q.first.side_effect = [None, existing] # miss by id, hit by chunk_id - session.query.return_value = q + session.scalar.side_effect = [None, existing] # miss by id, hit by chunk_id monkeypatch.setattr( summary_module, "session_factory", @@ -448,10 +421,7 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte existing = _summary_record(summary_content="old", node_id="old-node") session = MagicMock(name="session") - q = MagicMock() - q.filter_by.return_value = q - q.first.return_value = existing # hit by id - session.query.return_value = q + session.scalar.return_value = existing # hit by id monkeypatch.setattr( summary_module, "session_factory", @@ -487,10 +457,7 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon return None error_session = MagicMock() - q = MagicMock() - q.filter_by.return_value = q - q.first.return_value = summary - error_session.query.return_value = q + error_session.scalar.return_value = summary create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)]) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -516,21 +483,17 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc ) session = MagicMock() - q = MagicMock() - q.filter_by.return_value = q - q.first.side_effect = [None, None] # miss by id and chunk_id - session.query.return_value = q + session.scalar.side_effect = [None, None] # miss by id and chunk_id error_session = MagicMock() - eq = MagicMock() - eq.filter_by.return_value = eq - eq.first.return_value = summary - error_session.query.return_value = eq + error_session.scalar.return_value = summary create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)]) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) # Force the created record to be None so the "should not be None" guard triggers. + # Also mock select() so SQLAlchemy doesn't validate the mocked DocumentSegmentSummary as a real column clause. + monkeypatch.setattr(summary_module, "select", MagicMock(return_value=MagicMock())) monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None)) with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"): @@ -554,10 +517,7 @@ def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_ ) error_session = MagicMock(name="error_session") - q = MagicMock() - q.filter_by.return_value = q - q.first.side_effect = [None, None] # not found by id, not found by chunk_id - error_session.query.return_value = q + error_session.scalar.side_effect = [None, None] # not found by id, not found by chunk_id monkeypatch.setattr( summary_module, @@ -577,10 +537,7 @@ def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.Monk segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -599,10 +556,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -646,11 +600,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py seg2.id = "seg-2" session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [seg1, seg2] - session.query.return_value = query + session.scalars.return_value.all.return_value = [seg1, seg2] monkeypatch.setattr( summary_module, @@ -678,11 +628,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -702,11 +648,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu seg = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [seg] - session.query.return_value = query + session.scalars.return_value.all.return_value = [seg] monkeypatch.setattr( summary_module, "session_factory", @@ -723,7 +665,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu segment_ids=[seg.id], only_parent_chunks=True, ) - query.filter.assert_called() + session.scalars.assert_called() def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None: @@ -732,11 +674,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: summary2 = _summary_record(summary_content="s", node_id=None) session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [summary1, summary2] - session.query.return_value = query + session.scalars.return_value.all.return_value = [summary1, summary2] monkeypatch.setattr( summary_module, @@ -761,11 +699,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _dataset() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -793,21 +727,8 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt segment.status = SegmentStatus.COMPLETED session = MagicMock() - summary_query = MagicMock() - summary_query.filter_by.return_value = summary_query - summary_query.filter.return_value = summary_query - summary_query.all.return_value = [summary] - - seg_query = MagicMock() - seg_query.filter_by.return_value = seg_query - seg_query.first.return_value = segment - - def query_side_effect(model: object) -> MagicMock: - if model is summary_module.DocumentSegmentSummary: - return summary_query - return seg_query - - session.query.side_effect = query_side_effect + session.scalars.return_value.all.return_value = [summary] + session.scalar.return_value = segment monkeypatch.setattr( summary_module, @@ -826,11 +747,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _dataset() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -860,21 +777,9 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect good_segment.status = SegmentStatus.COMPLETED session = MagicMock() - summary_query = MagicMock() - summary_query.filter_by.return_value = summary_query - summary_query.filter.return_value = summary_query - summary_query.all.return_value = [summary1, summary2, summary3] + session.scalars.return_value.all.return_value = [summary1, summary2, summary3] + session.scalar.side_effect = [bad_segment, good_segment, good_segment] - seg_query = MagicMock() - seg_query.filter_by.return_value = seg_query - seg_query.first.side_effect = [bad_segment, good_segment, good_segment] - - def query_side_effect(model: object) -> MagicMock: - if model is summary_module.DocumentSegmentSummary: - return summary_query - return seg_query - - session.query.side_effect = query_side_effect monkeypatch.setattr( summary_module, "session_factory", @@ -895,11 +800,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: summary = _summary_record(summary_content="sum", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [summary] - session.query.return_value = query + session.scalars.return_value.all.return_value = [summary] vector_instance = MagicMock() monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) @@ -918,11 +819,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _dataset() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -946,10 +843,7 @@ def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch: record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record vector_instance = MagicMock() monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) @@ -971,10 +865,7 @@ def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatc record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -996,10 +887,7 @@ def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: py segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -1015,10 +903,7 @@ def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch: record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record vector_instance = MagicMock() monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) @@ -1044,10 +929,7 @@ def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: py record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -1073,10 +955,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -1095,10 +974,7 @@ def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.Monke segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -1122,10 +998,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record session.flush.side_effect = RuntimeError("flush boom") monkeypatch.setattr( summary_module, @@ -1143,25 +1016,9 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None: record = _summary_record(summary_content="sum", node_id="n1") session = MagicMock() + session.scalar.return_value = record + session.scalars.return_value.all.return_value = [record] - q1 = MagicMock() - q1.where.return_value = q1 - q1.first.return_value = record - - q2 = MagicMock() - q2.filter.return_value = q2 - q2.all.return_value = [record] - - def query_side_effect(model: object) -> MagicMock: - if model is summary_module.DocumentSegmentSummary: - # first call used by get_segment_summary, second by get_document_summaries - if not hasattr(query_side_effect, "_called"): - query_side_effect._called = True # type: ignore[attr-defined] - return q1 - return q2 - return MagicMock() - - session.query.side_effect = query_side_effect monkeypatch.setattr( summary_module, "session_factory", @@ -1178,10 +1035,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No record2 = _summary_record() record2.chunk_id = "seg-2" session = MagicMock() - q = MagicMock() - q.where.return_value = q - q.all.return_value = [record1, record2] - session.query.return_value = q + session.scalars.return_value.all.return_value = [record1, record2] monkeypatch.setattr( summary_module, "session_factory", @@ -1194,10 +1048,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: session = MagicMock() - q = MagicMock() - q.where.return_value = q - q.all.return_value = [] - session.query.return_value = q + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -1212,10 +1063,7 @@ def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.Monk def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None: session = MagicMock() - q = MagicMock() - q.where.return_value = q - q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] - session.query.return_value = q + session.execute.return_value.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] monkeypatch.setattr( summary_module, "session_factory", @@ -1237,10 +1085,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, @@ -1267,10 +1112,7 @@ def test_get_segments_summaries_empty_list() -> None: def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None: seg_row = SimpleNamespace(id="seg-1", document_id="doc-1") session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.all.return_value = [SimpleNamespace(id="seg-1")] - session.query.return_value = query + session.scalars.return_value.all.return_value = ["seg-1"] # get_document_summary_index_status returns IDs create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -1283,11 +1125,8 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" # Multiple docs - query2 = MagicMock() - query2.where.return_value = query2 - query2.all.return_value = [seg_row] session2 = MagicMock() - session2.query.return_value = query2 + session2.execute.return_value.all.return_value = [seg_row] # get_documents_summary_index_status uses execute monkeypatch.setattr( summary_module, "session_factory",