From 43e4e161b5a8a4b2f57a2a3b468d43b3cee97cb5 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Wed, 15 Apr 2026 13:01:41 +0800 Subject: [PATCH] fix unit test --- .../unit_tests/test_tidb_on_qdrant_vector.py | 62 ++++++++++++- .../tests/unit_tests/test_tidb_service.py | 88 +++++++++++++++++++ 2 files changed, 149 insertions(+), 1 deletion(-) diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py index e4fca9f931..4ebb6aa22f 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py @@ -1,10 +1,11 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import httpx import pytest from dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector import ( TidbOnQdrantConfig, TidbOnQdrantVector, + TidbOnQdrantVectorFactory, ) from qdrant_client.http import models as rest from qdrant_client.http.exceptions import UnexpectedResponse @@ -172,3 +173,62 @@ class TestTidbOnQdrantVectorDeleteByIds: # Verify MatchAny structure assert isinstance(field_condition.match, rest.MatchAny) assert field_condition.match.any == ids + + +class TestInitVectorEndpointSelection: + """Test that init_vector selects the correct qdrant endpoint.""" + + def _make_dataset(self, tenant_id="t-1", dataset_id="d-1", index_struct_dict=None): + ds = MagicMock() + ds.tenant_id = tenant_id + ds.id = dataset_id + ds.index_struct_dict = index_struct_dict + return ds + + def _make_binding(self, account="acc", password="pwd", qdrant_endpoint=None, cluster_id="c-1"): + b = MagicMock() + b.account = account + b.password = password + b.qdrant_endpoint = qdrant_endpoint + b.cluster_id = cluster_id + return b + + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.current_app") + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.dify_config") + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient") + def test_uses_binding_endpoint_when_present(self, mock_qc, mock_db, mock_config, mock_app): + binding = self._make_binding(qdrant_endpoint="https://qdrant-custom.tidb.com") + mock_db.session.scalars.return_value.one_or_none.return_value = binding + mock_config.TIDB_ON_QDRANT_URL = "https://qdrant-global.tidb.com" + mock_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT = 20 + mock_config.TIDB_ON_QDRANT_GRPC_PORT = 6334 + mock_config.TIDB_ON_QDRANT_GRPC_ENABLED = False + mock_config.QDRANT_REPLICATION_FACTOR = 1 + mock_app.config = {"root_path": "/app"} + + ds = self._make_dataset(index_struct_dict={"type": "tidb_on_qdrant", "vector_store": {"class_prefix": "col"}}) + factory = TidbOnQdrantVectorFactory() + result = factory.init_vector(ds, [], MagicMock()) + + assert result._client_config.endpoint == "https://qdrant-custom.tidb.com" + + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.current_app") + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.dify_config") + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient") + def test_falls_back_to_global_when_binding_endpoint_is_none(self, mock_qc, mock_db, mock_config, mock_app): + binding = self._make_binding(qdrant_endpoint=None) + mock_db.session.scalars.return_value.one_or_none.return_value = binding + mock_config.TIDB_ON_QDRANT_URL = "https://qdrant-global.tidb.com" + mock_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT = 20 + mock_config.TIDB_ON_QDRANT_GRPC_PORT = 6334 + mock_config.TIDB_ON_QDRANT_GRPC_ENABLED = False + mock_config.QDRANT_REPLICATION_FACTOR = 1 + mock_app.config = {"root_path": "/app"} + + ds = self._make_dataset(index_struct_dict={"type": "tidb_on_qdrant", "vector_store": {"class_prefix": "col"}}) + factory = TidbOnQdrantVectorFactory() + result = factory.init_vector(ds, [], MagicMock()) + + assert result._client_config.endpoint == "https://qdrant-global.tidb.com" diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py index 92e52e8321..9af70bbefe 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import pytest from dify_vdb_tidb_on_qdrant.tidb_service import TidbService @@ -118,3 +119,90 @@ class TestBatchCreateTidbServerlessClusterQdrantEndpoint: assert len(result) == 1 assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + +class TestCreateTidbServerlessClusterRetry: + """Cover retry/logging paths in create_tidb_serverless_cluster.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.side_effect = [ + {"state": "CREATING", "userPrefix": ""}, + {"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}}, + ] + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com" + assert mock_get_cluster.call_count == 2 + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""} + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is None + + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=400, text="Bad Request") + mock_response.raise_for_status.side_effect = Exception("HTTP 400") + mock_http.post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP 400"): + TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + +class TestBatchCreateEdgeCases: + """Cover logging/edge-case branches in batch_create.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = None + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, project_id="proj", api_url="url", iam_url="iam", + public_key="pub", private_key="priv", region="us-east-1", + ) + + assert len(result) == 0 + mock_fetch_ep.assert_not_called() + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=500, text="Server Error") + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_http.post.return_value = mock_response + mock_redis.setex = MagicMock() + + with pytest.raises(Exception, match="HTTP 500"): + TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, project_id="proj", api_url="url", iam_url="iam", + public_key="pub", private_key="priv", region="us-east-1", + )