chore(api): adapt Graphon 0.2.2 upgrade (#35377)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
99
2026-04-18 19:16:24 +08:00
committed by GitHub
parent ae9c4244d6
commit 3e876e173a
134 changed files with 2154 additions and 1134 deletions

View File

@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
return value
try:

View File

@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.exposed_type().value,
"value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,

View File

@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:

View File

@@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

View File

@@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:

View File

@@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)

View File

@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile

View File

@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)

View File

@@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,

View File

@@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.CUSTOM,
file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),

View File

@@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
@@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials

View File

@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)

View File

@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
return HttpResponse(
status_code=response.status_code,
headers=dict(response.headers),
content=response.content,
url=str(response.url) if response.url else None,
reason_phrase=response.reason_phrase,
fallback_text=response.text,
)
class GraphonSSRFProxy:
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
ssrf_proxy = SSRFProxy()
graphon_ssrf_proxy = GraphonSSRFProxy()

View File

@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class ModelInstance:
@@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke

View File

@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")

View File

@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):

View File

@@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding

View File

@@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import inspect
import logging
import mimetypes
import os
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
_closed: bool
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def close(self) -> None:
"""Best-effort cleanup for downloaded temporary files."""
if getattr(self, "_closed", False):
return
self._closed = True
temp_file = getattr(self, "temp_file", None)
if temp_file is None:
return
try:
close_result = temp_file.close()
if inspect.isawaitable(close_result):
close_awaitable = getattr(close_result, "close", None)
if callable(close_awaitable):
close_awaitable()
except Exception:
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
def __del__(self):
if hasattr(self, "temp_file"):
self.temp_file.close()
self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""

View File

@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@@ -9,7 +9,7 @@ from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):

View File

@@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,

View File

@@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
type=get_file_type_by_mime_type(tool_file.mimetype),
file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),

View File

@@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.value
if not isinstance(variable_selector, list) or not all(
isinstance(selector_part, str) for selector_part in variable_selector
):
raise ToolParameterError("Variable tool input must be a variable selector")
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value

View File

@@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke

View File

@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
transfer_method_value = file_dict.get("transfer_method")
if not isinstance(transfer_method_value, str):
raise ValueError("Workflow file mapping is missing a valid transfer_method")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id

View File

@@ -1,8 +1,8 @@
"""Workflow-layer adapters for legacy human-input payload keys.
"""Workflow-to-Graphon adapters for persisted node payloads.
Stored workflow graphs and editor payloads may still use Dify-specific human
input recipient keys. Normalize them here before handing configs to
`graphon` so graph-owned models only see graph-neutral field names.
Stored workflow graphs and editor payloads still contain a small set of
Dify-owned field spellings and value shapes. Adapt them here before handing the
payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
normalized = normalize_human_input_node_data_for_graph(node_data)
normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
return normalized
return normalize_human_input_node_data_for_graph(normalized)
node_type = normalized.get("type")
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
return adapt_human_input_node_data_for_graph(normalized)
if node_type == BuiltinNodeTypes.TOOL:
return _adapt_tool_node_data_for_graph(normalized)
return normalized
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
normalized["data"] = normalize_node_data_for_graph(data_mapping)
normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(node_data)
raw_tool_configurations = normalized.get("tool_configurations")
if not isinstance(raw_tool_configurations, Mapping):
return normalized
existing_tool_parameters = normalized.get("tool_parameters")
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
normalized_tool_configurations: dict[str, Any] = {}
found_legacy_tool_inputs = False
for name, value in raw_tool_configurations.items():
if not isinstance(value, Mapping):
normalized_tool_configurations[name] = value
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
normalized_tool_configurations[name] = value
continue
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, dict(value))
flattened_value = _flatten_legacy_tool_configuration_value(
input_type=input_type,
input_value=input_value,
)
if flattened_value is not None:
normalized_tool_configurations[name] = flattened_value
if not found_legacy_tool_inputs:
return normalized
normalized["tool_parameters"] = normalized_tool_parameters
normalized["tool_configurations"] = normalized_tool_configurations
return normalized
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
return input_value
if (
input_type == "variable"
and isinstance(input_value, list)
and all(isinstance(item, str) for item in input_value)
):
return "{{#" + ".".join(input_value) + "#}}"
return None
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
"adapt_human_input_node_data_for_graph",
"adapt_node_config_for_graph",
"adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
"normalize_human_input_node_data_for_graph",
"normalize_node_config_for_graph",
"normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]

