refactor: replace bare dict with TypedDicts in annotation_service (#34998)

This commit is contained in:
wdeveloper16
2026-04-13 05:46:33 +02:00
committed by GitHub
parent 17da0e4146
commit 8436470fcb
3 changed files with 110 additions and 29 deletions

View File

@@ -1,11 +1,8 @@
import logging
import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
import pandas as pd
from sqlalchemy import delete, or_, select, update
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
logger = logging.getLogger(__name__)
class AnnotationJobStatusDict(TypedDict):
job_id: str
@@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class EnableAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to enable_app_annotation."""
score_threshold: float
embedding_provider_name: str
embedding_model_name: str
class UpsertAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
answer: str
content: str
message_id: str
question: str
class InsertAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
question: str
answer: str
class UpdateAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to update_app_annotation_directly.
Both fields are optional at the type level; the service validates at runtime
and raises ValueError if either is missing.
"""
answer: str
question: str
class UpdateAnnotationSettingArgs(TypedDict):
"""Expected shape of the args dict passed to update_app_annotation_setting."""
score_threshold: float
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@@ -62,8 +102,9 @@ class AppAnnotationService:
if answer is None:
raise ValueError("Either 'answer' or 'content' must be provided")
if args.get("message_id"):
message_id = str(args["message_id"])
raw_message_id = args.get("message_id")
if raw_message_id:
message_id = str(raw_message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
)
@@ -87,9 +128,10 @@ class AppAnnotationService:
account_id=current_user.id,
)
else:
question = args.get("question")
if not question:
maybe_question = args.get("question")
if not maybe_question:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
db.session.add(annotation)
@@ -110,7 +152,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@@ -217,7 +259,7 @@ class AppAnnotationService:
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@@ -251,7 +293,7 @@ class AppAnnotationService:
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@@ -270,7 +312,11 @@ class AppAnnotationService:
if question is None:
raise ValueError("'question' is required")
annotation.content = args["answer"]
answer = args.get("answer")
if answer is None:
raise ValueError("'answer' is required")
annotation.content = answer
annotation.question = question
db.session.commit()
@@ -613,7 +659,7 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info