refactor(services): migrate summary_index_service to SQLAlchemy 2.0 select() API (#34971)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wdeveloper16
2026-04-12 03:37:16 +02:00
committed by GitHub
parent 510120410b
commit 440602f52a
2 changed files with 171 additions and 295 deletions

View File

@@ -8,6 +8,7 @@ from typing import TypedDict, cast
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
@@ -109,8 +110,13 @@ class SummaryIndexService:
"""
with session_factory.create_session() as session:
# Check if summary record already exists
existing_summary = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
existing_summary = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if existing_summary:
@@ -309,8 +315,10 @@ class SummaryIndexService:
summary_record_id,
segment.id,
)
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(DocumentSegmentSummary.id == summary_record_id)
.limit(1)
)
if not summary_record_in_session:
@@ -323,10 +331,13 @@ class SummaryIndexService:
dataset.id,
segment.id,
)
summary_record_in_session = (
session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if not summary_record_in_session:
@@ -487,8 +498,10 @@ class SummaryIndexService:
with session_factory.create_session() as error_session:
# Try to find the record by id first
# Note: Using assignment only (no type annotation) to avoid redeclaration error
summary_record_in_session = (
error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
summary_record_in_session = error_session.scalar(
select(DocumentSegmentSummary)
.where(DocumentSegmentSummary.id == summary_record_id)
.limit(1)
)
if not summary_record_in_session:
# Try to find by chunk_id and dataset_id
@@ -500,10 +513,13 @@ class SummaryIndexService:
dataset.id,
segment.id,
)
summary_record_in_session = (
error_session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record_in_session = error_session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record_in_session:
@@ -551,14 +567,12 @@ class SummaryIndexService:
with session_factory.create_session() as session:
# Query existing summary records
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
existing_summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset.id,
)
.all()
)
).all()
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
# Create or update records
@@ -603,8 +617,13 @@ class SummaryIndexService:
error: Error message
"""
with session_factory.create_session() as session:
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@@ -639,8 +658,13 @@ class SummaryIndexService:
with session_factory.create_session() as session:
try:
# Get or refresh summary record in this session
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if not summary_record_in_session:
@@ -710,8 +734,13 @@ class SummaryIndexService:
except Exception as e:
logger.exception("Failed to generate summary for segment %s", segment.id)
# Update summary record with error status
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record_in_session:
summary_record_in_session.status = SummaryStatus.ERROR
@@ -769,17 +798,17 @@ class SummaryIndexService:
with session_factory.create_session() as session:
# Query segments (only enabled segments)
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True, # Only generate summaries for enabled segments
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments
)
if segment_ids:
query = query.filter(DocumentSegment.id.in_(segment_ids))
stmt = stmt.where(DocumentSegment.id.in_(segment_ids))
segments = query.all()
segments = list(session.scalars(stmt).all())
if not segments:
logger.info("No segments found for document %s", document.id)
@@ -848,15 +877,15 @@ class SummaryIndexService:
from libs.datetime_utils import naive_utc_now
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=True, # Only disable enabled summaries
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@@ -911,15 +940,15 @@ class SummaryIndexService:
return
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=False, # Only enable disabled summaries
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@@ -935,13 +964,13 @@ class SummaryIndexService:
enabled_count = 0
for summary in summaries:
# Get the original segment
segment = (
session.query(DocumentSegment)
.filter_by(
id=summary.chunk_id,
dataset_id=dataset.id,
segment = session.scalar(
select(DocumentSegment)
.where(
DocumentSegment.id == summary.chunk_id,
DocumentSegment.dataset_id == dataset.id,
)
.first()
.limit(1)
)
# Summary.enabled stays in sync with chunk.enabled,
@@ -988,12 +1017,12 @@ class SummaryIndexService:
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
"""
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@@ -1046,10 +1075,13 @@ class SummaryIndexService:
# Check if summary_content is empty (whitespace-only strings are considered empty)
if not summary_content or not summary_content.strip():
# If summary is empty, only delete existing summary vector and record
summary_record = (
session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@@ -1077,8 +1109,13 @@ class SummaryIndexService:
return None
# Find existing summary record
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@@ -1162,8 +1199,13 @@ class SummaryIndexService:
except Exception as e:
logger.exception("Failed to update summary for segment %s", segment.id)
# Update summary record with error status if it exists
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
summary_record.status = SummaryStatus.ERROR
@@ -1185,14 +1227,14 @@ class SummaryIndexService:
DocumentSegmentSummary instance if found, None otherwise
"""
with session_factory.create_session() as session:
return (
session.query(DocumentSegmentSummary)
return session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
.first()
.limit(1)
)
@staticmethod
@@ -1211,15 +1253,13 @@ class SummaryIndexService:
return {}
with session_factory.create_session() as session:
summary_records = (
session.query(DocumentSegmentSummary)
.where(
summary_records = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
.all()
)
).all()
return {summary.chunk_id: summary for summary in summary_records}
@@ -1239,16 +1279,16 @@ class SummaryIndexService:
List of DocumentSegmentSummary instances (only enabled summaries)
"""
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter(
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
return query.all()
return list(session.scalars(stmt).all())
@staticmethod
def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
@@ -1265,16 +1305,15 @@ class SummaryIndexService:
"""
# Get all segments for this document (excluding qa_model and re_segment)
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment.id)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
.all()
segment_ids = list(
session.scalars(
select(DocumentSegment.id).where(
DocumentSegment.document_id == document_id,
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
).all()
)
segment_ids = [seg.id for seg in segments]
if not segment_ids:
return None
@@ -1312,15 +1351,13 @@ class SummaryIndexService:
# Get all segments for these documents (excluding qa_model and re_segment)
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
segments = session.execute(
select(DocumentSegment.id, DocumentSegment.document_id).where(
DocumentSegment.document_id.in_(document_ids),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
.all()
)
).all()
# Group segments by document_id
document_segments_map: dict[str, list[str]] = {}