View File

@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
"""Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
# Graph configs are initially validated against permissive shared node data.
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
return node_class.validate_node_data(node_data)
validate_node_data = getattr(node_class, "validate_node_data", None)
if callable(validate_node_data):
return cast("BaseNodeData", validate_node_data(node_data))
return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .human_input_compat import (
from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: KnowledgeIndexNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(id, config, graph_init_params, graph_runtime_state)
super().__init__(
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()

View File

@@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
if value is None or isinstance(value, (str, float)):
return value
if isinstance(value, int) and not isinstance(value, bool):
return value
return str(value)
def _normalize_metadata_filter_sequence_item(value: object) -> str:
return value if isinstance(value, str) else str(value)
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: KnowledgeRetrievalNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values

View File

@@ -148,11 +148,11 @@ def _build_from_local_file(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
@@ -196,11 +196,11 @@ def _build_from_remote_url(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
@@ -222,9 +222,9 @@ def _build_from_remote_url(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=filename,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@@ -263,9 +263,9 @@ def _build_from_tool_file(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
@@ -306,9 +306,9 @@ def _build_from_datasource_file(
)
return File(
id=mapping.get("datasource_file_id"),
file_id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
file_type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),

View File

@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
return value_type.exposed_type().value
return str(value_type.exposed_type())

View File

@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:

View File

@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": value.value_type.exposed_type().value,
"value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):

View File

@@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
from core.workflow.human_input_compat import DeliveryMethodType
from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string

View File

@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
from graphon.file import File, FileTransferMethod
from graphon.file import File, FileTransferMethod, FileType
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
"""Build a graph `File` directly from serialized metadata."""
def _coerce_file_type(value: Any) -> FileType:
if isinstance(value, FileType):
return value
if isinstance(value, str):
return FileType.value_of(value)
raise ValueError("file type is required in file mapping")
mapping = dict(file_mapping)
transfer_method_value = mapping.get("transfer_method")
if isinstance(transfer_method_value, FileTransferMethod):
transfer_method = transfer_method_value
elif isinstance(transfer_method_value, str):
transfer_method = FileTransferMethod.value_of(transfer_method_value)
else:
raise ValueError("transfer_method is required in file mapping")
file_id = mapping.get("file_id")
if not isinstance(file_id, str) or not file_id:
legacy_id = mapping.get("id")
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
related_id = resolve_file_record_id(mapping)
if related_id is None:
raw_related_id = mapping.get("related_id")
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
remote_url = url if isinstance(url, str) and url else None
reference = mapping.get("reference")
if not isinstance(reference, str) or not reference:
reference = None
filename = mapping.get("filename")
if not isinstance(filename, str):
filename = None
extension = mapping.get("extension")
if not isinstance(extension, str):
extension = None
mime_type = mapping.get("mime_type")
if not isinstance(mime_type, str):
mime_type = None
size = mapping.get("size", -1)
if not isinstance(size, int):
size = -1
storage_key = mapping.get("storage_key")
if not isinstance(storage_key, str):
storage_key = None
tenant_id = mapping.get("tenant_id")
if not isinstance(tenant_id, str):
tenant_id = None
dify_model_identity = mapping.get("dify_model_identity")
if not isinstance(dify_model_identity, str):
dify_model_identity = FILE_MODEL_IDENTITY
tool_file_id = mapping.get("tool_file_id")
if not isinstance(tool_file_id, str):
tool_file_id = None
upload_file_id = mapping.get("upload_file_id")
if not isinstance(upload_file_id, str):
upload_file_id = None
datasource_file_id = mapping.get("datasource_file_id")
if not isinstance(datasource_file_id, str):
datasource_file_id = None
return File(
file_id=file_id,
tenant_id=tenant_id,
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
transfer_method=transfer_method,
remote_url=remote_url,
reference=reference,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
storage_key=storage_key,
dify_model_identity=dify_model_identity,
url=remote_url,
tool_file_id=tool_file_id,
upload_file_id=upload_file_id,
datasource_file_id=datasource_file_id,
)
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
"""Recursively rebuild serialized graph file payloads into `File` objects.
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
`model_validate_json()`. Dify keeps this recovery path at the model boundary
so historical JSON blobs remain readable without reintroducing global graph
patches or test-local coercion.
"""
if isinstance(value, list):
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
if isinstance(value, dict):
if maybe_file_object(value):
return build_file_from_mapping_without_lookup(file_mapping=value)
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
return value
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
if isinstance(url, str) and url:
mapping["remote_url"] = url
return File.model_validate(mapping)
return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,

View File

@@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
from .utils.file_input_compat import build_file_from_stored_mapping
from .utils.file_input_compat import (
build_file_from_mapping_without_lookup,
build_file_from_stored_mapping,
)
logger = logging.getLogger(__name__)
@@ -290,7 +293,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
return File.model_validate(normalized_file)
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
file_list.append(File.model_validate(normalized_file))
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)

View File

@@ -1,7 +1,6 @@
"""
Tencent APM tracing implementation with separated concerns
"""
"""Tencent APM tracing with idempotent client cleanup."""
import inspect
import logging
from sqlalchemy import select
@@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
explicitly in tests or implicitly during garbage collection, so shutdown
must be safe to call multiple times.
"""
trace_client: TencentTraceClient
_closed: bool
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
def close(self) -> None:
"""Synchronously and idempotently shutdown the underlying trace client."""
if getattr(self, "_closed", False):
return
self._closed = True
trace_client = getattr(self, "trace_client", None)
if trace_client is None:
return
try:
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
shutdown_result = trace_client.shutdown()
if inspect.isawaitable(shutdown_result):
close_awaitable = getattr(shutdown_result, "close", None)
if callable(close_awaitable):
close_awaitable()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
def __del__(self):
"""Ensure best-effort cleanup on garbage collection without retrying shutdown."""
self.close()

