mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 15:20:15 +08:00
refactor(services): migrate summary_index_service to SQLAlchemy 2.0 select() API (#34971)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import TypedDict, cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@@ -109,8 +110,13 @@ class SummaryIndexService:
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
# Check if summary record already exists
|
||||
existing_summary = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
existing_summary = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_summary:
|
||||
@@ -309,8 +315,10 @@ class SummaryIndexService:
|
||||
summary_record_id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(DocumentSegmentSummary.id == summary_record_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@@ -323,10 +331,13 @@ class SummaryIndexService:
|
||||
dataset.id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@@ -487,8 +498,10 @@ class SummaryIndexService:
|
||||
with session_factory.create_session() as error_session:
|
||||
# Try to find the record by id first
|
||||
# Note: Using assignment only (no type annotation) to avoid redeclaration error
|
||||
summary_record_in_session = (
|
||||
error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
|
||||
summary_record_in_session = error_session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(DocumentSegmentSummary.id == summary_record_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not summary_record_in_session:
|
||||
# Try to find by chunk_id and dataset_id
|
||||
@@ -500,10 +513,13 @@ class SummaryIndexService:
|
||||
dataset.id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
error_session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record_in_session = error_session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record_in_session:
|
||||
@@ -551,14 +567,12 @@ class SummaryIndexService:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# Query existing summary records
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
existing_summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
|
||||
|
||||
# Create or update records
|
||||
@@ -603,8 +617,13 @@ class SummaryIndexService:
|
||||
error: Error message
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@@ -639,8 +658,13 @@ class SummaryIndexService:
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
# Get or refresh summary record in this session
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@@ -710,8 +734,13 @@ class SummaryIndexService:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Update summary record with error status
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
@@ -769,17 +798,17 @@ class SummaryIndexService:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# Query segments (only enabled segments)
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True, # Only generate summaries for enabled segments
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegment.id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegment.id.in_(segment_ids))
|
||||
|
||||
segments = query.all()
|
||||
segments = list(session.scalars(stmt).all())
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
@@ -848,15 +877,15 @@ class SummaryIndexService:
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=True, # Only disable enabled summaries
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@@ -911,15 +940,15 @@ class SummaryIndexService:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=False, # Only enable disabled summaries
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@@ -935,13 +964,13 @@ class SummaryIndexService:
|
||||
enabled_count = 0
|
||||
for summary in summaries:
|
||||
# Get the original segment
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.filter_by(
|
||||
id=summary.chunk_id,
|
||||
dataset_id=dataset.id,
|
||||
segment = session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.id == summary.chunk_id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Summary.enabled stays in sync with chunk.enabled,
|
||||
@@ -988,12 +1017,12 @@ class SummaryIndexService:
|
||||
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
|
||||
stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@@ -1046,10 +1075,13 @@ class SummaryIndexService:
|
||||
# Check if summary_content is empty (whitespace-only strings are considered empty)
|
||||
if not summary_content or not summary_content.strip():
|
||||
# If summary is empty, only delete existing summary vector and record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@@ -1077,8 +1109,13 @@ class SummaryIndexService:
|
||||
return None
|
||||
|
||||
# Find existing summary record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@@ -1162,8 +1199,13 @@ class SummaryIndexService:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
@@ -1185,14 +1227,14 @@ class SummaryIndexService:
|
||||
DocumentSegmentSummary instance if found, None otherwise
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
return (
|
||||
session.query(DocumentSegmentSummary)
|
||||
return session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1211,15 +1253,13 @@ class SummaryIndexService:
|
||||
return {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
summary_records = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
summary_records = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return {summary.chunk_id: summary for summary in summary_records}
|
||||
|
||||
@@ -1239,16 +1279,16 @@ class SummaryIndexService:
|
||||
List of DocumentSegmentSummary instances (only enabled summaries)
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter(
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
return query.all()
|
||||
return list(session.scalars(stmt).all())
|
||||
|
||||
@staticmethod
|
||||
def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
|
||||
@@ -1265,16 +1305,15 @@ class SummaryIndexService:
|
||||
"""
|
||||
# Get all segments for this document (excluding qa_model and re_segment)
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
segment_ids = list(
|
||||
session.scalars(
|
||||
select(DocumentSegment.id).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return None
|
||||
@@ -1312,15 +1351,13 @@ class SummaryIndexService:
|
||||
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
segments = session.execute(
|
||||
select(DocumentSegment.id, DocumentSegment.document_id).where(
|
||||
DocumentSegment.document_id.in_(document_ids),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map: dict[str, list[str]] = {}
|
||||
|
||||
Reference in New Issue
Block a user