fix style

This commit is contained in:
Yansong Zhang
2026-04-14 17:24:09 +08:00
parent 99ef50e6f0
commit 729677ca2d
3 changed files with 142 additions and 7 deletions

View File

@@ -4,8 +4,6 @@ import uuid
from collections.abc import Sequence
import httpx
logger = logging.getLogger(__name__)
from httpx import DigestAuth
from configs import dify_config
@@ -15,6 +13,8 @@ from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
logger = logging.getLogger(__name__)
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
_tidb_http_client: httpx.Client = get_pooled_http_client(
"tidb:cloud",

View File

@@ -114,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds:
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
def test_delete_by_ids_with_exactly_1000(self, vector_instance):
"""Test deletion with exactly 1000 IDs triggers a single batch."""
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
@@ -129,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds:
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_splits_into_batches(self, vector_instance):
"""Test deletion with >1000 IDs triggers multiple batched calls."""
ids = [f"doc_{i}" for i in range(2500)]
vector_instance.delete_by_ids(ids)
assert vector_instance._client.delete.call_count == 3
batches = []
for call in vector_instance._client.delete.call_args_list:
filter_selector = call[1]["points_selector"]
field_condition = filter_selector.filter.must[0]
batches.append(field_condition.match.any)
assert len(batches[0]) == 1000
assert len(batches[1]) == 1000
assert len(batches[2]) == 500
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]

View File

@@ -0,0 +1,120 @@
from unittest.mock import MagicMock, patch
import pytest
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
class TestFetchQdrantEndpoint:
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
mock_get_cluster.return_value = {
"status": {
"connection_strings": {
"standard": {"host": "gateway01.us-east-1.tidbcloud.com"}
}
}
}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
mock_get_cluster.return_value = None
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_host_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"status": {"connection_strings": {"standard": {}}}}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_status_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_connection_strings_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"status": {}}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_on_exception(self, mock_get_cluster):
mock_get_cluster.side_effect = RuntimeError("network error")
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_standard_key_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"status": {"connection_strings": {}}}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
class TestCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
mock_fetch_ep.assert_called_once_with("url", "pub", "priv", "c-1")
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_result_qdrant_endpoint_none_when_fetch_fails(self, mock_config, mock_http, mock_get_cluster, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] is None
class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that batch_create includes qdrant_endpoint per cluster."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
cluster_name = "abc123"
mock_http.post.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]},
)
mock_redis.setex = MagicMock()
mock_redis.get.return_value = b"password123"
result = TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
assert len(result) == 1
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"