View File

@@ -1,5 +1,7 @@
import gc
import logging
from unittest.mock import MagicMock, patch
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dify_trace_tencent.config import TencentConfig
@@ -632,13 +634,38 @@ class TestTencentDataTrace:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
def test_del(self, tencent_data_trace):
def test_close(self, tencent_data_trace):
client = tencent_data_trace.trace_client
tencent_data_trace.__del__()
tencent_data_trace.close()
client.shutdown.assert_called_once()
def test_del_exception(self, tencent_data_trace):
def test_close_is_idempotent(self, tencent_data_trace):
client = tencent_data_trace.trace_client
tencent_data_trace.close()
tencent_data_trace.close()
client.shutdown.assert_called_once()
def test_close_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.__del__()
tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
shutdown = AsyncMock()
tencent_data_trace.trace_client.shutdown = shutdown
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
tencent_data_trace.close()
gc.collect()
shutdown.assert_called_once()
assert not [
warning
for warning in caught
if issubclass(warning.category, RuntimeWarning)
and "AsyncMockMixin._execute_mock_call" in str(warning.message)
]

View File

@@ -44,7 +44,7 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
"graphon~=0.1.2",
"graphon~=0.2.2",
"httpx-sse~=0.4.0",
"json-repair~=0.59.2",
]

View File

@@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account

View File

@@ -30,7 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user

View File

@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,

View File

