diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index ce626bbd7e..fb6eaa370a 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel, model_validator @@ -13,6 +13,13 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client +class AnalyticdbClientParamsDict(TypedDict): + access_key_id: str + access_key_secret: str + region_id: str + read_timeout: int + + class AnalyticdbVectorOpenAPIConfig(BaseModel): access_key_id: str access_key_secret: str @@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") return values - def to_analyticdb_client_params(self): - return { + def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict: + result: AnalyticdbClientParamsDict = { "access_key_id": self.access_key_id, "access_key_secret": self.access_key_secret, "region_id": self.region_id, "read_timeout": self.read_timeout, } + return result class AnalyticdbVectorOpenAPI: diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 3e0420b9d0..73787c2f00 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, TypedDict import chromadb from chromadb import QueryResult, Settings @@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset +class ChromaParamsDict(TypedDict): + host: str + port: int + ssl: bool + tenant: str + database: str + settings: Settings + + class ChromaConfig(BaseModel): host: str port: int @@ -23,14 +32,13 @@ class ChromaConfig(BaseModel): auth_provider: str | None = None auth_credentials: str | None = None - def to_chroma_params(self): + def to_chroma_params(self) -> ChromaParamsDict: settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, chroma_client_auth_credentials=self.auth_credentials, ) - - return { + result: ChromaParamsDict = { "host": self.host, "port": self.port, "ssl": False, @@ -38,6 +46,7 @@ class ChromaConfig(BaseModel): "database": self.database, "settings": settings, } + return result class ChromaVector(BaseVector):