refactor(api): type pipeline template retrieval dicts with TypedDict (#34874)

This commit is contained in:
YBoy
2026-04-17 02:13:54 -06:00
committed by GitHub
parent e70e4fa41d
commit 0020aa8f59
4 changed files with 65 additions and 23 deletions

View File

@@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
@@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param template_id: Template ID
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(template_id)

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class CustomizedTemplateItemDict(TypedDict):
id: str
name: str
description: str
icon: dict[str, Any]
position: int
chunk_structure: str
class CustomizedTemplatesResultDict(TypedDict):
pipeline_templates: list[CustomizedTemplateItemDict]
class CustomizedTemplateDetailDict(TypedDict):
id: str
name: str
icon_info: dict[str, Any]
description: str
chunk_structure: str
export_data: str
graph: dict[str, Any]
created_by: str
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from database
@@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
_, current_tenant_id = current_account_with_tenant()
result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
return result
return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.CUSTOMIZED
@@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
).all()
recommended_pipelines_results = []
recommended_pipelines_results: list[CustomizedTemplateItemDict] = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
recommended_pipeline_result: CustomizedTemplateItemDict = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class PipelineTemplateItemDict(TypedDict):
id: str
name: str
description: str
icon: dict[str, Any]
copyright: str
privacy_policy: str
position: int
chunk_structure: str
class PipelineTemplatesResultDict(TypedDict):
pipeline_templates: list[PipelineTemplateItemDict]
class PipelineTemplateDetailDict(TypedDict):
id: str
name: str
icon_info: dict[str, Any]
description: str
chunk_structure: str
export_data: str
graph: dict[str, Any]
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
result = self.fetch_pipeline_templates_from_db(language)
return result
return self.fetch_pipeline_templates_from_db(language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.DATABASE
@@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
).all()
)
recommended_pipelines_results = []
recommended_pipelines_results: list[PipelineTemplateItemDict] = []
for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = {
recommended_pipeline_result: PipelineTemplateItemDict = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"description": pipeline_built_in_template.description,

View File

@@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result: dict[str, Any] | None
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
return self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
return result
return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
try:
result = self.fetch_pipeline_templates_from_dify_official(language)
return self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
return result
return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
def get_type(self) -> str:
return PipelineTemplateType.REMOTE