@@ -476,7 +476,7 @@ class RagPipelineService:
:param filters: filter by node config parameters.
:return:
"""
node_type_enum = NodeType(node_type)
node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config

View File

@@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
json_str = dumps_with_segments(result.value, ensure_ascii=False)
json_str = dumps_with_segments(result.value)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)

View File

@@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -1067,7 +1067,7 @@ class DraftVariableSaver:
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
original_content_serialized = dumps_with_segments(value_seg.value)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"

View File

@@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
normalize_human_input_node_data_for_graph,
adapt_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import (
@@ -791,7 +791,7 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
node_type_enum = NodeType(node_type)
node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
@@ -1096,7 +1096,7 @@ class WorkflowService:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(
normalize_human_input_node_data_for_graph(node_config["data"]),
adapt_human_input_node_data_for_graph(node_config["data"]),
from_attributes=True,
)
delivery_method = self._resolve_human_input_delivery_method(
@@ -1237,9 +1237,10 @@ class WorkflowService:
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
id=node_config["id"],
config=node_config,
node_id=node_config["id"],
config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),
@@ -1529,7 +1530,7 @@ class WorkflowService:
from graphon.nodes.human_input.entities import HumanInputNodeData
try:
HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data))
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")

View File

@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
from extensions.ext_database import db
from extensions.ext_mail import mail
from graphon.runtime import GraphRuntimeState, VariablePool

View File

@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamCompletedEvent
@@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
id="n",
config={
"id": "n",
"data": {
"type": "datasource",
"version": "1",
"title": "Datasource",
"provider_type": "plugin",
"provider_name": "p",
"plugin_id": "plug",
"datasource_name": "ds",
},
},
node_id="n",
config=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",
provider_type="plugin",
provider_name="p",
plugin_id="plug",
datasource_name="ds",
),
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
)

View File

@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import NodeRunResult
from graphon.nodes.code.code_node import CodeNode
from graphon.nodes.code.entities import CodeNodeData
from graphon.nodes.code.limits import CodeNodeLimits
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = CodeNode(
id=str(uuid.uuid4()),
config=code_config,
node_id=str(uuid.uuid4()),
config=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,

View File

@@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.graph import Graph
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -75,8 +75,8 @@ def init_http_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
id=str(uuid.uuid4()),
config=graph_config["nodes"][1],
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
@@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@@ -11,6 +11,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
@@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.template_rendering import TemplateRenderError
@@ -86,8 +87,8 @@ def test_execute_template_transform():
assert graph is not None
node = TemplateTransformNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),

View File

@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.protocols import ToolFileManagerProtocol
from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.tool_node import ToolNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -60,8 +61,8 @@ def init_tool_node(config: dict):
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,

View File

@@ -101,8 +101,8 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
node_id="start",
config=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@@ -116,8 +116,8 @@ def _build_graph(
],
)
human_node = HumanInputNode(
id="human",
config={"id": "human", "data": human_data.model_dump()},
node_id="human",
config=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@@ -130,8 +130,8 @@ def _build_graph(
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
node_id="end",
config=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)

View File

@@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
file_related_id = related_id
return File(
id=str(uuid4()), # Generate new UUID for File.id
file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,

View File

@@ -271,7 +271,7 @@ def _create_recipient(
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
from core.workflow.human_input_compat import DeliveryMethodType
from core.workflow.human_input_adapter import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@@ -8,7 +8,7 @@ import pytest
from sqlalchemy.engine import Engine
from configs import dify_config
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
file_list = [
File(
tenant_id="t1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)

View File

@@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
id="test_file_id",
type=FileType.IMAGE,
file_id="test_file_id",
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
@@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
# Create a File object with REMOTE_URL transfer method
test_file = File(
id="test_file_id",
type=FileType.IMAGE,
file_id="test_file_id",
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",

View File

@@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
from fields._value_type_serializer import serialize_value_type
from graphon.variables import StringSegment
from graphon.variables.types import SegmentType
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
def test_variable_response_normalizes_string_value_type_alias(self):
response = ConversationVariableResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER.value,
}
)
assert response.value_type == "number"
def test_variable_response_normalizes_callable_exposed_type(self):
response = ConversationVariableResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
}
)
assert response.value_type == "string"
def test_serialize_value_type_supports_segments_and_mappings(self):
assert serialize_value_type(StringSegment(value="hello")) == "string"
assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{

View File

@@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
id=file_id,
type=FileType.DOCUMENT,
file_id=file_id,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",

View File

@@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow import node_factory as node_factory_module
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
from graphon.graph import Graph
@@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
def version(cls) -> str:
return "1"
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def __init__(
self,
node_id: str,
config: _StubToolNodeData,
*,
graph_init_params,
graph_runtime_state,
**_kwargs: Any,
) -> None:
super().__init__(
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
def _get_error_strategy(self):
return self._node_data.error_strategy
@@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
original_create_node = DifyNodeFactory.create_node
original_resolve_node_class = node_factory_module.resolve_workflow_node_class
def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_data = typed_node_config["data"]
if node_data.type == BuiltinNodeTypes.TOOL:
return _StubToolNode(
id=str(typed_node_config["id"]),
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
return original_create_node(self, typed_node_config)
def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
if node_type == BuiltinNodeTypes.TOOL:
return _StubToolNode
return original_resolve_node_class(node_type=node_type, node_version=node_version)
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:

View File

@@ -26,8 +26,8 @@ def _build_file(
extension: str | None = None,
) -> File:
return File(
id="file-id",
type=FileType.IMAGE,
file_id="file-id",
file_type=FileType.IMAGE,
transfer_method=transfer_method,
reference=reference,
remote_url=remote_url,
@@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)

View File

@@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
self.id = id
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
self.id = node_id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state

View File

@@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
built = File(
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
@@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",

View File

@@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Assert
assert simple_provider.provider == "openai"
assert simple_provider.label.en_US == "OpenAI"
assert simple_provider.label.en_us == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]

View File

@@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
def test_file():
file = File(
id="test-file",
file_id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
@@ -25,27 +25,21 @@ def test_file():
assert file.size == 67
def test_file_model_validate_accepts_legacy_tenant_id():
data = {
"id": "test-file",
"tenant_id": "test-tenant-id",
"type": "image",
"transfer_method": "tool_file",
"related_id": "test-related-id",
"filename": "image.png",
"extension": ".png",
"mime_type": "image/png",
"size": 67,
"storage_key": "test-storage-key",
"url": "https://example.com/image.png",
# Extra legacy fields
"tool_file_id": "tool-file-123",
"upload_file_id": "upload-file-456",
"datasource_file_id": "datasource-file-789",
}
def test_file_constructor_accepts_legacy_tenant_id():
file = File(
file_id="test-file",
tenant_id="test-tenant-id",
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
tool_file_id="tool-file-123",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
file = File.model_validate(data)
assert file.related_id == "test-related-id"
assert file.related_id == "tool-file-123"
assert file.storage_key == "test-storage-key"
assert "tenant_id" not in file.model_dump()

View File

@@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
SSRFProxy,
_get_user_provided_host_header,
_to_graphon_http_response,
graphon_ssrf_proxy,
make_request,
max_retries_exceeded_error,
request_error,
)
@@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
response = httpx.Response(
201,
headers={"X-Test": "1"},
content=b"payload",
request=httpx.Request("GET", "https://example.com/resource"),
)
wrapped = _to_graphon_http_response(response)
assert wrapped.status_code == 201
assert wrapped.headers == {"x-test": "1", "content-length": "7"}
assert wrapped.content == b"payload"
assert wrapped.url == "https://example.com/resource"
assert wrapped.reason_phrase == "Created"
assert wrapped.text == "payload"
def test_ssrf_proxy_exposes_expected_error_types() -> None:
proxy = SSRFProxy()
assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
assert proxy.request_error is request_error
assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
assert graphon_ssrf_proxy.request_error is request_error
@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
response = httpx.Response(
200,
headers={"X-Test": "1"},
content=b"ok",
request=httpx.Request("GET", "https://example.com/resource"),
)
with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
wrapped = getattr(graphon_ssrf_proxy, method_name)(
"https://example.com/resource",
max_retries=3,
headers={"X-Test": "1"},
)
mock_method.assert_called_once_with(
url="https://example.com/resource",
max_retries=3,
headers={"X-Test": "1"},
)
assert wrapped.status_code == 200
assert wrapped.url == "https://example.com/resource"
assert wrapped.content == b"ok"

View File

@@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory

View File

@@ -56,7 +56,7 @@ class TestPluginModelRuntime:
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert providers[0].provider_name == "openai"
assert providers[0].label.en_US == "OpenAI"
assert providers[0].label.en_us == "OpenAI"
client.fetch_model_providers.assert_called_once_with("tenant")
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:

View File

@@ -466,7 +466,7 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
@@ -499,14 +499,14 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",

View File

@@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
files = [
File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
@@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",

View File

@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import Conversation

View File

@@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform

View File

@@ -1,12 +1,14 @@
"""Primarily used for testing merged cell scenarios"""
import gc
import io
import os
import tempfile
import warnings
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from docx import Document
@@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
WordExtractor("not-a-file", "tenant", "user")
def test_del_closes_temp_file():
def test_close_closes_temp_file():
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
WordExtractor.__del__(extractor)
extractor.close()
extractor.temp_file.close.assert_called_once()
def test_close_is_idempotent():
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
extractor.close()
extractor.close()
extractor.temp_file.close.assert_called_once()
def test_close_handles_async_close_mock():
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
extractor.temp_file.close = AsyncMock()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
extractor.close()
gc.collect()
extractor.temp_file.close.assert_called_once()
assert not [
warning
for warning in caught
if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
]
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):

View File

@@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@@ -6,9 +6,9 @@ from models.workflow import Workflow
def test_file_to_dict():
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",

View File

@@ -1,8 +1,9 @@
import dataclasses
from typing import Annotated
import orjson
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Discriminator, Tag
from core.helper import encrypter
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
@@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
@@ -47,6 +49,26 @@ from graphon.variables.variables import (
StringVariable,
Variable,
)
from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
type SegmentUnion = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
),
Discriminator(get_segment_discriminator),
]
def _build_variable_pool(
@@ -123,7 +145,7 @@ def create_test_file(
) -> File:
"""Factory function to create File objects for testing"""
return File(
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
@@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
"""Test file-aware segment serialization through Dify's model boundary."""
# Create one instance of each segment type
test_file = create_test_file()
@@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
@@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
"""Test file-aware variable serialization through Dify's model boundary."""
# Create one instance of each variable type
test_file = create_test_file()
@@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)

