diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index cead33d14f..9c8b095b9f 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
return value
try:
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index f6319573e0..e32ba5f66c 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
- "value_type": variable_file.value_type.exposed_type().value,
+ "value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
- "value_type": v.value_type.exposed_type().value,
+ "value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index c4353ca7b8..ca4b18cb5e 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 790602ef5d..c22102c2ba 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
index dbd7527fc6..5df3df2b3e 100644
--- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
+++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
@@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 09ddce327e..cae0eee0df 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index dfe6133cb6..e2e07ebaff 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py
index 68e5e5f0c8..3a6f9d575a 100644
--- a/api/core/app/workflow/file_runtime.py
+++ b/api/core/app/workflow/file_runtime.py
@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
-from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
+from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
+from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
- return ssrf_proxy.get(url, follow_redirects=follow_redirects)
+ return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
index 87f005a250..d521304615 100644
--- a/api/core/app/workflow/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
- execution.exceptions_count = runtime_state.exceptions_count
+ execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,
diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py
index dc831e5cac..f0dcb13b62 100644
--- a/api/core/datasource/datasource_manager.py
+++ b/api/core/datasource/datasource_manager.py
@@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.CUSTOM,
+ file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 6bbf163c9d..38b87e2cd1 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.ai_model import AIModel
+from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
- if key in provider_credential_secret_variables:
+ if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
@@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
- if key in provider_credential_secret_variables:
+ if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index b96a9ce380..38864a1830 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
- inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
+ inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py
index dc37a36943..f169f247cf 100644
--- a/api/core/helper/moderation.py
+++ b/api/core/helper/moderation.py
@@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index e38592bb7b..91e92712b7 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
+from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
+def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
+ """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
+ return HttpResponse(
+ status_code=response.status_code,
+ headers=dict(response.headers),
+ content=response.content,
+ url=str(response.url) if response.url else None,
+ reason_phrase=response.reason_phrase,
+ fallback_text=response.text,
+ )
+
+
+class GraphonSSRFProxy:
+ """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
+
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]:
+ return max_retries_exceeded_error
+
+ @property
+ def request_error(self) -> type[Exception]:
+ return request_error
+
+ def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
+
+ def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
+
+ def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
+
+ def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
+
+ def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
+
+ def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
+
+
ssrf_proxy = SSRFProxy()
+graphon_ssrf_proxy = GraphonSSRFProxy()
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index d8d8dfedd8..86d0e3baaa 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
-from typing import IO, Any, Literal, Optional, Union, cast, overload
+from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
+P = ParamSpec("P")
+R = TypeVar("R")
class ModelInstance:
@@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke_multimodal_rerank,
+ self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
- def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
+ def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py
index e3fba4ef3a..4e66d58b5e 100644
--- a/api/core/plugin/impl/model_runtime.py
+++ b/api/core/plugin/impl/model_runtime.py
@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
- provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
+ provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
- provider_schema.icon_small_dark.zh_Hans
+ provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
- else provider_schema.icon_small_dark.en_US
+ else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py
index 8f1d51f08a..7c6280fe93 100644
--- a/api/core/prompt/agent_history_prompt_transform.py
+++ b/api/core/prompt/agent_history_prompt_transform.py
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):
diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py
index 4926f44f16..a9995778f7 100644
--- a/api/core/rag/embedding/cached_embedding.py
+++ b/api/core/rag/embedding/cached_embedding.py
@@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 052fca930d..0330a43b28 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
+import inspect
import logging
import mimetypes
import os
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
+ _closed: bool
+
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
+ self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
+ def close(self) -> None:
+ """Best-effort cleanup for downloaded temporary files."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ temp_file = getattr(self, "temp_file", None)
+ if temp_file is None:
+ return
+
+ try:
+ close_result = temp_file.close()
+ if inspect.isawaitable(close_result):
+ close_awaitable = getattr(close_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
+ except Exception:
+ logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
+
def __del__(self):
- if hasattr(self, "temp_file"):
- self.temp_file.close()
+ self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index f8242efe31..7ffa9afafd 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 1453fe020b..5631b3a921 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 2581c354dd..66b375dad1 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -9,7 +9,7 @@ from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
-from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
+from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py
index 02625e242f..740d727e26 100644
--- a/api/core/repositories/human_input_repository.py
+++ b/api/core/repositories/human_input_repository.py
@@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index b3424cd9a5..c87e8a3ae0 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
- type=get_file_type_by_mime_type(tool_file.mimetype),
+ file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index f4588904d3..87cf6d7085 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
- variable = variable_pool.get(tool_input.value)
+ variable_selector = tool_input.value
+ if not isinstance(variable_selector, list) or not all(
+ isinstance(selector_part, str) for selector_part in variable_selector
+ ):
+ raise ToolParameterError("Variable tool input must be a variable selector")
+ variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py
index 9e1d41cb39..a3623d4ecd 100644
--- a/api/core/tools/utils/model_invocation_utils.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 52ab605963..cd8c6352b5 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
- transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
+ transfer_method_value = file_dict.get("transfer_method")
+ if not isinstance(transfer_method_value, str):
+ raise ValueError("Workflow file mapping is missing a valid transfer_method")
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id
diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py
similarity index 74%
rename from api/core/workflow/human_input_compat.py
rename to api/core/workflow/human_input_adapter.py
index 75a0a0c202..4b765e6aea 100644
--- a/api/core/workflow/human_input_compat.py
+++ b/api/core/workflow/human_input_adapter.py
@@ -1,8 +1,8 @@
-"""Workflow-layer adapters for legacy human-input payload keys.
+"""Workflow-to-Graphon adapters for persisted node payloads.
-Stored workflow graphs and editor payloads may still use Dify-specific human
-input recipient keys. Normalize them here before handing configs to
-`graphon` so graph-owned models only see graph-neutral field names.
+Stored workflow graphs and editor payloads still contain a small set of
+Dify-owned field spellings and value shapes. Adapt them here before handing the
+payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
-def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
- normalized = normalize_human_input_node_data_for_graph(node_data)
+ normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
-def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
- if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
- return normalized
- return normalize_human_input_node_data_for_graph(normalized)
+ node_type = normalized.get("type")
+ if node_type == BuiltinNodeTypes.HUMAN_INPUT:
+ return adapt_human_input_node_data_for_graph(normalized)
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _adapt_tool_node_data_for_graph(normalized)
+ return normalized
-def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
- normalized["data"] = normalize_node_data_for_graph(data_mapping)
+ normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
+def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
+ normalized = dict(node_data)
+
+ raw_tool_configurations = normalized.get("tool_configurations")
+ if not isinstance(raw_tool_configurations, Mapping):
+ return normalized
+
+ existing_tool_parameters = normalized.get("tool_parameters")
+ normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
+ normalized_tool_configurations: dict[str, Any] = {}
+ found_legacy_tool_inputs = False
+
+ for name, value in raw_tool_configurations.items():
+ if not isinstance(value, Mapping):
+ normalized_tool_configurations[name] = value
+ continue
+
+ input_type = value.get("type")
+ input_value = value.get("value")
+ if input_type not in {"mixed", "variable", "constant"}:
+ normalized_tool_configurations[name] = value
+ continue
+
+ found_legacy_tool_inputs = True
+ normalized_tool_parameters.setdefault(name, dict(value))
+
+ flattened_value = _flatten_legacy_tool_configuration_value(
+ input_type=input_type,
+ input_value=input_value,
+ )
+ if flattened_value is not None:
+ normalized_tool_configurations[name] = flattened_value
+
+ if not found_legacy_tool_inputs:
+ return normalized
+
+ normalized["tool_parameters"] = normalized_tool_parameters
+ normalized["tool_configurations"] = normalized_tool_configurations
+ return normalized
+
+
+def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
+ if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
+ return input_value
+
+ if (
+ input_type == "variable"
+ and isinstance(input_value, list)
+ and all(isinstance(item, str) for item in input_value)
+ ):
+ return "{{#" + ".".join(input_value) + "#}}"
+
+ return None
+
+
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
+ "adapt_human_input_node_data_for_graph",
+ "adapt_node_config_for_graph",
+ "adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
- "normalize_human_input_node_data_for_graph",
- "normalize_node_config_for_graph",
- "normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]
diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py
index 351da3444f..de4eae1b22 100644
--- a/api/core/workflow/node_factory.py
+++ b/api/core/workflow/node_factory.py
@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ """Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
- self._http_request_http_client = ssrf_proxy
+ self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
+ # Graph configs are initially validated against permissive shared node data.
+ # Re-validate using the resolved node class so workflow-local node schemas
+ # stay explicit and constructors receive the concrete typed payload.
+ resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
- return node_class.validate_node_data(node_data)
+ validate_node_data = getattr(node_class, "validate_node_data", None)
+ if callable(validate_node_data):
+ return cast("BaseNodeData", validate_node_data(node_data))
+ return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py
index 2e632e56f0..b8725853c4 100644
--- a/api/core/workflow/node_runtime.py
+++ b/api/core/workflow/node_runtime.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
-from .human_input_compat import (
+from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResult: ...
+
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunk, None, None]: ...
+
def invoke_llm(
self,
*,
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResultWithStructuredOutput: ...
+
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
+
def invoke_llm_with_structured_output(
self,
*,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 7b000101b0..68a24e86b1 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: AgentNodeData,
+ *,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
- *,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index e4f6b3b470..f3006c4242 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: DatasourceNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
index d5cab05dbe..9c1b7ab2c4 100644
--- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
+++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
@@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeIndexNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
- super().__init__(id, config, graph_init_params, graph_runtime_state)
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 47ad14b499..25f73e446d 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
+ if value is None or isinstance(value, (str, float)):
+ return value
+ if isinstance(value, int) and not isinstance(value, bool):
+ return value
+ return str(value)
+
+
+def _normalize_metadata_filter_sequence_item(value: object) -> str:
+ return value if isinstance(value, str) else str(value)
+
+
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeRetrievalNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
+ resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
- resolved_value = segment_group.value[0].to_object()
+ resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
- resolved_values = []
- for v in value: # type: ignore
+ resolved_values: list[str] = []
+ for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
- resolved_values.append(segment_group.value[0].to_object())
+ resolved_values.append(
+ _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
+ )
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py
index ce1fa441c2..1d2ad4d445 100644
--- a/api/factories/file_factory/builders.py
+++ b/api/factories/file_factory/builders.py
@@ -148,11 +148,11 @@ def _build_from_local_file(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
@@ -196,11 +196,11 @@ def _build_from_remote_url(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
@@ -222,9 +222,9 @@ def _build_from_remote_url(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=filename,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@@ -263,9 +263,9 @@ def _build_from_tool_file(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=tool_file.name,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
@@ -306,9 +306,9 @@ def _build_from_datasource_file(
)
return File(
- id=mapping.get("datasource_file_id"),
+ file_id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
- type=file_type,
+ file_type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py
index b5acbbbcb4..d518114777 100644
--- a/api/fields/_value_type_serializer.py
+++ b/api/fields/_value_type_serializer.py
@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
- return v.value_type.exposed_type().value
+ return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index cf4a71d545..e4219ba1ee 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index f9b5e98936..6e947858ba 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
- "value_type": value.value_type.exposed_type().value,
+ "value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):
diff --git a/api/models/human_input.py b/api/models/human_input.py
index b4c7a634b6..7447d3efcb 100644
--- a/api/models/human_input.py
+++ b/api/models/human_input.py
@@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
-from core.workflow.human_input_compat import DeliveryMethodType
+from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string
diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py
index a2dc8f6157..77dcbd13d4 100644
--- a/api/models/utils/file_input_compat.py
+++ b/api/models/utils/file_input_compat.py
@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
-from graphon.file import File, FileTransferMethod
+from graphon.file import File, FileTransferMethod, FileType
+from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
+def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
+ """Build a graph `File` directly from serialized metadata."""
+
+ def _coerce_file_type(value: Any) -> FileType:
+ if isinstance(value, FileType):
+ return value
+ if isinstance(value, str):
+ return FileType.value_of(value)
+ raise ValueError("file type is required in file mapping")
+
+ mapping = dict(file_mapping)
+ transfer_method_value = mapping.get("transfer_method")
+ if isinstance(transfer_method_value, FileTransferMethod):
+ transfer_method = transfer_method_value
+ elif isinstance(transfer_method_value, str):
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
+ else:
+ raise ValueError("transfer_method is required in file mapping")
+
+ file_id = mapping.get("file_id")
+ if not isinstance(file_id, str) or not file_id:
+ legacy_id = mapping.get("id")
+ file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
+
+ related_id = resolve_file_record_id(mapping)
+ if related_id is None:
+ raw_related_id = mapping.get("related_id")
+ related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
+
+ remote_url = mapping.get("remote_url")
+ if not isinstance(remote_url, str) or not remote_url:
+ url = mapping.get("url")
+ remote_url = url if isinstance(url, str) and url else None
+
+ reference = mapping.get("reference")
+ if not isinstance(reference, str) or not reference:
+ reference = None
+
+ filename = mapping.get("filename")
+ if not isinstance(filename, str):
+ filename = None
+
+ extension = mapping.get("extension")
+ if not isinstance(extension, str):
+ extension = None
+
+ mime_type = mapping.get("mime_type")
+ if not isinstance(mime_type, str):
+ mime_type = None
+
+ size = mapping.get("size", -1)
+ if not isinstance(size, int):
+ size = -1
+
+ storage_key = mapping.get("storage_key")
+ if not isinstance(storage_key, str):
+ storage_key = None
+
+ tenant_id = mapping.get("tenant_id")
+ if not isinstance(tenant_id, str):
+ tenant_id = None
+
+ dify_model_identity = mapping.get("dify_model_identity")
+ if not isinstance(dify_model_identity, str):
+ dify_model_identity = FILE_MODEL_IDENTITY
+
+ tool_file_id = mapping.get("tool_file_id")
+ if not isinstance(tool_file_id, str):
+ tool_file_id = None
+
+ upload_file_id = mapping.get("upload_file_id")
+ if not isinstance(upload_file_id, str):
+ upload_file_id = None
+
+ datasource_file_id = mapping.get("datasource_file_id")
+ if not isinstance(datasource_file_id, str):
+ datasource_file_id = None
+
+ return File(
+ file_id=file_id,
+ tenant_id=tenant_id,
+ file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
+ transfer_method=transfer_method,
+ remote_url=remote_url,
+ reference=reference,
+ related_id=related_id,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=size,
+ storage_key=storage_key,
+ dify_model_identity=dify_model_identity,
+ url=remote_url,
+ tool_file_id=tool_file_id,
+ upload_file_id=upload_file_id,
+ datasource_file_id=datasource_file_id,
+ )
+
+
+def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
+ """Recursively rebuild serialized graph file payloads into `File` objects.
+
+ `graphon` 0.2.2 no longer accepts legacy serialized file mappings via
+ `model_validate_json()`. Dify keeps this recovery path at the model boundary
+ so historical JSON blobs remain readable without reintroducing global graph
+ patches or test-local coercion.
+ """
+ if isinstance(value, list):
+ return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
+
+ if isinstance(value, dict):
+ if maybe_file_object(value):
+ return build_file_from_mapping_without_lookup(file_mapping=value)
+ return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
+
+ return value
+
+
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
- remote_url = mapping.get("remote_url")
- if not isinstance(remote_url, str) or not remote_url:
- url = mapping.get("url")
- if isinstance(url, str) and url:
- mapping["remote_url"] = url
- return File.model_validate(mapping)
+ return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,
diff --git a/api/models/workflow.py b/api/models/workflow.py
index dfda03c2ee..d127244b0f 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
-from .utils.file_input_compat import build_file_from_stored_mapping
+from .utils.file_input_compat import (
+ build_file_from_mapping_without_lookup,
+ build_file_from_stored_mapping,
+)
logger = logging.getLogger(__name__)
@@ -290,7 +293,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
- return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
- return File.model_validate(normalized_file)
+ return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
- file_list.append(File.model_validate(normalized_file))
+ file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
index cfcf6b307e..a8c480e4a5 100644
--- a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
@@ -1,7 +1,6 @@
-"""
-Tencent APM tracing implementation with separated concerns
-"""
+"""Tencent APM tracing with idempotent client cleanup."""
+import inspect
import logging
from sqlalchemy import select
@@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
+
+ The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
+ explicitly in tests or implicitly during garbage collection, so shutdown
+ must be safe to call multiple times.
"""
+ trace_client: TencentTraceClient
+ _closed: bool
+
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
+ self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
- def __del__(self):
- """Ensure proper cleanup on garbage collection."""
+ def close(self) -> None:
+ """Synchronously and idempotently shutdown the underlying trace client."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ trace_client = getattr(self, "trace_client", None)
+ if trace_client is None:
+ return
+
try:
- if hasattr(self, "trace_client"):
- self.trace_client.shutdown()
+ shutdown_result = trace_client.shutdown()
+ if inspect.isawaitable(shutdown_result):
+ close_awaitable = getattr(shutdown_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def __del__(self):
+ """Ensure best-effort cleanup on garbage collection without retrying shutdown."""
+ self.close()
diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
index a91a0aa558..54524b09ca 100644
--- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
@@ -1,5 +1,7 @@
+import gc
import logging
-from unittest.mock import MagicMock, patch
+import warnings
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dify_trace_tencent.config import TencentConfig
@@ -632,13 +634,38 @@ class TestTencentDataTrace:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
- def test_del(self, tencent_data_trace):
+ def test_close(self, tencent_data_trace):
client = tencent_data_trace.trace_client
- tencent_data_trace.__del__()
+ tencent_data_trace.close()
client.shutdown.assert_called_once()
- def test_del_exception(self, tencent_data_trace):
+ def test_close_is_idempotent(self, tencent_data_trace):
+ client = tencent_data_trace.trace_client
+
+ tencent_data_trace.close()
+ tencent_data_trace.close()
+
+ client.shutdown.assert_called_once()
+
+ def test_close_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
- tencent_data_trace.__del__()
+ tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
+ shutdown = AsyncMock()
+ tencent_data_trace.trace_client.shutdown = shutdown
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ tencent_data_trace.close()
+ gc.collect()
+
+ shutdown.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning)
+ and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 12b8b3d782..8f6ee796ab 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -44,7 +44,7 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
- "graphon~=0.1.2",
+ "graphon~=0.2.2",
"httpx-sse~=0.4.0",
"json-repair~=0.59.2",
]
diff --git a/api/services/app_service.py b/api/services/app_service.py
index afd98e2975..038c59633a 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index e6f5f80a6d..894cb05687 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -30,7 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py
index 68ef67dec1..8b4983e5f7 100644
--- a/api/services/human_input_delivery_test_service.py
+++ b/api/services/human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 968600d1bc..9db6682e10 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -476,7 +476,7 @@ class RagPipelineService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py
index c96050ce13..1529c2b98f 100644
--- a/api/services/variable_truncator.py
+++ b/api/services/variable_truncator.py
@@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
- json_str = dumps_with_segments(result.value, ensure_ascii=False)
+ json_str = dumps_with_segments(result.value)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 5ec00ee336..96f936ff9b 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -1067,7 +1067,7 @@ class DraftVariableSaver:
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
- original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
+ original_content_serialized = dumps_with_segments(value_seg.value)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index d71223314e..d4b9095ce5 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
- normalize_human_input_node_data_for_graph,
+ adapt_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import (
@@ -791,7 +791,7 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
@@ -1096,7 +1096,7 @@ class WorkflowService:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(
- normalize_human_input_node_data_for_graph(node_config["data"]),
+ adapt_human_input_node_data_for_graph(node_config["data"]),
from_attributes=True,
)
delivery_method = self._resolve_human_input_delivery_method(
@@ -1237,9 +1237,10 @@ class WorkflowService:
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
+ node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
- id=node_config["id"],
- config=node_config,
+ node_id=node_config["id"],
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),
@@ -1529,7 +1530,7 @@ class WorkflowService:
from graphon.nodes.human_input.entities import HumanInputNodeData
try:
- HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
+ HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data))
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py
index f8ae3f4b6e..2a60be7762 100644
--- a/api/tasks/mail_human_input_delivery_task.py
+++ b/api/tasks/mail_human_input_delivery_task.py
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
from extensions.ext_database import db
from extensions.ext_mail import mail
from graphon.runtime import GraphRuntimeState, VariablePool
diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
index b5318aaa2b..2392084c36 100644
--- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
+++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamCompletedEvent
@@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
)
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index e3476c292b..aaa6092993 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import NodeRunResult
from graphon.nodes.code.code_node import CodeNode
+from graphon.nodes.code.entities import CodeNodeData
from graphon.nodes.code.limits import CodeNodeLimits
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = CodeNode(
- id=str(uuid.uuid4()),
- config=code_config,
+ node_id=str(uuid.uuid4()),
+ config=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index aa6cf1e021..b9f7b9575b 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.graph import Graph
-from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
+from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -75,8 +75,8 @@ def init_http_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=graph_config["nodes"][1],
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index fa5d63cfbf..3eead70163 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
+from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
@@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
index 52886855b8..f2eabb86c3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -11,6 +11,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
+from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
@@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = ParameterExtractorNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index 9e3e1a47e3..e2e0723fb8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.template_rendering import TemplateRenderError
@@ -86,8 +87,8 @@ def test_execute_template_transform():
assert graph is not None
node = TemplateTransformNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index f9ec51ee10..a8e9422c1e 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.protocols import ToolFileManagerProtocol
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.tool_node import ToolNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -60,8 +61,8 @@ def init_tool_node(config: dict):
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
index 14d5740072..6524d6ce61 100644
--- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
index da4f8847d6..5aed230cd4 100644
--- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
+++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
@@ -101,8 +101,8 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
- id="start",
- config={"id": "start", "data": start_data.model_dump()},
+ node_id="start",
+ config=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@@ -116,8 +116,8 @@ def _build_graph(
],
)
human_node = HumanInputNode(
- id="human",
- config={"id": "human", "data": human_data.model_dump()},
+ node_id="human",
+ config=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@@ -130,8 +130,8 @@ def _build_graph(
desc=None,
)
end_node = EndNode(
- id="end",
- config={"id": "end", "data": end_data.model_dump()},
+ node_id="end",
+ config=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
index 2e207ddc67..35e41035df 100644
--- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
+++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
@@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
file_related_id = related_id
return File(
- id=str(uuid4()), # Generate new UUID for File.id
+ file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
index aaf9a85d60..54b7afc018 100644
--- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
+++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
@@ -271,7 +271,7 @@ def _create_recipient(
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
- from core.workflow.human_input_compat import DeliveryMethodType
+ from core.workflow.human_input_adapter import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
index 18c5320d0a..80f9083e81 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
index 21a54e909e..ed75363f3b 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ import pytest
from sqlalchemy.engine import Engine
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
index 328bdbf055..95a867dbb5 100644
--- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
+++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
@@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py
index 6ff3b19362..e91c0a0597 100644
--- a/api/tests/unit_tests/controllers/console/app/test_workflow.py
+++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py
@@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
file_list = [
File(
tenant_id="t1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
index b19a1740eb..22b80b748e 100644
--- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
+++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
@@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
@@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
# Create a File object with REMOTE_URL transfer method
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
index 14c35a9ed5..4fb8ecf784 100644
--- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
+++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
@@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
+from fields._value_type_serializer import serialize_value_type
+from graphon.variables import StringSegment
from graphon.variables.types import SegmentType
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
+ def test_variable_response_normalizes_string_value_type_alias(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SegmentType.INTEGER.value,
+ }
+ )
+
+ assert response.value_type == "number"
+
+ def test_variable_response_normalizes_callable_exposed_type(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
+ }
+ )
+
+ assert response.value_type == "string"
+
+ def test_serialize_value_type_supports_segments_and_mappings(self):
+ assert serialize_value_type(StringSegment(value="hello")) == "string"
+ assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
+
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{
diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
index 3ab63aed25..dd6cd0e919 100644
--- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
+++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
@@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
- id=file_id,
- type=FileType.DOCUMENT,
+ file_id=file_id,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",
diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
index a04a7b7576..6104b8d6ca 100644
--- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py
+++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
@@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow import node_factory as node_factory_module
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
from graphon.graph import Graph
@@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
def version(cls) -> str:
return "1"
- def init_node_data(self, data):
- self._node_data = _StubToolNodeData.model_validate(data)
+ def __init__(
+ self,
+ node_id: str,
+ config: _StubToolNodeData,
+ *,
+ graph_init_params,
+ graph_runtime_state,
+ **_kwargs: Any,
+ ) -> None:
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
def _get_error_strategy(self):
return self._node_data.error_strategy
@@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
- original_create_node = DifyNodeFactory.create_node
+ original_resolve_node_class = node_factory_module.resolve_workflow_node_class
- def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
- node_data = typed_node_config["data"]
- if node_data.type == BuiltinNodeTypes.TOOL:
- return _StubToolNode(
- id=str(typed_node_config["id"]),
- config=typed_node_config,
- graph_init_params=self.graph_init_params,
- graph_runtime_state=self.graph_runtime_state,
- )
- return original_create_node(self, typed_node_config)
+ def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _StubToolNode
+ return original_resolve_node_class(node_type=node_type, node_version=node_version)
- mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
+ mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
index cddd03f4b0..701863b927 100644
--- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
+++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
@@ -26,8 +26,8 @@ def _build_file(
extension: str | None = None,
) -> File:
return File(
- id="file-id",
- type=FileType.IMAGE,
+ file_id="file-id",
+ file_type=FileType.IMAGE,
transfer_method=transfer_method,
reference=reference,
remote_url=remote_url,
@@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
- with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
+ with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)
diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
index c4bfb23272..30a068f4c5 100644
--- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py
+++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
@@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
- def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
- self.id = id
+ def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
+ self.id = node_id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
index 81315d2508..deeac49bbc 100644
--- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py
+++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
@@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
built = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
@@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",
diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
index a0b2820157..aeca2e3afd 100644
--- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py
+++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
@@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Assert
assert simple_provider.provider == "openai"
- assert simple_provider.label.en_US == "OpenAI"
+ assert simple_provider.label.en_us == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]
diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py
index bb6e40e224..8cb0938575 100644
--- a/api/tests/unit_tests/core/file/test_models.py
+++ b/api/tests/unit_tests/core/file/test_models.py
@@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
def test_file():
file = File(
- id="test-file",
+ file_id="test-file",
tenant_id="test-tenant-id",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
@@ -25,27 +25,21 @@ def test_file():
assert file.size == 67
-def test_file_model_validate_accepts_legacy_tenant_id():
- data = {
- "id": "test-file",
- "tenant_id": "test-tenant-id",
- "type": "image",
- "transfer_method": "tool_file",
- "related_id": "test-related-id",
- "filename": "image.png",
- "extension": ".png",
- "mime_type": "image/png",
- "size": 67,
- "storage_key": "test-storage-key",
- "url": "https://example.com/image.png",
- # Extra legacy fields
- "tool_file_id": "tool-file-123",
- "upload_file_id": "upload-file-456",
- "datasource_file_id": "datasource-file-789",
- }
+def test_file_constructor_accepts_legacy_tenant_id():
+ file = File(
+ file_id="test-file",
+ tenant_id="test-tenant-id",
+ file_type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ tool_file_id="tool-file-123",
+ filename="image.png",
+ extension=".png",
+ mime_type="image/png",
+ size=67,
+ storage_key="test-storage-key",
+ url="https://example.com/image.png",
+ )
- file = File.model_validate(data)
-
- assert file.related_id == "test-related-id"
+ assert file.related_id == "tool-file-123"
assert file.storage_key == "test-storage-key"
assert "tenant_id" not in file.model_dump()
diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
index 3b5c5e6597..d9fed9ae2a 100644
--- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
+++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
@@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
+import httpx
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
+ SSRFProxy,
_get_user_provided_host_header,
+ _to_graphon_http_response,
+ graphon_ssrf_proxy,
make_request,
+ max_retries_exceeded_error,
+ request_error,
)
@@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
+
+
+def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
+ response = httpx.Response(
+ 201,
+ headers={"X-Test": "1"},
+ content=b"payload",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ wrapped = _to_graphon_http_response(response)
+
+ assert wrapped.status_code == 201
+ assert wrapped.headers == {"x-test": "1", "content-length": "7"}
+ assert wrapped.content == b"payload"
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.reason_phrase == "Created"
+ assert wrapped.text == "payload"
+
+
+def test_ssrf_proxy_exposes_expected_error_types() -> None:
+ proxy = SSRFProxy()
+
+ assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert proxy.request_error is request_error
+ assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert graphon_ssrf_proxy.request_error is request_error
+
+
+@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
+def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
+ response = httpx.Response(
+ 200,
+ headers={"X-Test": "1"},
+ content=b"ok",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
+ wrapped = getattr(graphon_ssrf_proxy, method_name)(
+ "https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+
+ mock_method.assert_called_once_with(
+ url="https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+ assert wrapped.status_code == 200
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.content == b"ok"
diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
index 249ecb5006..c4fd970562 100644
--- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
+++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
@@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
index 68aa130518..88bf555594 100644
--- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
+++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
@@ -56,7 +56,7 @@ class TestPluginModelRuntime:
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert providers[0].provider_name == "openai"
- assert providers[0].label.en_US == "OpenAI"
+ assert providers[0].label.en_us == "OpenAI"
client.fetch_model_providers.assert_called_once_with("tenant")
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:
diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
index d49b6e4b71..00a4207786 100644
--- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
+++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
@@ -466,7 +466,7 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
@@ -499,14 +499,14 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
index 395d392127..e536c0831f 100644
--- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
@@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
files = [
File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
@@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
index 803afa54d7..28966242d8 100644
--- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import Conversation
diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
index 9f9ea33695..5308c8e7b3 100644
--- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
@@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
-# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform
diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
index 64eb89590a..0220fb6d4a 100644
--- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
+++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
@@ -1,12 +1,14 @@
"""Primarily used for testing merged cell scenarios"""
+import gc
import io
import os
import tempfile
+import warnings
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
import pytest
from docx import Document
@@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
WordExtractor("not-a-file", "tenant", "user")
-def test_del_closes_temp_file():
+def test_close_closes_temp_file():
extractor = object.__new__(WordExtractor)
+ extractor._closed = False
extractor.temp_file = MagicMock()
- WordExtractor.__del__(extractor)
+ extractor.close()
extractor.temp_file.close.assert_called_once()
+def test_close_is_idempotent():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+
+ extractor.close()
+ extractor.close()
+
+ extractor.temp_file.close.assert_called_once()
+
+
+def test_close_handles_async_close_mock():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+ extractor.temp_file.close = AsyncMock()
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ extractor.close()
+ gc.collect()
+
+ extractor.temp_file.close.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
+
+
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
index 8be1ac318c..18ae9fafc8 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
index 1297a95df1..4248782d93 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
@@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py
index f17927f16b..eab0176f41 100644
--- a/api/tests/unit_tests/core/test_file.py
+++ b/api/tests/unit_tests/core/test_file.py
@@ -6,9 +6,9 @@ from models.workflow import Workflow
def test_file_to_dict():
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py
index 72052c8c05..9e07ea1b6d 100644
--- a/api/tests/unit_tests/core/variables/test_segment.py
+++ b/api/tests/unit_tests/core/variables/test_segment.py
@@ -1,8 +1,9 @@
import dataclasses
+from typing import Annotated
import orjson
import pytest
-from pydantic import BaseModel
+from pydantic import BaseModel, Discriminator, Tag
from core.helper import encrypter
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
@@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import (
ArrayAnySegment,
+ ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
+ BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
- SegmentUnion,
StringSegment,
get_segment_discriminator,
)
@@ -47,6 +49,26 @@ from graphon.variables.variables import (
StringVariable,
Variable,
)
+from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
+
+type SegmentUnion = Annotated[
+ (
+ Annotated[NoneSegment, Tag(SegmentType.NONE)]
+ | Annotated[StringSegment, Tag(SegmentType.STRING)]
+ | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
+ | Annotated[FileSegment, Tag(SegmentType.FILE)]
+ | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
+ | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
+ | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
def _build_variable_pool(
@@ -123,7 +145,7 @@ def create_test_file(
) -> File:
"""Factory function to create File objects for testing"""
return File(
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
@@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
assert restored == model
def test_all_segments_serialization(self):
- """Test serialization/deserialization of all segment types"""
+ """Test file-aware segment serialization through Dify's model boundary."""
# Create one instance of each segment type
test_file = create_test_file()
@@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
- loaded = _Segments.model_validate_json(json_str)
+ loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
@@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
- """Test serialization/deserialization of all variable types"""
+ """Test file-aware variable serialization through Dify's model boundary."""
# Create one instance of each variable type
test_file = create_test_file()
@@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
- loaded = _Variables.model_validate_json(json_str)
+ loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
index 94e788edb2..317fe99d37 100644
--- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py
+++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
@@ -35,7 +35,7 @@ def create_test_file(
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
index 76b2984a4b..9f3e3b00b9 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
@@ -1,12 +1,13 @@
-"""
-Mock node factory for testing workflows with third-party service dependencies.
+"""Mock node factory for third-party-service workflow tests.
-This module provides a MockNodeFactory that automatically detects and mocks nodes
-requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
+The factory follows the same config adaptation path as production
+`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
+implementations before instantiation.
"""
from typing import TYPE_CHECKING, Any
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_factory import DifyNodeFactory
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
+ node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
- node_id = typed_node_config["id"]
-
# Create mock node instance
mock_class = self._mock_node_types[node_type]
+ resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
)
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
}:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
)
else:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
- return super().create_node(typed_node_config)
+ return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
index 971b9b2bbf..f9819c47ec 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
@@ -55,13 +55,14 @@ class MockNodeMixin:
def __init__(
self,
- id: str,
- config: Mapping[str, Any],
+ node_id: str,
+ config: Any,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
- ):
+ ) -> None:
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
@@ -96,7 +97,7 @@ class MockNodeMixin:
kwargs.setdefault("message_transformer", MagicMock())
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
index 55a329eba9..75bc6d05f7 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
@@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
- id=start_config["id"],
- config=start_config,
+ node_id=start_config["id"],
+ config=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
- id=human_a_config["id"],
- config=human_a_config,
+ node_id=human_a_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
- id=human_b_config["id"],
- config=human_b_config,
+ node_id=human_b_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
- id=end_config["id"],
- config=end_config,
+ node_id=end_config["id"],
+ config=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
index 9c0ad25b58..76b4cd1ef4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
@@ -9,6 +9,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
+from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,20 +67,15 @@ def test_execute_answer():
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- }
-
node = AnswerNode(
- id=str(uuid.uuid4()),
+ node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
+ config=AnswerNodeData(
+ title="123",
+ type="answer",
+ answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ ),
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
index 9cceadde49..d7ef781732 100644
--- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=gp,
graph_runtime_state=gs,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
index a3cadc0681..2e89a2da3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
@@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
-from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response
+from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config():
assert default_config["retry_config"]["max_retries"] == 7
-def test_get_default_config_with_malformed_http_request_config_raises_value_error():
- with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
+def test_get_default_config_with_malformed_http_request_config_raises_type_error():
+ with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"):
HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"})
@@ -114,8 +114,8 @@ def _build_http_node(
start_at=time.perf_counter(),
)
return HttpRequestNode(
- id="http-node",
- config=node_config,
+ node_id="http-node",
+ config=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
index 1d6a4da7c4..07430498e5 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
@@ -1,4 +1,4 @@
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients
from graphon.runtime import VariablePool
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
index c0e21d0bf7..0659984c76 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
@@ -19,7 +19,7 @@ from core.repositories.human_input_repository import (
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryMethodType,
EmailDeliveryConfig,
EmailDeliveryMethod,
@@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.status_value = HumanInputFormStatus.SUBMITTED
+def _build_human_input_node(
+ *,
+ node_id: str,
+ node_data: HumanInputNodeData | Mapping[str, Any],
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ runtime: DifyHumanInputNodeRuntime,
+) -> HumanInputNode:
+ typed_node_data = (
+ node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data)
+ )
+ return HumanInputNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ runtime=runtime,
+ )
+
+
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
@@ -239,7 +259,7 @@ class TestUserAction:
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
- UserAction(**data)
+ UserAction.model_validate(data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
@@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent:
form_repository = InMemoryHumanInputFormRepository()
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
index bc98028d5b..4a9438b14f 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
@@ -11,6 +11,7 @@ from graphon.graph_events import (
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
+from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.human_input.enums import HumanInputFormStatus
from graphon.nodes.human_input.human_input_node import HumanInputNode
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -25,6 +26,28 @@ class _FakeFormRepository:
return self._form
+def _create_human_input_node(
+ *,
+ config: dict,
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ repo: _FakeFormRepository,
+) -> HumanInputNode:
+ node_data = (
+ config["data"]
+ if isinstance(config["data"], HumanInputNodeData)
+ else HumanInputNodeData.model_validate(config["data"])
+ )
+ return HumanInputNode(
+ node_id=config["id"],
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ form_repository=repo,
+ runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ )
+
+
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
@@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
@@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode:
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
index 82cc734274..8ffce39cd6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
@@ -5,6 +5,7 @@ import pytest
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
+from graphon.nodes.iteration.entities import IterationNodeData
from graphon.nodes.iteration.exc import IterationGraphNotFoundError
from graphon.nodes.iteration.iteration_node import IterationNode
from graphon.runtime import (
@@ -44,17 +45,14 @@ def _build_iteration_node(
) -> IterationNode:
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
- id="iteration-node",
- config={
- "id": "iteration-node",
- "data": {
- "type": "iteration",
- "title": "Iteration",
- "iterator_selector": ["start", "items"],
- "output_selector": ["iteration-node", "output"],
- "start_node_id": start_node_id,
- },
- },
+ node_id="iteration-node",
+ config=IterationNodeData(
+ type="iteration",
+ title="Iteration",
+ iterator_selector=["start", "items"],
+ output_selector=["iteration-node", "output"],
+ start_node_id=start_node_id,
+ ),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
index a6fca1bfb4..f254fc3d09 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
@@ -93,6 +93,25 @@ def sample_chunks():
}
+def _build_node(
+ *,
+ node_id: str,
+ node_data: KnowledgeIndexNodeData | dict[str, object],
+ graph_init_params,
+ graph_runtime_state,
+) -> KnowledgeIndexNode:
+ return KnowledgeIndexNode(
+ node_id=node_id,
+ config=(
+ node_data
+ if isinstance(node_data, KnowledgeIndexNodeData)
+ else KnowledgeIndexNodeData.model_validate(node_data)
+ ),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+
class TestKnowledgeIndexNode:
"""
Test suite for KnowledgeIndexNode.
@@ -115,9 +134,9 @@ class TestKnowledgeIndexNode:
}
# Act
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -143,9 +162,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -176,9 +195,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -212,9 +231,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -269,9 +288,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -332,9 +351,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -383,9 +402,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -440,9 +459,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -498,9 +517,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -536,9 +555,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -583,9 +602,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
index 45e8ae7d20..e923ee761b 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
@@ -14,7 +14,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import (
SingleRetrievalConfig,
)
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
-from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
+from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import (
+ KnowledgeRetrievalNode,
+ _normalize_metadata_filter_scalar,
+ _normalize_metadata_filter_sequence_item,
+)
from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
@@ -85,6 +89,12 @@ def sample_node_data():
)
+def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None:
+ assert _normalize_metadata_filter_scalar(3) == 3
+ assert _normalize_metadata_filter_scalar(True) == "True"
+ assert _normalize_metadata_filter_sequence_item(4) == "4"
+
+
class TestKnowledgeRetrievalNode:
"""
Test suite for KnowledgeRetrievalNode.
@@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -470,8 +480,8 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -507,8 +517,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -562,8 +572,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -610,8 +620,8 @@ class TestFetchDatasetRetriever:
mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme"))
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -671,8 +681,8 @@ class TestFetchDatasetRetriever:
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
index eca34f05be..388654f279 100644
--- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
@@ -1,3 +1,4 @@
+from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@@ -5,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from graphon.entities import GraphInitParams
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
+from graphon.nodes.list_operator.entities import ListOperatorNodeData
from graphon.nodes.list_operator.node import ListOperatorNode
from graphon.runtime import GraphRuntimeState
from graphon.variables import ArrayNumberSegment, ArrayStringSegment
@@ -13,11 +15,28 @@ from graphon.variables import ArrayNumberSegment, ArrayStringSegment
class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
+ @staticmethod
+ def _build_node(*, config, graph_init_params, graph_runtime_state):
+ return ListOperatorNode(
+ node_id="test",
+ config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ @staticmethod
+ def _filter_by(comparison_operator: str, value: str) -> dict[str, object]:
+ return {
+ "enabled": True,
+ "conditions": [{"comparison_operator": comparison_operator, "value": value}],
+ }
+
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock GraphRuntimeState."""
mock_state = MagicMock(spec=GraphRuntimeState)
mock_variable_pool = MagicMock()
+ mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value)
mock_state.variable_pool = mock_variable_pool
return mock_state
@@ -45,9 +64,8 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
- return ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ return self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -64,9 +82,8 @@ class TestListOperatorNode:
"limit": {"enabled": False},
}
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -109,9 +126,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=[])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -128,11 +144,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -140,9 +152,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -157,11 +168,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "not contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("not contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -169,9 +176,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -186,11 +192,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "5",
- },
+ "filter_by": self._filter_by(">", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -198,9 +200,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -226,9 +227,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -254,9 +254,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -282,9 +281,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -299,11 +297,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "3",
- },
+ "filter_by": self._filter_by(">", "3"),
"order_by": {
"enabled": True,
"value": "desc",
@@ -317,9 +311,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -341,9 +334,8 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -366,9 +358,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["first", "middle", "last"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -384,11 +375,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "start with",
- "value": "app",
- },
+ "filter_by": self._filter_by("start with", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -396,9 +383,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -413,11 +399,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "end with",
- "value": "le",
- },
+ "filter_by": self._filter_by("end with", "le"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -425,9 +407,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -442,11 +423,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "=",
- "value": "5",
- },
+ "filter_by": self._filter_by("=", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -454,9 +431,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -471,11 +447,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "≠",
- "value": "5",
- },
+ "filter_by": self._filter_by("≠", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -483,9 +455,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -511,9 +482,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
index 4186bbdc93..212ad07bd3 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
@@ -71,8 +71,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -95,6 +95,8 @@ def variable_pool() -> VariablePool:
def _fetch_prompt_messages_with_mocked_content(content):
variable_pool = VariablePool.empty()
model_instance = mock.MagicMock(spec=ModelInstance)
+ model_schema = mock.MagicMock()
+ model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text"
prompt_template = [
LLMNodeChatModelMessage(
text="You are a classifier.",
@@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content):
with (
mock.patch(
"graphon.nodes.llm.llm_utils.fetch_model_schema",
- return_value=mock.MagicMock(features=[]),
+ return_value=model_schema,
),
mock.patch(
"graphon.nodes.llm.llm_utils.handle_list_messages",
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index b1f81b6c48..c707cf28cd 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -140,8 +140,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -205,14 +205,10 @@ def llm_node(
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers():
def test_fetch_files_with_file_segment():
file = File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
@@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment():
def test_fetch_files_with_array_file_segment():
files = [
File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
- id="2",
- type=FileType.IMAGE,
+ file_id="2",
+ file_type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
@@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/jpg",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
@@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable
file_saver=file_saver,
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
reasoning_format="separated",
)
)
@@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=_build_prepared_llm_mock(),
reasoning_format="separated",
request_start_time=1.0,
@@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors():
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=model_instance,
)
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
index bc44ececd8..892f6cc586 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
@@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_template_transform_node(
+ *,
+ node_data,
+ graph_init_params,
+ graph_runtime_state,
+ node_id: str = "test_node",
+ **kwargs,
+) -> TemplateTransformNode:
+ typed_node_data = (
+ node_data
+ if isinstance(node_data, TemplateTransformNodeData)
+ else TemplateTransformNodeData.model_validate(node_data)
+ )
+ return TemplateTransformNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ **kwargs,
+ )
+
+
class TestTemplateTransformNode:
"""Comprehensive test suite for TemplateTransformNode."""
@@ -59,9 +81,8 @@ class TestTemplateTransformNode:
def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test that TemplateTransformNode initializes correctly."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -75,9 +96,8 @@ class TestTemplateTransformNode:
def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_title method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -88,9 +108,8 @@ class TestTemplateTransformNode:
def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_description method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -108,9 +127,8 @@ class TestTemplateTransformNode:
}
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -143,9 +161,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
with pytest.raises(ValueError, match="max_output_length must be a positive integer"):
- TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -170,9 +187,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -198,9 +214,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Value: "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -218,9 +233,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error")
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -238,9 +252,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -260,9 +273,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "1234567890"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -302,9 +314,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -375,8 +386,8 @@ class TestTemplateTransformNode:
)
assert mapping == {
- "node_123.var1": ["sys", "input1"],
- "node_123.empty_selector": [],
+ "node_123.var1": ("sys", "input1"),
+ "node_123.empty_selector": (),
}
def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self):
@@ -409,9 +420,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a static message."
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -448,9 +458,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Total: $31.5"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -477,9 +486,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -507,9 +515,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Tags: #python #ai #workflow "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
index 636237e56e..a846efbb43 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
@@ -4,6 +4,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from graphon.nodes.base.entities import VariableSelector
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import (
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH,
TemplateTransformNode,
@@ -37,15 +38,13 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
- id="test_node",
- config={
- "id": "test_node",
- "data": {
- "title": "Template Transform",
- "variables": [],
- "template": "hello",
- },
- },
+ node_id="test_node",
+ config=TemplateTransformNodeData(
+ title="Template Transform",
+ type="template-transform",
+ variables=[],
+ template="hello",
+ ),
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=MagicMock(),
@@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri
assert mapping == {
"node_123.validated": ["sys", "input1"],
- "node_123.raw": ["sys", "input2"],
+ "node_123.raw": ("sys", "input2"),
}
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
index 0522dd9d14..364408ead6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
@@ -7,7 +7,6 @@ from core.workflow.node_runtime import resolve_dify_run_context
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
from graphon.entities.base_node_data import BaseNodeData
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- },
- }
- )
+def _build_node_config() -> dict[str, object]:
+ return {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ ),
+ }
+
+
+def _build_node_data() -> _SampleNodeData:
+ return _build_node_config()["data"] # type: ignore[return-value]
def test_node_hydrates_data_during_initialization():
@@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization():
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum():
)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error():
def test_base_node_data_keeps_dict_style_access_compatibility():
- node_data = _SampleNodeData.model_validate(
- {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- }
- )
+ node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar")
assert node_data["foo"] == "bar"
assert node_data.get("foo") == "bar"
@@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility():
def test_node_hydration_preserves_compatibility_extra_fields():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
- node_config = NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- "compat_flag": True,
- },
- }
- )
+ node_config = {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ compat_flag=True,
+ ),
+ }
node = _SampleNode(
- id="node-1",
- config=node_config,
+ node_id="node-1",
+ config=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
index 87ec2d5bce..dd75b32593 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
@@ -11,14 +11,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.file import File, FileTransferMethod
from graphon.node_events import NodeRunResult
from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
+from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError
from graphon.nodes.document_extractor.node import (
_extract_text_from_docx,
_extract_text_from_excel,
+ _extract_text_from_file,
_extract_text_from_pdf,
_extract_text_from_plain_text,
_normalize_docx_zip,
)
-from graphon.variables import ArrayFileSegment
+from graphon.variables import ArrayFileSegment, FileSegment
from graphon.variables.segments import ArrayStringSegment
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params
@@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params):
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
- node_config = {"id": "test_node_id", "data": node_data.model_dump()}
http_client = Mock()
node = DocumentExtractorNode(
- id="test_node_id",
- config=node_config,
+ node_id="test_node_id",
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"]
- mock_excel_instance.parse.side_effect = [df, Exception("Parse error")]
+ mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_mixed_content"
@@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"]
- mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")]
+ mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_all_bad_sheets"
@@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
assert mock_excel_instance.parse.call_count == 2
+@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook"))
+def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file):
+ with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"):
+ _extract_text_from_excel(b"broken")
+
+
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
"""Test extracting text from Excel file with numeric column names."""
@@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
assert expected_manual == result
+@pytest.mark.parametrize(
+ ("extension", "mime_type"),
+ [
+ (".xlsx", "text/plain"),
+ (None, "application/vnd.ms-excel"),
+ ],
+)
+def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type):
+ file = Mock(spec=File)
+ file.extension = extension
+ file.mime_type = mime_type
+
+ with (
+ patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"excel",
+ ),
+ patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_excel",
+ return_value="excel text",
+ ) as mock_extract,
+ ):
+ result = _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+ assert result == "excel text"
+ mock_extract.assert_called_once_with(b"excel")
+
+
+def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
+ file = Mock(spec=File)
+ file.extension = None
+ file.mime_type = None
+
+ with patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"unknown",
+ ):
+ with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"):
+ _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+
+def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_list = Mock(spec=ArrayFileSegment)
+ file_list.value = [Mock(spec=File)]
+ mock_graph_runtime_state.variable_pool.get.return_value = file_list
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("bad file"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "bad file"
+
+
+def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("single file failed"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "single file failed"
+
+
+def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ return_value="single file text",
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs == {"text": "single file text"}
+
+
def _make_docx_zip(use_backslash: bool) -> bytes:
"""Helper to build a minimal in-memory DOCX zip.
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index 782750e02e..aa9a1360b0 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -19,6 +19,20 @@ from graphon.variables import ArrayFileSegment
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_if_else_node(
+ *,
+ node_data: IfElseNodeData | dict[str, object],
+ init_params,
+ graph_runtime_state,
+) -> IfElseNode:
+ return IfElseNode(
+ node_id=str(uuid.uuid4()),
+ graph_init_params=init_params,
+ graph_runtime_state=graph_runtime_state,
+ config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
+ )
+
+
def test_execute_if_else_result_true():
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
@@ -61,9 +75,8 @@ def test_execute_if_else_result_true():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "and",
@@ -104,13 +117,8 @@ def test_execute_if_else_result_true():
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -155,9 +163,8 @@ def test_execute_if_else_result_false():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "or",
@@ -174,13 +181,8 @@ def test_execute_if_else_result_false():
},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -222,11 +224,6 @@ def test_array_file_contains_file_name():
],
)
- node_config = {
- "id": "if-else",
- "data": node_data.model_dump(),
- }
-
# Create properly configured mock for graph_init_params
graph_init_params = Mock()
graph_init_params.workflow_id = "test_workflow"
@@ -242,17 +239,12 @@ def test_array_file_contains_file_name():
}
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=graph_init_params,
- graph_runtime_state=Mock(),
- config=node_config,
- )
+ node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock())
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
@@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
@@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions():
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={
- "id": "if-else",
- "data": node_data,
- },
)
# Mock db.session.close()
@@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure():
}
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
index b217e4e8e7..465a4c0ff4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
@@ -19,6 +19,15 @@ from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract
from graphon.variables import ArrayFileSegment
+def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
+ return ListOperatorNode(
+ node_id="test_node_id",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=MagicMock(),
+ )
+
+
@pytest.fixture
def list_operator_node():
config = {
@@ -35,10 +44,6 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData.model_validate(config)
- node_config = {
- "id": "test_node_id",
- "data": node_data.model_dump(),
- }
# Create properly configured mock for graph_init_params
graph_init_params = MagicMock()
graph_init_params.workflow_id = "test_workflow"
@@ -54,12 +59,7 @@ def list_operator_node():
}
}
- node = ListOperatorNode(
- id="test_node_id",
- config=node_config,
- graph_init_params=graph_init_params,
- graph_runtime_state=MagicMock(),
- )
+ node = _build_list_operator_node(node_data, graph_init_params)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
@@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node):
files = [
File(
filename="image1.jpg",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
- type=FileType.AUDIO,
+ file_type=FileType.AUDIO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
@@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node):
def test_get_file_extract_string_func():
# Create a File object
file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename="test_file.txt",
extension=".txt",
@@ -156,7 +156,7 @@ def test_get_file_extract_string_func():
# Test with empty values
empty_file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename=None,
extension=None,
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
index 543f9878de..5655f80737 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
@@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables):
inputs=user_inputs,
)
- config = {
- "id": "start",
- "data": StartNodeData(title="Start", variables=variables).model_dump(),
- }
+ node_data = StartNodeData(title="Start", variables=variables)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
@@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables):
)
return StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@@ -109,7 +106,7 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
- with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
+ with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"):
node._run()
@@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot():
inputs={"profile": {"age": 20, "name": "Tom"}},
)
- config = {
- "id": "start",
- "data": StartNodeData(
- title="Start",
- variables=[
- VariableEntity(
- variable="profile",
- label="profile",
- type=VariableEntityType.JSON_OBJECT,
- required=True,
- )
- ],
- ).model_dump(),
- }
+ node_data = StartNodeData(
+ title="Start",
+ variables=[
+ VariableEntity(
+ variable="profile",
+ label="profile",
+ type=VariableEntityType.JSON_OBJECT,
+ required=True,
+ )
+ ],
+ )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
index c806181340..284af68319 100644
--- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
@@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import StreamChunkEvent, StreamCompletedEvent
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variables.segments import ArrayFileSegment
@@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode:
runtime = _StubToolRuntime()
node = ToolNode(
- id="node-instance",
- config=config,
+ node_id="node-instance",
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
@@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode:
return node
-def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
+def _collect_events(generator: Generator) -> list[Any]:
events: list[Any] = []
try:
while True:
events.append(next(generator))
- except StopIteration as stop:
- return events, stop.value
+ except StopIteration:
+ return events
def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]:
@@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li
node_id=tool_node._node_id,
tool_runtime=ToolRuntimeHandle(raw=object()),
)
- return _collect_events(generator)
+ events = _collect_events(generator)
+ completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
+ assert completed_events
+ return events, completed_events[-1].node_run_result.llm_usage
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
@@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode):
def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
index c8ddc53284..e3b5e3b591 100644
--- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
@@ -1,10 +1,10 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
+from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": TRIGGER_PLUGIN_NODE_TYPE,
- "title": "Trigger Event",
- "plugin_id": "plugin-id",
- "provider_id": "provider-id",
- "event_name": "event-name",
- "subscription_id": "subscription-id",
- "plugin_unique_identifier": "plugin-unique-identifier",
- "event_parameters": {},
- },
- }
+def _build_node_data() -> TriggerEventNodeData:
+ return TriggerEventNodeData(
+ type=TRIGGER_PLUGIN_NODE_TYPE,
+ title="Trigger Event",
+ plugin_id="plugin-id",
+ provider_id="provider-id",
+ event_name="event-name",
+ subscription_id="subscription-id",
+ plugin_unique_identifier="plugin-unique-identifier",
+ event_parameters={},
)
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
index 1bbc12b23f..07d03bec05 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
@@ -30,11 +30,6 @@ def create_webhook_node(
tenant_id: str = "test-tenant",
) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "webhook-node-1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="test-workflow",
graph_config={},
@@ -56,8 +51,8 @@ def create_webhook_node(
)
node = TriggerWebhookNode(
- id="webhook-node-1",
- config=node_config,
+ node_id="webhook-node-1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -66,10 +61,6 @@ def create_webhook_node(
runtime_state.app_config = Mock()
runtime_state.app_config.tenant_id = tenant_id
- # Provide compatibility alias expected by node implementation
- # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
- node.node_id = node.id
-
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
index 427afa96ec..b839490d3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
@@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool
def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="1",
graph_config={},
@@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
start_at=0,
)
node = TriggerWebhookNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
# Provide tenant_id for conversion path
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
- # Compatibility alias for some nodes referencing `self.node_id`
- node.node_id = node.id
-
return node
@@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params():
"""Test webhook node execution with file parameter extraction."""
# Create mock file objects
file1 = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="image.jpg",
@@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params():
)
file2 = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file2",
filename="document.pdf",
@@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
@@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params():
def test_webhook_node_run_mixed_parameters():
"""Test webhook node execution with mixed parameter types."""
file_obj = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="test.jpg",
@@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
new file mode 100644
index 0000000000..8b5fceeb37
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
@@ -0,0 +1,350 @@
+from types import SimpleNamespace
+
+import pytest
+from pydantic import BaseModel
+
+from core.workflow.human_input_adapter import (
+ DeliveryMethodType,
+ EmailDeliveryConfig,
+ EmailDeliveryMethod,
+ EmailRecipients,
+ WebAppDeliveryMethod,
+ _WebAppDeliveryConfig,
+ adapt_human_input_node_data_for_graph,
+ adapt_node_config_for_graph,
+ adapt_node_data_for_graph,
+ is_human_input_webapp_enabled,
+ parse_human_input_delivery_methods,
+)
+from graphon.enums import BuiltinNodeTypes
+from graphon.nodes.base.variable_template_parser import VariableTemplateParser
+
+
+def test_email_delivery_config_helpers_render_and_sanitize_text() -> None:
+ variable_pool = SimpleNamespace(
+ convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42"))
+ )
+
+ rendered = EmailDeliveryConfig.render_body_template(
+ body="Open {{#url#}} and use {{#node.value#}}",
+ url="https://example.com",
+ variable_pool=variable_pool,
+ )
+ sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team")
+ html = EmailDeliveryConfig.render_markdown_body(
+ "**Hello** [mail](mailto:test@example.com)"
+ )
+
+ assert rendered == "Open https://example.com and use 42"
+ assert sanitized == "Hello alert(1) Team"
+ assert "Hello" in html
+ assert " Team")
- html = EmailDeliveryConfig.render_markdown_body(
- "**Hello** [mail](mailto:test@example.com)"
- )
-
- assert rendered == "Open https://example.com and use 42"
- assert sanitized == "Hello alert(1) Team"
- assert "Hello" in html
- assert "