mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-30 17:50:29 +08:00
refactor(api): fix pyright errors in jieba, milvus, couchbase, oracle, and router (#34938)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
data: Any = keyword_table_dict["__data__"]
|
||||
return dict(data["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
|
||||
@@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
topK=max_keywords_per_chunk or 10,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
@@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
|
||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
stream=False, # pyright: ignore[reportArgumentType]
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
@@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
auth = PasswordAuthenticator(config.user, config.password)
|
||||
options = ClusterOptions(auth)
|
||||
self._cluster = Cluster(config.connection_string, options)
|
||||
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
|
||||
self._bucket = self._cluster.bucket(config.bucket_name)
|
||||
self._scope = self._bucket.scope(config.scope_name)
|
||||
self._bucket_name = config.bucket_name
|
||||
@@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
|
||||
def _load_collection_fields(self, fields: list[str] | None = None):
|
||||
if fields is None:
|
||||
# Load collection fields from remote server
|
||||
collection_info = self._client.describe_collection(self._collection_name)
|
||||
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
|
||||
fields = [field["name"] for field in collection_info["fields"]]
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||
@@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
milvus_version_raw = self._client.get_server_version()
|
||||
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
|
||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||
if "Zilliz Cloud" in milvus_version:
|
||||
return True
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
@@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
|
||||
class _OraclePoolParams(TypedDict, total=False):
|
||||
user: str
|
||||
password: str
|
||||
dsn: str
|
||||
min: int
|
||||
max: int
|
||||
increment: int
|
||||
config_dir: str | None
|
||||
wallet_location: str | None
|
||||
wallet_password: str | None
|
||||
|
||||
|
||||
class OracleVectorConfig(BaseModel):
|
||||
user: str
|
||||
password: str
|
||||
@@ -127,22 +139,18 @@ class OracleVector(BaseVector):
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"dsn": config.dsn,
|
||||
"min": 1,
|
||||
"max": 5,
|
||||
"increment": 1,
|
||||
}
|
||||
pool_params = _OraclePoolParams(
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
dsn=config.dsn,
|
||||
min=1,
|
||||
max=5,
|
||||
increment=1,
|
||||
)
|
||||
if config.is_autonomous:
|
||||
pool_params.update(
|
||||
{
|
||||
"config_dir": config.config_dir,
|
||||
"wallet_location": config.wallet_location,
|
||||
"wallet_password": config.wallet_password,
|
||||
}
|
||||
)
|
||||
pool_params["config_dir"] = config.config_dir
|
||||
pool_params["wallet_location"] = config.wallet_location
|
||||
pool_params["wallet_password"] = config.wallet_password
|
||||
return oracledb.create_pool(**pool_params)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user