View File

@@ -35,7 +35,7 @@ def create_test_file(
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,

View File

@@ -1,12 +1,13 @@
"""
Mock node factory for testing workflows with third-party service dependencies.
"""Mock node factory for third-party-service workflow tests.
This module provides a MockNodeFactory that automatically detects and mocks nodes
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
The factory follows the same config adaptation path as production
`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
implementations before instantiation.
"""
from typing import TYPE_CHECKING, Any
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_factory import DifyNodeFactory
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
node_id = typed_node_config["id"]
# Create mock node instance
mock_class = self._mock_node_types[node_type]
resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
)
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
}:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
)
else:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
return super().create_node(typed_node_config)
return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""

View File

@@ -55,13 +55,14 @@ class MockNodeMixin:
def __init__(
self,
id: str,
config: Mapping[str, Any],
node_id: str,
config: Any,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
) -> None:
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
@@ -96,7 +97,7 @@ class MockNodeMixin:
kwargs.setdefault("message_transformer", MagicMock())
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
node_id=start_config["id"],
config=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
id=human_a_config["id"],
config=human_a_config,
node_id=human_a_config["id"],
config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
id=human_b_config["id"],
config=human_b_config,
node_id=human_b_config["id"],
config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
node_id=end_config["id"],
config=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -9,6 +9,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,20 +67,15 @@ def test_execute_answer():
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
config=AnswerNodeData(
title="123",
type="answer",
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
),
)
# Mock db.session.close()

View File

@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
id="n",
config={
"id": "n",
"data": {
"type": "datasource",
"version": "1",
"title": "Datasource",
"provider_type": "plugin",
"provider_name": "p",
"plugin_id": "plug",
"datasource_name": "ds",
},
},
node_id="n",
config=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",
provider_type="plugin",
provider_name="p",
plugin_id="plug",
datasource_name="ds",
),
graph_init_params=gp,
graph_runtime_state=gs,
)

View File

@@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response
from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config():
assert default_config["retry_config"]["max_retries"] == 7
def test_get_default_config_with_malformed_http_request_config_raises_value_error():
with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
def test_get_default_config_with_malformed_http_request_config_raises_type_error():
with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"):
HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"})
@@ -114,8 +114,8 @@ def _build_http_node(
start_at=time.perf_counter(),
)
return HttpRequestNode(
id="http-node",
config=node_config,
node_id="http-node",
config=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@@ -1,4 +1,4 @@
from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients
from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients
from graphon.runtime import VariablePool

View File

@@ -19,7 +19,7 @@ from core.repositories.human_input_repository import (
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryMethodType,
EmailDeliveryConfig,
EmailDeliveryMethod,
@@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.status_value = HumanInputFormStatus.SUBMITTED
def _build_human_input_node(
*,
node_id: str,
node_data: HumanInputNodeData | Mapping[str, Any],
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
runtime: DifyHumanInputNodeRuntime,
) -> HumanInputNode:
typed_node_data = (
node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data)
)
return HumanInputNode(
node_id=node_id,
config=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=runtime,
)
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
@@ -239,7 +259,7 @@ class TestUserAction:
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
UserAction(**data)
UserAction.model_validate(data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
@@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
node = HumanInputNode(
id=config["id"],
config=config,
node = _build_human_input_node(
node_id=config["id"],
node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
node = HumanInputNode(
id=config["id"],
config=config,
node = _build_human_input_node(
node_id=config["id"],
node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
node = HumanInputNode(
id=config["id"],
config=config,
node = _build_human_input_node(
node_id=config["id"],
node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
node = HumanInputNode(
id=config["id"],
config=config,
node = _build_human_input_node(
node_id=config["id"],
node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent:
form_repository = InMemoryHumanInputFormRepository()
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined]
node = HumanInputNode(
id=config["id"],
config=config,
node = _build_human_input_node(
node_id=config["id"],
node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,

View File

@@ -11,6 +11,7 @@ from graphon.graph_events import (
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.human_input.enums import HumanInputFormStatus
from graphon.nodes.human_input.human_input_node import HumanInputNode
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -25,6 +26,28 @@ class _FakeFormRepository:
return self._form
def _create_human_input_node(
*,
config: dict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
repo: _FakeFormRepository,
) -> HumanInputNode:
node_data = (
config["data"]
if isinstance(config["data"], HumanInputNodeData)
else HumanInputNodeData.model_validate(config["data"])
)
return HumanInputNode(
node_id=config["id"],
config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
)
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
@@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
)
repo = _FakeFormRepository(fake_form)
return HumanInputNode(
id="node-1",
return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
repo=repo,
)
@@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode:
)
repo = _FakeFormRepository(fake_form)
return HumanInputNode(
id="node-1",
return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
repo=repo,
)

View File

@@ -5,6 +5,7 @@ import pytest
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
from graphon.nodes.iteration.entities import IterationNodeData
from graphon.nodes.iteration.exc import IterationGraphNotFoundError
from graphon.nodes.iteration.iteration_node import IterationNode
from graphon.runtime import (
@@ -44,17 +45,14 @@ def _build_iteration_node(
) -> IterationNode:
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
id="iteration-node",
config={
"id": "iteration-node",
"data": {
"type": "iteration",
"title": "Iteration",
"iterator_selector": ["start", "items"],
"output_selector": ["iteration-node", "output"],
"start_node_id": start_node_id,
},
},
node_id="iteration-node",
config=IterationNodeData(
type="iteration",
title="Iteration",
iterator_selector=["start", "items"],
output_selector=["iteration-node", "output"],
start_node_id=start_node_id,
),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@@ -93,6 +93,25 @@ def sample_chunks():
}
def _build_node(
*,
node_id: str,
node_data: KnowledgeIndexNodeData | dict[str, object],
graph_init_params,
graph_runtime_state,
) -> KnowledgeIndexNode:
return KnowledgeIndexNode(
node_id=node_id,
config=(
node_data
if isinstance(node_data, KnowledgeIndexNodeData)
else KnowledgeIndexNodeData.model_validate(node_data)
),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
class TestKnowledgeIndexNode:
"""
Test suite for KnowledgeIndexNode.
@@ -115,9 +134,9 @@ class TestKnowledgeIndexNode:
}
# Act
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -143,9 +162,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -176,9 +195,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -212,9 +231,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -269,9 +288,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -332,9 +351,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -383,9 +402,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -440,9 +459,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -498,9 +517,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -536,9 +555,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -583,9 +602,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex:
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
node = _build_node(
node_id=node_id,
node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

Some files were not shown because too many files have changed in this diff Show More