mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 23:40:16 +08:00
refactor: replace bare dict with TypedDicts in annotation_service (#34998)
This commit is contained in:
@@ -1,11 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
|
||||
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationJobStatusDict(TypedDict):
|
||||
job_id: str
|
||||
@@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EnableAnnotationArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to enable_app_annotation."""
|
||||
|
||||
score_threshold: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class UpsertAnnotationArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
|
||||
|
||||
answer: str
|
||||
content: str
|
||||
message_id: str
|
||||
question: str
|
||||
|
||||
|
||||
class InsertAnnotationArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
|
||||
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class UpdateAnnotationArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to update_app_annotation_directly.
|
||||
|
||||
Both fields are optional at the type level; the service validates at runtime
|
||||
and raises ValueError if either is missing.
|
||||
"""
|
||||
|
||||
answer: str
|
||||
question: str
|
||||
|
||||
|
||||
class UpdateAnnotationSettingArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to update_app_annotation_setting."""
|
||||
|
||||
score_threshold: float
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@@ -62,8 +102,9 @@ class AppAnnotationService:
|
||||
if answer is None:
|
||||
raise ValueError("Either 'answer' or 'content' must be provided")
|
||||
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
raw_message_id = args.get("message_id")
|
||||
if raw_message_id:
|
||||
message_id = str(raw_message_id)
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
|
||||
)
|
||||
@@ -87,9 +128,10 @@ class AppAnnotationService:
|
||||
account_id=current_user.id,
|
||||
)
|
||||
else:
|
||||
question = args.get("question")
|
||||
if not question:
|
||||
maybe_question = args.get("question")
|
||||
if not maybe_question:
|
||||
raise ValueError("'question' is required when 'message_id' is not provided")
|
||||
question = maybe_question
|
||||
|
||||
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
|
||||
db.session.add(annotation)
|
||||
@@ -110,7 +152,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
|
||||
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@@ -217,7 +259,7 @@ class AppAnnotationService:
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@@ -251,7 +293,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@@ -270,7 +312,11 @@ class AppAnnotationService:
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation.content = args["answer"]
|
||||
answer = args.get("answer")
|
||||
if answer is None:
|
||||
raise ValueError("'answer' is required")
|
||||
|
||||
annotation.content = answer
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
@@ -613,7 +659,7 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(
|
||||
cls, app_id: str, annotation_setting_id: str, args: dict
|
||||
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
|
||||
) -> AnnotationSettingDict:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
|
||||
Reference in New Issue
Block a user