fix: sync 35447 to lts (#35508)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Yunlu Wen
2026-04-23 13:30:59 +08:00
committed by GitHub
parent 2256e75f16
commit e7746cb256
2 changed files with 458 additions and 3 deletions

View File

@@ -56,12 +56,37 @@ from services.feature_service import FeatureService
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
ProviderManager manages tenant-scoped model provider configuration.
The runtime adapter is injected by the composition layer so this class stays
focused on configuration assembly instead of constructing plugin runtimes.
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
Configuration assembly is cached per manager instance so call chains that
share one request-scoped manager can reuse the same provider graph instead
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
when a long-lived manager needs to observe writes performed within the same
instance scope.
"""
decoding_rsa_key: Any | None
decoding_cipher_rsa: Any | None
_configurations_cache: dict[str, ProviderConfigurations]
def __init__(self):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._configurations_cache = {}
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
"""Drop assembled provider configurations cached on this manager instance."""
if tenant_id is None:
self._configurations_cache.clear()
return
self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@@ -100,6 +125,10 @@ class ProviderManager:
:param tenant_id:
:return:
"""
cached_configurations = self._configurations_cache.get(tenant_id)
if cached_configurations is not None:
return cached_configurations
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@@ -258,6 +287,8 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
self._configurations_cache[tenant_id] = provider_configurations
# Return the encapsulated object
return provider_configurations

View File

@@ -1,4 +1,6 @@
from unittest.mock import Mock, PropertyMock, patch
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, PropertyMock, patch
import pytest
@@ -6,7 +8,14 @@ from core.entities.provider_entities import ModelSettings
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel
from models.provider_ids import ModelProviderID
@contextmanager
def _build_session_context(session: Mock):
"""Used with patch(Session, return_value=...) to emulate ``with Session(...) as s``."""
yield session
@pytest.fixture
@@ -228,3 +237,418 @@ def test_get_default_model_uses_first_available_active_model():
assert saved_default_model.model_name == "gpt-3.5-turbo"
assert saved_default_model.provider_name == "openai"
mock_session.commit.assert_called_once()
def test_get_default_model_returns_none_when_no_default_or_active_models():
mock_session = Mock()
mock_session.scalar.return_value = None
provider_configurations = Mock()
provider_configurations.get_models.return_value = []
manager = ProviderManager()
with (
patch("core.provider_manager.db.session", mock_session),
patch.object(manager, "get_configurations", return_value=provider_configurations),
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
):
result = manager.get_default_model("tenant-id", ModelType.LLM)
assert result is None
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
mock_factory_cls.assert_not_called()
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
def test_get_default_model_uses_tenant_id_factory_for_existing_default_record():
existing_default_model = TenantDefaultModel(
tenant_id="tenant-id",
provider_name="openai",
model_name="gpt-4",
model_type=ModelType.LLM,
)
mock_session = Mock()
mock_session.scalar.return_value = existing_default_model
manager = ProviderManager()
with (
patch("core.provider_manager.db.session", mock_session),
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
):
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
supported_model_types=[ModelType.LLM],
)
result = manager.get_default_model("tenant-id", ModelType.LLM)
mock_factory_cls.assert_called_once_with("tenant-id")
assert result is not None
assert result.model == "gpt-4"
assert result.provider.provider == "openai"
def test_get_configurations_uses_tenant_id_factory_and_adds_provider_aliases():
manager = ProviderManager()
provider_records = {"openai": [SimpleNamespace(provider_name="openai")]}
provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]}
preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")}
with (
patch.object(manager, "_get_all_providers", return_value=provider_records),
patch.object(manager, "_init_trial_provider_records", return_value=provider_records),
patch.object(manager, "_get_all_provider_models", return_value=provider_model_records),
patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
):
mock_factory_cls.return_value.get_providers.return_value = []
result = manager.get_configurations("tenant-id")
expected_alias = str(ModelProviderID("openai"))
mock_factory_cls.assert_called_once_with("tenant-id")
assert result.tenant_id == "tenant-id"
assert expected_alias in provider_records
assert expected_alias in provider_model_records
assert expected_alias in preferred_provider_records
@pytest.mark.parametrize(
("provider_name", "expected_provider_names"),
[
("openai", ["openai", "langgenius/openai/openai"]),
("langgenius/openai/openai", ["langgenius/openai/openai", "openai"]),
("langgenius/gemini/google", ["langgenius/gemini/google", "google"]),
],
)
def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, expected_provider_names: list[str]):
assert ProviderManager._get_provider_names(provider_name) == expected_provider_names
def test_get_provider_model_bundle_raises_for_unknown_provider():
manager = ProviderManager()
with patch.object(manager, "get_configurations", return_value={}):
with pytest.raises(ValueError, match="Provider openai does not exist."):
manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
def test_get_configurations_builds_provider_configuration(
mock_provider_entity,
):
manager = ProviderManager()
provider_configuration = Mock()
provider_factory = Mock()
provider_factory.get_providers.return_value = [mock_provider_entity]
custom_configuration = SimpleNamespace(provider=None, models=[])
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
with (
patch.object(manager, "_get_all_providers", return_value={"openai": []}),
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
patch.object(manager, "_to_model_settings", return_value=[]),
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration) as mock_pc,
):
manager.get_configurations("tenant-id")
mock_pc.assert_called_once()
call_kw = mock_pc.call_args.kwargs
assert call_kw["tenant_id"] == "tenant-id"
assert call_kw["provider"] is mock_provider_entity
def test_get_configurations_reuses_cached_result_for_same_tenant(mock_provider_entity):
manager = ProviderManager()
provider_configuration = Mock()
provider_factory = Mock()
provider_factory.get_providers.return_value = [mock_provider_entity]
custom_configuration = SimpleNamespace(provider=None, models=[])
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
with (
patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
patch.object(manager, "_to_model_settings", return_value=[]),
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls,
patch(
"core.provider_manager.ProviderConfiguration",
return_value=provider_configuration,
) as mock_provider_configuration,
):
first = manager.get_configurations("tenant-id")
second = manager.get_configurations("tenant-id")
assert first is second
mock_get_all_providers.assert_called_once_with("tenant-id")
mock_factory_cls.assert_called_once_with("tenant-id")
mock_provider_configuration.assert_called_once()
def test_clear_configurations_cache_rebuilds_requested_tenant(mock_provider_entity):
manager = ProviderManager()
provider_factory = Mock()
provider_factory.get_providers.return_value = [mock_provider_entity]
custom_configuration = SimpleNamespace(provider=None, models=[])
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
provider_configuration_first = Mock()
provider_configuration_second = Mock()
with (
patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
patch.object(manager, "_to_model_settings", return_value=[]),
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
patch(
"core.provider_manager.ProviderConfiguration",
side_effect=[provider_configuration_first, provider_configuration_second],
) as mock_provider_configuration,
):
first = manager.get_configurations("tenant-id")
manager.clear_configurations_cache("tenant-id")
second = manager.get_configurations("tenant-id")
assert first is not second
assert mock_get_all_providers.call_count == 2
assert mock_provider_configuration.call_count == 2
def test_get_provider_model_bundle_returns_selected_model_type_instance():
manager = ProviderManager()
provider_configuration = Mock()
model_type_instance = Mock()
provider_configuration.get_model_type_instance.return_value = model_type_instance
expected_bundle = Mock()
with (
patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}),
patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle,
):
result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
provider_configuration.get_model_type_instance.assert_called_once_with(ModelType.LLM)
mock_bundle.assert_called_once_with(
configuration=provider_configuration,
model_type_instance=model_type_instance,
)
assert result is expected_bundle
def test_get_first_provider_first_model_returns_none_when_no_models():
manager = ProviderManager()
provider_configurations = Mock()
provider_configurations.get_models.return_value = []
with patch.object(manager, "get_configurations", return_value=provider_configurations):
result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM)
assert result == (None, None)
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False)
def test_get_first_provider_first_model_returns_first_model_and_provider():
manager = ProviderManager()
provider_configurations = Mock()
provider_configurations.get_models.return_value = [
Mock(model="gpt-4", provider=Mock(provider="openai")),
Mock(model="gpt-4o", provider=Mock(provider="openai")),
]
with patch.object(manager, "get_configurations", return_value=provider_configurations):
result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM)
assert result == ("openai", "gpt-4")
def test_update_default_model_record_raises_for_unknown_provider():
manager = ProviderManager()
with patch.object(manager, "get_configurations", return_value={}):
with pytest.raises(ValueError, match="Provider openai does not exist."):
manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4")
def test_update_default_model_record_raises_for_unknown_model():
manager = ProviderManager()
provider_configurations = MagicMock()
provider_configurations.__contains__.return_value = True
provider_configurations.get_models.return_value = [Mock(model="gpt-4")]
with patch.object(manager, "get_configurations", return_value=provider_configurations):
with pytest.raises(ValueError, match="Model gpt-3.5-turbo does not exist."):
manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo")
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
def test_update_default_model_record_updates_existing_record():
manager = ProviderManager()
provider_configurations = MagicMock()
provider_configurations.__contains__.return_value = True
provider_configurations.get_models.return_value = [Mock(model="gpt-3.5-turbo")]
existing_default_model = TenantDefaultModel(
tenant_id="tenant-id",
provider_name="anthropic",
model_name="claude-3-sonnet",
model_type=ModelType.LLM,
)
mock_session = Mock()
mock_session.scalar.return_value = existing_default_model
with (
patch.object(manager, "get_configurations", return_value=provider_configurations),
patch("core.provider_manager.db.session", mock_session),
):
result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo")
assert result is existing_default_model
assert existing_default_model.provider_name == "openai"
assert existing_default_model.model_name == "gpt-3.5-turbo"
mock_session.commit.assert_called_once()
mock_session.add.assert_not_called()
def test_update_default_model_record_creates_new_record_stores_str_model_type_value():
manager = ProviderManager()
provider_configurations = MagicMock()
provider_configurations.__contains__.return_value = True
provider_configurations.get_models.return_value = [Mock(model="gpt-4")]
mock_session = Mock()
mock_session.scalar.return_value = None
with (
patch.object(manager, "get_configurations", return_value=provider_configurations),
patch("core.provider_manager.db.session", mock_session),
):
result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4")
mock_session.add.assert_called_once()
created_default_model = mock_session.add.call_args.args[0]
assert result is created_default_model
assert created_default_model.tenant_id == "tenant-id"
assert created_default_model.provider_name == "openai"
assert created_default_model.model_name == "gpt-4"
# ProviderManager persists ``ModelType`` string value in DB, not the origin Dify key.
assert created_default_model.model_type == ModelType.LLM.value
mock_session.commit.assert_called_once()
def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None:
session = Mock()
openai_provider = SimpleNamespace(provider_name="openai")
gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google")
session.scalars.return_value = [openai_provider, gemini_provider]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
):
result = ProviderManager._get_all_providers("tenant-id")
assert list(result[str(ModelProviderID("openai"))]) == [openai_provider]
assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider]
@pytest.mark.parametrize(
"method_name",
[
"_get_all_provider_models",
"_get_all_provider_model_settings",
"_get_all_provider_model_credentials",
],
)
def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None:
session = Mock()
openai_primary = SimpleNamespace(provider_name="openai")
openai_secondary = SimpleNamespace(provider_name="openai")
anthropic_record = SimpleNamespace(provider_name="anthropic")
session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
):
result = getattr(ProviderManager, method_name)("tenant-id")
assert list(result["openai"]) == [openai_primary, openai_secondary]
assert list(result["anthropic"]) == [anthropic_record]
def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None:
session = Mock()
openai_preference = SimpleNamespace(provider_name="openai")
anthropic_preference = SimpleNamespace(provider_name="anthropic")
session.scalars.return_value = [openai_preference, anthropic_preference]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
):
result = ProviderManager._get_all_preferred_model_providers("tenant-id")
assert result == {
"openai": openai_preference,
"anthropic": anthropic_preference,
}
def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None:
with (
patch("core.provider_manager.redis_client.get", return_value=b"False"),
patch("core.provider_manager.FeatureService.get_features") as mock_get_features,
patch("core.provider_manager.Session") as mock_session_cls,
):
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
assert result == {}
mock_get_features.assert_not_called()
mock_session_cls.assert_not_called()
def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None:
session = Mock()
openai_config = SimpleNamespace(provider_name="openai")
anthropic_config = SimpleNamespace(provider_name="anthropic")
session.scalars.return_value = [openai_config, anthropic_config]
with (
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
patch("core.provider_manager.redis_client.get", return_value=None),
patch("core.provider_manager.redis_client.setex") as mock_setex,
patch(
"core.provider_manager.FeatureService.get_features",
return_value=SimpleNamespace(model_load_balancing_enabled=True),
),
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
):
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True")
assert list(result["openai"]) == [openai_config]
assert list(result["anthropic"]) == [anthropic_config]