mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 23:40:16 +08:00
Merge remote-tracking branch 'origin/deploy/dev' into feat/evaluation
# Conflicts: # .vite-hooks/pre-commit # api/controllers/console/__init__.py # api/core/agent/base_agent_runner.py # api/core/app/app_config/easy_ui_based_app/model_config/converter.py # api/core/app/apps/agent_chat/app_runner.py # api/core/entities/provider_configuration.py # api/core/helper/moderation.py # api/core/model_manager.py # api/core/rag/embedding/cached_embedding.py # api/core/rag/retrieval/dataset_retrieval.py # api/core/rag/splitter/fixed_text_splitter.py # api/core/workflow/nodes/datasource/datasource_node.py # api/core/workflow/nodes/knowledge_index/knowledge_index_node.py # api/models/human_input.py # api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py # api/services/workflow_service.py # api/tasks/trigger_processing_tasks.py # api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py # api/tests/integration_tests/workflow/nodes/test_http.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/controllers/service_api/app/test_conversation.py # api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py # api/tests/unit_tests/core/variables/test_segment.py # api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py # api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py # api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py # api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py # api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py # api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py # web/app/(commonLayout)/layout.tsx # web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx # web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx # web/app/components/app/workflow-log/__tests__/list.spec.tsx # web/app/components/apps/__tests__/list.spec.tsx # web/app/components/apps/list.tsx # web/app/components/base/chat/chat-with-history/header/operation.tsx # web/app/components/base/chat/chat-with-history/sidebar/operation.tsx # web/app/components/header/account-setting/data-source-page-new/operator.tsx # web/app/components/header/account-setting/members-page/operation/index.tsx # web/app/components/plugins/marketplace/sort-dropdown/__tests__/index.spec.tsx # web/app/components/plugins/marketplace/sort-dropdown/index.tsx # web/app/components/plugins/plugin-page/plugin-tasks/index.tsx # web/app/components/workflow/header/__tests__/test-run-menu.spec.tsx # web/app/components/workflow/header/test-run-menu.tsx # web/app/components/workflow/nodes/_base/components/next-step/operator.tsx # web/app/components/workflow/nodes/_base/components/panel-operator/index.tsx # web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx # web/app/components/workflow/nodes/assigner/components/operation-selector.tsx # web/app/components/workflow/operator/__tests__/more-actions.spec.tsx # web/app/components/workflow/operator/zoom-in-out.tsx # web/app/components/workflow/panel/version-history-panel/context-menu/menu-item.tsx # web/app/components/workflow/selection-contextmenu.tsx # web/eslint-suppressions.json Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
@@ -20,11 +20,11 @@
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@/app/components/base/ui/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
vi.mock('@langgenius/dify-ui/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { Button } from '@/app/components/base/ui/button'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
|
||||
18
.github/workflows/pyrefly-diff-comment.yml
vendored
18
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -76,13 +76,11 @@ jobs:
|
||||
diff += '\\n\\n... (truncated) ...';
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
if (diff.trim()) {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
|
||||
});
|
||||
}
|
||||
|
||||
34
.github/workflows/web-tests.yml
vendored
34
.github/workflows/web-tests.yml
vendored
@@ -89,3 +89,37 @@ jobs:
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./packages/dify-ui
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install Chromium for Browser Mode
|
||||
run: vp exec playwright install --with-deps chromium
|
||||
|
||||
- name: Run dify-ui tests
|
||||
run: vp test run --coverage --silent=passed-only
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: packages/dify-ui/coverage
|
||||
flags: dify-ui
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
import secrets
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
@@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm):
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account = db.session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
with Session(db.engine) as session:
|
||||
account = session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
session.commit()
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
@@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm):
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
db.session.commit()
|
||||
with Session(db.engine) as session:
|
||||
account = session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
|
||||
1
api/constants/dsl_version.py
Normal file
1
api/constants/dsl_version.py
Normal file
@@ -0,0 +1 @@
|
||||
CURRENT_APP_DSL_VERSION = "0.6.0"
|
||||
@@ -126,8 +126,6 @@ from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@@ -215,12 +213,12 @@ __all__ = [
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
|
||||
@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
|
||||
@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
return str(value_type.exposed_type())
|
||||
|
||||
|
||||
class FullContentDict(TypedDict):
|
||||
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"value_type": str(variable_file.value_type.exposed_type()),
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value_type": str(v.value_type.exposed_type()),
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
|
||||
@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
return str(SegmentType(value).exposed_type())
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@@ -4,20 +4,6 @@ import uuid
|
||||
from decimal import Decimal
|
||||
from typing import Union, cast
|
||||
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@@ -43,6 +29,20 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
@@ -300,7 +300,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
tool_instance = tool_instances.get(prompt_tool.name)
|
||||
if tool_instance:
|
||||
self.update_prompt_message_tool(tool_instance, prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
@@ -19,6 +16,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
|
||||
@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.protocols import WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
from graphon.http.protocols import HttpResponseProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.file import File
|
||||
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@@ -350,7 +350,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
|
||||
@@ -352,11 +352,11 @@ class DatasourceManager:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.CUSTOM,
|
||||
file_type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
|
||||
@@ -8,16 +8,6 @@ from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -34,6 +24,16 @@ from core.entities.provider_entities import (
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:param session: optional database session
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
# fix origin data
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if credential_record and credential_record.encrypted_config:
|
||||
if not credential_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {"openai_api_key": credential_record.encrypted_config}
|
||||
@@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
return validated_credentials
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
@@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
credential_name = self._generate_provider_credential_name(pre_session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
credentials = self.validate_provider_credentials(credentials=credentials)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
try:
|
||||
new_record = ProviderCredential(
|
||||
@@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
|
||||
session.flush()
|
||||
|
||||
if not provider_record:
|
||||
# If provider record does not exist, create it
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
@@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
credential_name=credential_name, session=pre_session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
||||
credentials = self.validate_provider_credentials(
|
||||
credentials=credentials, credential_id=credential_id, session=session
|
||||
)
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
@@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
@@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
@@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
@@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
json.loads(credential_record.encrypted_config)
|
||||
if credential_record and credential_record.encrypted_config
|
||||
@@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
return validated_credentials
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
@@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
model=model, model_type=model_type, session=pre_session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
try:
|
||||
@@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
|
||||
session.add(credential)
|
||||
session.flush()
|
||||
|
||||
# save provider model
|
||||
if not provider_model_record:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
@@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
session=session,
|
||||
session=pre_session,
|
||||
exclude_id=credential_id,
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
@@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
|
||||
@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
@@ -2,14 +2,13 @@ import logging
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.tools.errors import ToolSSRFError
|
||||
from graphon.http.response import HttpResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -267,4 +268,47 @@ class SSRFProxy:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
|
||||
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
|
||||
return HttpResponse(
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
content=response.content,
|
||||
url=str(response.url) if response.url else None,
|
||||
reason_phrase=response.reason_phrase,
|
||||
fallback_text=response.text,
|
||||
)
|
||||
|
||||
|
||||
class GraphonSSRFProxy:
|
||||
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
graphon_ssrf_proxy = GraphonSSRFProxy()
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
@@ -25,9 +11,24 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
@@ -169,7 +170,7 @@ class ModelInstance:
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@@ -194,7 +195,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@@ -214,7 +215,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@@ -236,7 +237,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
@@ -253,7 +254,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@@ -278,7 +279,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@@ -306,7 +307,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@@ -325,7 +326,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
@@ -341,7 +342,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
@@ -358,14 +359,14 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import validate_project_name, validate_url
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
@@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel):
|
||||
return validate_project_name(v, default_name)
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TencentConfig(BaseTracingConfig):
|
||||
"""
|
||||
Tencent APM tracing config
|
||||
"""
|
||||
|
||||
token: str
|
||||
endpoint: str
|
||||
service_name: str
|
||||
|
||||
@field_validator("token")
|
||||
@classmethod
|
||||
def token_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Token cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||
|
||||
@field_validator("service_name")
|
||||
@classmethod
|
||||
def service_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
|
||||
class MLflowConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for MLflow tracing config.
|
||||
"""
|
||||
|
||||
tracking_uri: str = "http://localhost:5000"
|
||||
experiment_id: str = "0" # Default experiment id in MLflow is 0
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("tracking_uri")
|
||||
@classmethod
|
||||
def tracking_uri_validator(cls, v, info: ValidationInfo):
|
||||
if isinstance(v, str) and v.startswith("databricks"):
|
||||
raise ValueError(
|
||||
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
|
||||
)
|
||||
return validate_url_with_path(v, "http://localhost:5000")
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
class DatabricksConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Databricks (Databricks-managed MLflow) tracing config.
|
||||
"""
|
||||
|
||||
experiment_id: str
|
||||
host: str
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
personal_access_token: str | None = None
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict):
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
try:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.OPIK:
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
case TracingProviderEnum.OPIK:
|
||||
from dify_trace_opik.config import OpikConfig
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace
|
||||
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from dify_trace_weave.config import WeaveConfig
|
||||
from dify_trace_weave.weave_trace import WeaveDataTrace
|
||||
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import ArizeConfig
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig
|
||||
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import PhoenixConfig
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from dify_trace_arize_phoenix.config import PhoenixConfig
|
||||
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from core.ops.entities.config_entity import MLflowConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from dify_trace_mlflow.config import MLflowConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from core.ops.entities.config_entity import DatabricksConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from dify_trace_mlflow.config import DatabricksConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from dify_trace_tencent.config import TencentConfig
|
||||
from dify_trace_tencent.tencent_trace import TencentDataTrace
|
||||
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
except ImportError:
|
||||
raise ImportError(f"Provider {provider} is not installed.")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
provider_schema.icon_small_dark.zh_hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
else provider_schema.icon_small_dark.en_us
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
@@ -5,7 +5,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
|
||||
@@ -4,8 +4,6 @@ import pickle
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@@ -15,6 +13,8 @@ from core.model_manager import ModelInstance
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
_closed: bool
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||
"""Initialize with file path."""
|
||||
self._closed = False
|
||||
self.file_path = file_path
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Best-effort cleanup for downloaded temporary files."""
|
||||
if getattr(self, "_closed", False):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
temp_file = getattr(self, "temp_file", None)
|
||||
if temp_file is None:
|
||||
return
|
||||
|
||||
try:
|
||||
close_result = temp_file.close()
|
||||
if inspect.isawaitable(close_result):
|
||||
close_awaitable = getattr(close_result, "close", None)
|
||||
if callable(close_awaitable):
|
||||
close_awaitable()
|
||||
except Exception:
|
||||
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
self.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
|
||||
@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@@ -9,11 +9,6 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -69,6 +64,11 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import (
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.helper import parse_uuid_str_or_none
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
@@ -517,11 +517,11 @@ class DatasetRetrieval:
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@@ -7,10 +7,9 @@ import re
|
||||
from collections.abc import Collection
|
||||
from typing import Any, Literal
|
||||
|
||||
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@@ -28,7 +28,7 @@ class ToolFileManager:
|
||||
def _build_graph_file_reference(tool_file: ToolFile) -> File:
|
||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||
return File(
|
||||
type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
file_type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
|
||||
@@ -1083,7 +1083,12 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.value
|
||||
if not isinstance(variable_selector, list) or not all(
|
||||
isinstance(selector_part, str) for selector_part in variable_selector
|
||||
):
|
||||
raise ToolParameterError("Variable tool input must be a variable selector")
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
|
||||
@@ -18,7 +18,7 @@ from graphon.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
transfer_method_value = file_dict.get("transfer_method")
|
||||
if not isinstance(transfer_method_value, str):
|
||||
raise ValueError("Workflow file mapping is missing a valid transfer_method")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Workflow-layer adapters for legacy human-input payload keys.
|
||||
"""Workflow-to-Graphon adapters for persisted node payloads.
|
||||
|
||||
Stored workflow graphs and editor payloads may still use Dify-specific human
|
||||
input recipient keys. Normalize them here before handing configs to
|
||||
`graphon` so graph-owned models only see graph-neutral field names.
|
||||
Stored workflow graphs and editor payloads still contain a small set of
|
||||
Dify-owned field spellings and value shapes. Adapt them here before handing the
|
||||
payload to Graphon so Graphon-owned models only see current contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
|
||||
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
|
||||
|
||||
|
||||
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
|
||||
normalized = normalize_human_input_node_data_for_graph(node_data)
|
||||
normalized = adapt_human_input_node_data_for_graph(node_data)
|
||||
raw_delivery_methods = normalized.get("delivery_methods")
|
||||
if not isinstance(raw_delivery_methods, list):
|
||||
return []
|
||||
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
|
||||
return False
|
||||
|
||||
|
||||
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
|
||||
|
||||
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return normalized
|
||||
return normalize_human_input_node_data_for_graph(normalized)
|
||||
node_type = normalized.get("type")
|
||||
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return adapt_human_input_node_data_for_graph(normalized)
|
||||
if node_type == BuiltinNodeTypes.TOOL:
|
||||
return _adapt_tool_node_data_for_graph(normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_config)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
|
||||
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
|
||||
if data_mapping is None:
|
||||
return normalized
|
||||
|
||||
normalized["data"] = normalize_node_data_for_graph(data_mapping)
|
||||
normalized["data"] = adapt_node_data_for_graph(data_mapping)
|
||||
return normalized
|
||||
|
||||
|
||||
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(node_data)
|
||||
|
||||
raw_tool_configurations = normalized.get("tool_configurations")
|
||||
if not isinstance(raw_tool_configurations, Mapping):
|
||||
return normalized
|
||||
|
||||
existing_tool_parameters = normalized.get("tool_parameters")
|
||||
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
|
||||
normalized_tool_configurations: dict[str, Any] = {}
|
||||
found_legacy_tool_inputs = False
|
||||
|
||||
for name, value in raw_tool_configurations.items():
|
||||
if not isinstance(value, Mapping):
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, dict(value))
|
||||
|
||||
flattened_value = _flatten_legacy_tool_configuration_value(
|
||||
input_type=input_type,
|
||||
input_value=input_value,
|
||||
)
|
||||
if flattened_value is not None:
|
||||
normalized_tool_configurations[name] = flattened_value
|
||||
|
||||
if not found_legacy_tool_inputs:
|
||||
return normalized
|
||||
|
||||
normalized["tool_parameters"] = normalized_tool_parameters
|
||||
normalized["tool_configurations"] = normalized_tool_configurations
|
||||
return normalized
|
||||
|
||||
|
||||
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
|
||||
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
|
||||
return input_value
|
||||
|
||||
if (
|
||||
input_type == "variable"
|
||||
and isinstance(input_value, list)
|
||||
and all(isinstance(item, str) for item in input_value)
|
||||
):
|
||||
return "{{#" + ".".join(input_value) + "#}}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
@@ -291,9 +349,9 @@ __all__ = [
|
||||
"MemberRecipient",
|
||||
"WebAppDeliveryMethod",
|
||||
"_WebAppDeliveryConfig",
|
||||
"adapt_human_input_node_data_for_graph",
|
||||
"adapt_node_config_for_graph",
|
||||
"adapt_node_data_for_graph",
|
||||
"is_human_input_webapp_enabled",
|
||||
"normalize_human_input_node_data_for_graph",
|
||||
"normalize_node_config_for_graph",
|
||||
"normalize_node_data_for_graph",
|
||||
"parse_human_input_delivery_methods",
|
||||
]
|
||||
@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_runtime import (
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph.graph import NodeFactory
|
||||
from graphon.model_runtime.memory import PromptMessageMemory
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
"""Resolve the production node class for the requested type/version."""
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_http_client = graphon_ssrf_proxy
|
||||
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
|
||||
self._dify_context,
|
||||
conversation_id_getter=self._conversation_id,
|
||||
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
# Graph configs are initially validated against permissive shared node data.
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
),
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=False,
|
||||
include_llm_file_saver=False,
|
||||
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
validate_node_data = getattr(node_class, "validate_node_data", None)
|
||||
if callable(validate_node_data):
|
||||
return cast("BaseNodeData", validate_node_data(node_data))
|
||||
return node_data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.nodes.llm.runtime_protocols import (
|
||||
PreparedLLMProtocol,
|
||||
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .human_input_compat import (
|
||||
from .human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
DeliveryMethodType,
|
||||
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeExecutionType,
|
||||
@@ -12,13 +17,6 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment
|
||||
|
||||
from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from .exc import DatasourceNodeError
|
||||
|
||||
@@ -37,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -2,17 +2,15 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.template import Template
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
|
||||
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.template import Template
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@@ -33,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -51,6 +50,18 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
|
||||
if value is None or isinstance(value, (str, float)):
|
||||
return value
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_sequence_item(value: object) -> str:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
@@ -60,13 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@@ -283,18 +295,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
resolved_value: str | Sequence[str] | int | float | None
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import Any
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from extensions.ext_database import db
|
||||
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
@@ -135,29 +135,30 @@ def _build_from_local_file(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
with session_factory.create_session() as session:
|
||||
row = session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type", "custom"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type", "custom"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
reference=build_file_reference(record_id=str(row.id)),
|
||||
size=row.size,
|
||||
storage_key=row.key,
|
||||
)
|
||||
return File(
|
||||
file_id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
reference=build_file_reference(record_id=str(row.id)),
|
||||
size=row.size,
|
||||
storage_key=row.key,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_remote_url(
|
||||
@@ -179,32 +180,33 @@ def _build_from_remote_url(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = standardize_file_type(
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
detected_file_type = standardize_file_type(
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
return File(
|
||||
file_id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
@@ -220,9 +222,9 @@ def _build_from_remote_url(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=filename,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
@@ -247,30 +249,31 @@ def _build_from_tool_file(
|
||||
ToolFile.id == tool_file_id,
|
||||
ToolFile.tenant_id == tenant_id,
|
||||
)
|
||||
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=tool_file.name,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
return File(
|
||||
file_id=mapping.get("id"),
|
||||
filename=tool_file.name,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _build_from_datasource_file(
|
||||
@@ -289,31 +292,32 @@ def _build_from_datasource_file(
|
||||
UploadFile.id == datasource_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if datasource_file is None:
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
with session_factory.create_session() as session:
|
||||
datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||
if datasource_file is None:
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
reference=build_file_reference(record_id=str(datasource_file.id)),
|
||||
extension=extension,
|
||||
mime_type=datasource_file.mime_type,
|
||||
size=datasource_file.size,
|
||||
storage_key=datasource_file.key,
|
||||
url=datasource_file.source_url,
|
||||
)
|
||||
return File(
|
||||
file_id=mapping.get("datasource_file_id"),
|
||||
filename=datasource_file.name,
|
||||
file_type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
reference=build_file_reference(record_id=str(datasource_file.id)),
|
||||
extension=extension,
|
||||
mime_type=datasource_file.mime_type,
|
||||
size=datasource_file.size,
|
||||
storage_key=datasource_file.key,
|
||||
url=datasource_file.source_url,
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
|
||||
|
||||
@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
|
||||
|
||||
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type.exposed_type().value
|
||||
return str(v.value_type.exposed_type())
|
||||
else:
|
||||
value_type = v.get("value_type")
|
||||
if value_type is None:
|
||||
raise ValueError("value_type is required but not provided")
|
||||
return value_type.exposed_type().value
|
||||
return str(value_type.exposed_type())
|
||||
|
||||
@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
return str(SegmentType(value).exposed_type())
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": value.value,
|
||||
"value_type": value.value_type.exposed_type().value,
|
||||
"value_type": str(value.value_type.exposed_type()),
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, dict):
|
||||
|
||||
@@ -6,8 +6,8 @@ from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
@@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
with session_factory.create_session() as session:
|
||||
data_source_binding = session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
session.add(new_data_source_binding)
|
||||
session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
@@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
with session_factory.create_session() as session:
|
||||
data_source_binding = session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
session.add(new_data_source_binding)
|
||||
session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
with session_factory.create_session() as session:
|
||||
data_source_binding = session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -36,24 +37,27 @@ class WorkflowComment(Base):
|
||||
|
||||
__tablename__ = "workflow_comments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(db.Float)
|
||||
position_y: Mapped[float] = mapped_column(db.Float)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(sa.Float)
|
||||
position_y: Mapped[float] = mapped_column(sa.Float)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
|
||||
resolved: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
|
||||
# Relationships
|
||||
@@ -143,23 +147,26 @@ class WorkflowCommentReply(Base):
|
||||
|
||||
__tablename__ = "workflow_comment_replies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
Index("comment_replies_comment_idx", "comment_id"),
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
comment: Mapped["WorkflowComment"] = relationship(
|
||||
"WorkflowComment", back_populates="replies")
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@@ -187,24 +194,26 @@ class WorkflowCommentMention(Base):
|
||||
|
||||
__tablename__ = "workflow_comment_mentions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
Index("comment_mentions_comment_idx", "comment_id"),
|
||||
Index("comment_mentions_reply_idx", "reply_id"),
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
comment: Mapped["WorkflowComment"] = relationship(
|
||||
"WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]
|
||||
] = relationship("WorkflowCommentReply")
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
|
||||
@@ -3,11 +3,11 @@ from enum import StrEnum
|
||||
from typing import Annotated, Literal, Self, final
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from core.workflow.human_input_compat import DeliveryMethodType
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.helper import generate_string
|
||||
|
||||
from .base import Base, DefaultFieldsMixin
|
||||
|
||||
@@ -5,7 +5,8 @@ from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
|
||||
return tenant_resolver()
|
||||
|
||||
|
||||
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
|
||||
"""Build a graph `File` directly from serialized metadata."""
|
||||
|
||||
def _coerce_file_type(value: Any) -> FileType:
|
||||
if isinstance(value, FileType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return FileType.value_of(value)
|
||||
raise ValueError("file type is required in file mapping")
|
||||
|
||||
mapping = dict(file_mapping)
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if isinstance(transfer_method_value, FileTransferMethod):
|
||||
transfer_method = transfer_method_value
|
||||
elif isinstance(transfer_method_value, str):
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
else:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
|
||||
file_id = mapping.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
legacy_id = mapping.get("id")
|
||||
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
|
||||
|
||||
related_id = resolve_file_record_id(mapping)
|
||||
if related_id is None:
|
||||
raw_related_id = mapping.get("related_id")
|
||||
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
|
||||
|
||||
remote_url = mapping.get("remote_url")
|
||||
if not isinstance(remote_url, str) or not remote_url:
|
||||
url = mapping.get("url")
|
||||
remote_url = url if isinstance(url, str) and url else None
|
||||
|
||||
reference = mapping.get("reference")
|
||||
if not isinstance(reference, str) or not reference:
|
||||
reference = None
|
||||
|
||||
filename = mapping.get("filename")
|
||||
if not isinstance(filename, str):
|
||||
filename = None
|
||||
|
||||
extension = mapping.get("extension")
|
||||
if not isinstance(extension, str):
|
||||
extension = None
|
||||
|
||||
mime_type = mapping.get("mime_type")
|
||||
if not isinstance(mime_type, str):
|
||||
mime_type = None
|
||||
|
||||
size = mapping.get("size", -1)
|
||||
if not isinstance(size, int):
|
||||
size = -1
|
||||
|
||||
storage_key = mapping.get("storage_key")
|
||||
if not isinstance(storage_key, str):
|
||||
storage_key = None
|
||||
|
||||
tenant_id = mapping.get("tenant_id")
|
||||
if not isinstance(tenant_id, str):
|
||||
tenant_id = None
|
||||
|
||||
dify_model_identity = mapping.get("dify_model_identity")
|
||||
if not isinstance(dify_model_identity, str):
|
||||
dify_model_identity = FILE_MODEL_IDENTITY
|
||||
|
||||
tool_file_id = mapping.get("tool_file_id")
|
||||
if not isinstance(tool_file_id, str):
|
||||
tool_file_id = None
|
||||
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not isinstance(upload_file_id, str):
|
||||
upload_file_id = None
|
||||
|
||||
datasource_file_id = mapping.get("datasource_file_id")
|
||||
if not isinstance(datasource_file_id, str):
|
||||
datasource_file_id = None
|
||||
|
||||
return File(
|
||||
file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
reference=reference,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
storage_key=storage_key,
|
||||
dify_model_identity=dify_model_identity,
|
||||
url=remote_url,
|
||||
tool_file_id=tool_file_id,
|
||||
upload_file_id=upload_file_id,
|
||||
datasource_file_id=datasource_file_id,
|
||||
)
|
||||
|
||||
|
||||
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
|
||||
"""Recursively rebuild serialized graph file payloads into `File` objects.
|
||||
|
||||
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
|
||||
`model_validate_json()`. Dify keeps this recovery path at the model boundary
|
||||
so historical JSON blobs remain readable without reintroducing global graph
|
||||
patches or test-local coercion.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
if maybe_file_object(value):
|
||||
return build_file_from_mapping_without_lookup(file_mapping=value)
|
||||
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def build_file_from_stored_mapping(
|
||||
*,
|
||||
file_mapping: Mapping[str, Any],
|
||||
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
|
||||
pass
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
|
||||
remote_url = mapping.get("remote_url")
|
||||
if not isinstance(remote_url, str) or not remote_url:
|
||||
url = mapping.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
mapping["remote_url"] = url
|
||||
return File.model_validate(mapping)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=mapping)
|
||||
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
|
||||
@@ -37,7 +37,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.variable_prefixes import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
@@ -65,7 +65,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
from .utils.file_input_compat import build_file_from_stored_mapping
|
||||
from .utils.file_input_compat import (
|
||||
build_file_from_mapping_without_lookup,
|
||||
build_file_from_stored_mapping,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -339,7 +342,7 @@ class Workflow(Base): # bug
|
||||
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise NodeNotFoundError(node_id)
|
||||
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
|
||||
@@ -1737,7 +1740,7 @@ class WorkflowDraftVariable(Base):
|
||||
return cast(Any, value)
|
||||
normalized_file = dict(value)
|
||||
normalized_file.pop("tenant_id", None)
|
||||
return File.model_validate(normalized_file)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
|
||||
elif isinstance(value, list) and value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
@@ -1747,7 +1750,7 @@ class WorkflowDraftVariable(Base):
|
||||
for item in value_list:
|
||||
normalized_file = dict(cast(dict[str, Any], item))
|
||||
normalized_file.pop("tenant_id", None)
|
||||
file_list.append(File.model_validate(normalized_file))
|
||||
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return cast(Any, value)
|
||||
|
||||
@@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API
|
||||
|
||||
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
|
||||
|
||||
## Excluding Providers
|
||||
|
||||
In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-<provider>` and `--group trace-<provider>` to enable specific providers.
|
||||
|
||||
78
api/providers/trace/README.md
Normal file
78
api/providers/trace/README.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Trace providers
|
||||
|
||||
This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
|
||||
|
||||
Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
|
||||
|
||||
## Architecture
|
||||
|
||||
| Layer | Location | Role |
|
||||
|--------|----------|------|
|
||||
| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
|
||||
| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
|
||||
| Your package | `api/providers/trace/trace-<name>/` | Pydantic config + subclass of `BaseTraceInstance` |
|
||||
|
||||
At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
|
||||
|
||||
## What you implement
|
||||
|
||||
### 1. Config model (`BaseTracingConfig`)
|
||||
|
||||
Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
|
||||
|
||||
Fields fall into two groups used by the manager:
|
||||
|
||||
- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
|
||||
- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
|
||||
|
||||
List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
|
||||
|
||||
### 2. Trace instance (`BaseTraceInstance`)
|
||||
|
||||
Subclass `BaseTraceInstance` and implement:
|
||||
|
||||
```python
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
...
|
||||
```
|
||||
|
||||
Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
|
||||
|
||||
- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
|
||||
- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
|
||||
- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
|
||||
|
||||
You may ignore categories your backend does not support; existing providers often no-op unhandled types.
|
||||
|
||||
Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
|
||||
|
||||
### 3. Register in the API core
|
||||
|
||||
Upstream changes are required so Dify knows your provider exists:
|
||||
|
||||
1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
|
||||
2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
|
||||
- `config_class`: your Pydantic config type
|
||||
- `secret_keys` / `other_keys`: lists of field names as above
|
||||
- `trace_instance`: your `BaseTraceInstance` subclass
|
||||
Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
|
||||
|
||||
If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
|
||||
|
||||
## Package layout
|
||||
|
||||
Each provider is a normal uv workspace member, for example:
|
||||
|
||||
- `api/providers/trace/trace-<name>/pyproject.toml` — project name `dify-trace-<name>`, dependencies on vendor SDKs
|
||||
- `api/providers/trace/trace-<name>/src/dify_trace_<name>/` — `config.py`, `<name>_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
|
||||
|
||||
Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
|
||||
|
||||
## Wiring into the `api` workspace
|
||||
|
||||
In `api/pyproject.toml`:
|
||||
|
||||
1. **`[tool.uv.sources]`** — `dify-trace-<name> = { workspace = true }`
|
||||
2. **`[dependency-groups]`** — add `trace-<name> = ["dify-trace-<name>"]` and include `dify-trace-<name>` in `trace-all` if it should ship with the default bundle
|
||||
|
||||
After changing metadata, run **`uv sync`** from `api/`.
|
||||
14
api/providers/trace/trace-aliyun/pyproject.toml
Normal file
14
api/providers/trace/trace-aliyun/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-trace-aliyun"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
# versions inherited from parent
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-semantic-conventions",
|
||||
]
|
||||
description = "Dify ops tracing provider (Aliyun)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -6,7 +6,20 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
@@ -14,8 +27,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
DIFY_APP_ID,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
@@ -34,7 +47,7 @@ from core.ops.aliyun_trace.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
from dify_trace_aliyun.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
@@ -46,19 +59,6 @@ from core.ops.aliyun_trace.utils import (
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
@@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
|
||||
from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
@@ -6,7 +6,8 @@ from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
from core.rag.models.document import Document
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
@@ -15,7 +16,6 @@ from core.ops.aliyun_trace.entities.semconv import (
|
||||
OUTPUT_VALUE,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
|
||||
@@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
|
||||
|
||||
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
|
||||
from dify_trace_aliyun.data_exporter.traceclient import create_link
|
||||
|
||||
links = []
|
||||
if trace_id:
|
||||
@@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
INVALID_SPAN_ID,
|
||||
SpanBuilder,
|
||||
TraceClient,
|
||||
@@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -41,8 +40,8 @@ def trace_client_factory():
|
||||
|
||||
|
||||
class TestTraceClient:
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
|
||||
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
|
||||
mock_gethostname.return_value = "test-host"
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@@ -56,7 +55,7 @@ class TestTraceClient:
|
||||
client.shutdown()
|
||||
assert client.done is True
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@@ -64,8 +63,8 @@ class TestTraceClient:
|
||||
client.export(spans)
|
||||
mock_exporter.export.assert_called_once_with(spans)
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
@@ -74,8 +73,8 @@ class TestTraceClient:
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is True
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -84,8 +83,8 @@ class TestTraceClient:
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is False
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_head.side_effect = httpx.RequestError("Connection error")
|
||||
|
||||
@@ -93,12 +92,12 @@ class TestTraceClient:
|
||||
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
|
||||
client.api_check()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_add_span(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
@@ -134,8 +133,8 @@ class TestTraceClient:
|
||||
assert len(client.queue) == 2
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
|
||||
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
|
||||
|
||||
@@ -159,7 +158,7 @@ class TestTraceClient:
|
||||
assert len(client.queue) == 1
|
||||
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
mock_exporter.export.side_effect = Exception("Export failed")
|
||||
@@ -168,11 +167,11 @@ class TestTraceClient:
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
|
||||
client._export_batch()
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
|
||||
# We need to test the wait timeout in _worker
|
||||
# But _worker runs in a thread. Let's mock condition.wait.
|
||||
@@ -189,7 +188,7 @@ class TestTraceClient:
|
||||
# mock_wait might have been called
|
||||
assert mock_wait.called or client.done
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@@ -268,7 +267,7 @@ def test_generate_span_id():
|
||||
assert span_id != INVALID_SPAN_ID
|
||||
|
||||
# Test retry loop
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand:
|
||||
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
|
||||
span_id = generate_span_id()
|
||||
assert span_id == 999
|
||||
@@ -290,7 +289,7 @@ def test_convert_to_trace_id():
|
||||
def test_convert_string_to_id():
|
||||
assert convert_string_to_id("test") > 0
|
||||
# Test with None string
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen:
|
||||
mock_gen.return_value = 12345
|
||||
assert convert_string_to_id(None) == 12345
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import pytest
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
|
||||
|
||||
class TestTraceMetadata:
|
||||
def test_trace_metadata_init(self):
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
ACS_ARMS_SERVICE_FEATURE,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
@@ -4,12 +4,11 @@ from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
||||
import pytest
|
||||
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
|
||||
|
||||
import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
@@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@@ -1,9 +1,7 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from opentelemetry.trace import Link, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
@@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import (
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
from dify_trace_aliyun.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
@@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import (
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from opentelemetry.trace import Link, StatusCode
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
@@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = end_user_data
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
from dify_trace_aliyun.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
@@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
from dify_trace_aliyun.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
@@ -112,9 +112,9 @@ def test_get_workflow_node_status():
|
||||
def test_create_links_from_trace_id(monkeypatch):
|
||||
# Mock create_link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
import core.ops.aliyun_trace.data_exporter.traceclient
|
||||
import dify_trace_aliyun.data_exporter.traceclient
|
||||
|
||||
monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
|
||||
monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
|
||||
|
||||
# Trace ID None
|
||||
assert create_links_from_trace_id(None) == []
|
||||
@@ -0,0 +1,85 @@
|
||||
import pytest
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Aliyun configuration"""
|
||||
config = AliyunConfig(
|
||||
app_name="test_app",
|
||||
license_key="test_license_key",
|
||||
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
|
||||
)
|
||||
assert config.app_name == "test_app"
|
||||
assert config.license_key == "test_license_key"
|
||||
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="")
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation preserves path for Aliyun endpoints"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_license_key_required(self):
|
||||
"""Test that license_key is required and cannot be empty"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_valid_endpoint_format_examples(self):
|
||||
"""Test valid endpoint format examples from comments"""
|
||||
valid_endpoints = [
|
||||
# cms2.0 public endpoint
|
||||
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
|
||||
# cms2.0 intranet endpoint
|
||||
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
|
||||
# xtrace public endpoint
|
||||
"http://tracing-cn-heyuan.arms.aliyuncs.com",
|
||||
# xtrace intranet endpoint
|
||||
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
|
||||
]
|
||||
|
||||
for endpoint in valid_endpoints:
|
||||
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
|
||||
assert config.endpoint == endpoint
|
||||
10
api/providers/trace/trace-arize-phoenix/pyproject.toml
Normal file
10
api/providers/trace/trace-arize-phoenix/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "dify-trace-arize-phoenix"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
]
|
||||
description = "Dify ops tracing provider (Arize / Phoenix)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -26,7 +26,6 @@ from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@@ -40,6 +39,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
@@ -0,0 +1,45 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
@@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import Tracer
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import (
|
||||
ArizePhoenixDataTrace,
|
||||
datetime_to_nanos,
|
||||
error_to_string,
|
||||
@@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
|
||||
setup_tracer,
|
||||
wrap_span_metadata,
|
||||
)
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from opentelemetry.sdk.trace import Tracer
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@@ -80,7 +80,7 @@ def test_datetime_to_nanos():
|
||||
expected = int(dt.timestamp() * 1_000_000_000)
|
||||
assert datetime_to_nanos(dt) == expected
|
||||
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
|
||||
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt:
|
||||
mock_now = MagicMock()
|
||||
mock_now.timestamp.return_value = 1704110400.0
|
||||
mock_dt.now.return_value = mock_now
|
||||
@@ -142,8 +142,8 @@ def test_wrap_span_metadata():
|
||||
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_arize(mock_provider, mock_exporter):
|
||||
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
|
||||
setup_tracer(config)
|
||||
@@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter):
|
||||
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
|
||||
config = PhoenixConfig(endpoint="http://p.com", project="p")
|
||||
setup_tracer(config)
|
||||
@@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter):
|
||||
|
||||
def test_setup_tracer_exception():
|
||||
config = ArizeConfig(endpoint="http://a.com", project="p")
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
|
||||
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
|
||||
with pytest.raises(Exception, match="boom"):
|
||||
setup_tracer(config)
|
||||
|
||||
@@ -172,7 +172,7 @@ def test_setup_tracer_exception():
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance():
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
|
||||
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup:
|
||||
mock_tracer = MagicMock(spec=Tracer)
|
||||
mock_processor = MagicMock()
|
||||
mock_setup.return_value = (mock_tracer, mock_processor)
|
||||
@@ -228,9 +228,9 @@ def test_trace_exception(trace_instance):
|
||||
trace_instance.trace(_make_workflow_info())
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
@@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac
|
||||
assert trace_instance.tracer.start_span.call_count >= 2
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_no_app_id(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
@@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance):
|
||||
trace_instance.workflow_trace(info)
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
def test_message_trace_success(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
@@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance):
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
def test_message_trace_with_error(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
@@ -1,6 +1,6 @@
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues
|
||||
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
|
||||
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
|
||||
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestArizeConfig:
|
||||
"""Test cases for ArizeConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Arize configuration"""
|
||||
config = ArizeConfig(
|
||||
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.space_id == "test_space"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = ArizeConfig()
|
||||
assert config.api_key is None
|
||||
assert config.space_id is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = ArizeConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_project_validation_none(self):
|
||||
"""Test project validation with None value"""
|
||||
config = ArizeConfig(project=None)
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = ArizeConfig(endpoint="")
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="ftp://invalid.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="invalid.com")
|
||||
|
||||
|
||||
class TestPhoenixConfig:
|
||||
"""Test cases for PhoenixConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Phoenix configuration"""
|
||||
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = PhoenixConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = PhoenixConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation with path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
|
||||
|
||||
def test_endpoint_validation_without_path(self):
|
||||
"""Test endpoint validation without path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
10
api/providers/trace/trace-langfuse/pyproject.toml
Normal file
10
api/providers/trace/trace-langfuse/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "dify-trace-langfuse"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"langfuse>=4.2.0,<5.0.0",
|
||||
]
|
||||
description = "Dify ops tracing provider (Langfuse)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||
@@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
GenerationUsage,
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
@@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
@@ -5,8 +5,16 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
@@ -43,7 +43,7 @@ def langfuse_config():
|
||||
def trace_instance(langfuse_config, monkeypatch):
|
||||
# Mock Langfuse client to avoid network calls
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
return instance
|
||||
@@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch):
|
||||
|
||||
def test_init(langfuse_config, monkeypatch):
|
||||
mock_langfuse = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
@@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
|
||||
# Mock DB and Repositories
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
@@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
@@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_execution.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
@@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
@@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
@@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
|
||||
repo.get_by_workflow_execution.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
"""Test cases for LangfuseConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Langfuse configuration"""
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_valid_config_with_path(self):
|
||||
host = "https://custom.langfuse.com/api/v1"
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == host
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
@@ -4,14 +4,15 @@ from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
def _create_trace_instance() -> LangFuseDataTrace:
|
||||
with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True):
|
||||
with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True):
|
||||
return LangFuseDataTrace(
|
||||
LangfuseConfig(
|
||||
public_key="public-key",
|
||||
@@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime:
|
||||
patch.object(trace, "add_span"),
|
||||
patch.object(trace, "add_generation") as add_generation,
|
||||
patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()),
|
||||
patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()),
|
||||
patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()),
|
||||
patch(
|
||||
"core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
"dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
return_value=repository,
|
||||
),
|
||||
):
|
||||
10
api/providers/trace/trace-langsmith/pyproject.toml
Normal file
10
api/providers/trace/trace-langsmith/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "dify-trace-langsmith"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"langsmith~=0.7.30",
|
||||
]
|
||||
description = "Dify ops tracing provider (LangSmith)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
@@ -10,7 +10,6 @@ from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@@ -22,13 +21,14 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -3,8 +3,14 @@ from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
|
||||
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
|
||||
@@ -38,7 +38,7 @@ def langsmith_config():
|
||||
def trace_instance(langsmith_config, monkeypatch):
|
||||
# Mock LangSmith client
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
return instance
|
||||
@@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch):
|
||||
|
||||
def test_init(langsmith_config, monkeypatch):
|
||||
mock_client_class = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
@@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch):
|
||||
|
||||
# Mock dependencies
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
@@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
@@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_execution.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
@@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info.error = ""
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
@@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch):
|
||||
# Mock EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
@@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestLangSmithConfig:
|
||||
"""Test cases for LangSmithConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid LangSmith configuration"""
|
||||
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.smith.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangSmithConfig(api_key="key", project="project")
|
||||
assert config.endpoint == "https://api.smith.langchain.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
|
||||
10
api/providers/trace/trace-mlflow/pyproject.toml
Normal file
10
api/providers/trace/trace-mlflow/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "dify-trace-mlflow"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"mlflow-skinny>=3.11.1",
|
||||
]
|
||||
description = "Dify ops tracing provider (MLflow / Databricks)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user