mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 15:20:15 +08:00
chore(api): adapt Graphon 0.2.2 upgrade (#35377)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
|
||||
@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
return str(value_type.exposed_type())
|
||||
|
||||
|
||||
class FullContentDict(TypedDict):
|
||||
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"value_type": str(variable_file.value_type.exposed_type()),
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value_type": str(v.value_type.exposed_type()),
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
|
||||
@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
return str(SegmentType(value).exposed_type())
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
|
||||
@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.protocols import WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
from graphon.http.protocols import HttpResponseProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.file import File
|
||||
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
|
||||
@@ -352,11 +352,11 @@ class DatasourceManager:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.CUSTOM,
|
||||
file_type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
|
||||
@@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
@@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.tools.errors import ToolSSRFError
|
||||
from graphon.http.response import HttpResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -267,4 +268,47 @@ class SSRFProxy:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
|
||||
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
|
||||
return HttpResponse(
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
content=response.content,
|
||||
url=str(response.url) if response.url else None,
|
||||
reason_phrase=response.reason_phrase,
|
||||
fallback_text=response.text,
|
||||
)
|
||||
|
||||
|
||||
class GraphonSSRFProxy:
|
||||
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
graphon_ssrf_proxy = GraphonSSRFProxy()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
@@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
@@ -168,7 +170,7 @@ class ModelInstance:
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@@ -193,7 +195,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@@ -213,7 +215,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@@ -235,7 +237,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
@@ -252,7 +254,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@@ -277,7 +279,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@@ -305,7 +307,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@@ -324,7 +326,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
@@ -340,7 +342,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
@@ -357,14 +359,14 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
|
||||
@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
provider_schema.icon_small_dark.zh_hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
else provider_schema.icon_small_dark.en_us
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
_closed: bool
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||
"""Initialize with file path."""
|
||||
self._closed = False
|
||||
self.file_path = file_path
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Best-effort cleanup for downloaded temporary files."""
|
||||
if getattr(self, "_closed", False):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
temp_file = getattr(self, "temp_file", None)
|
||||
if temp_file is None:
|
||||
return
|
||||
|
||||
try:
|
||||
close_result = temp_file.close()
|
||||
if inspect.isawaitable(close_result):
|
||||
close_awaitable = getattr(close_result, "close", None)
|
||||
if callable(close_awaitable):
|
||||
close_awaitable()
|
||||
except Exception:
|
||||
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
self.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
|
||||
@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.helper import parse_uuid_str_or_none
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
@@ -517,11 +517,11 @@ class DatasetRetrieval:
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@@ -28,7 +28,7 @@ class ToolFileManager:
|
||||
def _build_graph_file_reference(tool_file: ToolFile) -> File:
|
||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||
return File(
|
||||
type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
file_type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
|
||||
@@ -1082,7 +1082,12 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.value
|
||||
if not isinstance(variable_selector, list) or not all(
|
||||
isinstance(selector_part, str) for selector_part in variable_selector
|
||||
):
|
||||
raise ToolParameterError("Variable tool input must be a variable selector")
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
|
||||
@@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
transfer_method_value = file_dict.get("transfer_method")
|
||||
if not isinstance(transfer_method_value, str):
|
||||
raise ValueError("Workflow file mapping is missing a valid transfer_method")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Workflow-layer adapters for legacy human-input payload keys.
|
||||
"""Workflow-to-Graphon adapters for persisted node payloads.
|
||||
|
||||
Stored workflow graphs and editor payloads may still use Dify-specific human
|
||||
input recipient keys. Normalize them here before handing configs to
|
||||
`graphon` so graph-owned models only see graph-neutral field names.
|
||||
Stored workflow graphs and editor payloads still contain a small set of
|
||||
Dify-owned field spellings and value shapes. Adapt them here before handing the
|
||||
payload to Graphon so Graphon-owned models only see current contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
|
||||
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
|
||||
|
||||
|
||||
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
|
||||
normalized = normalize_human_input_node_data_for_graph(node_data)
|
||||
normalized = adapt_human_input_node_data_for_graph(node_data)
|
||||
raw_delivery_methods = normalized.get("delivery_methods")
|
||||
if not isinstance(raw_delivery_methods, list):
|
||||
return []
|
||||
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
|
||||
return False
|
||||
|
||||
|
||||
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
|
||||
|
||||
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return normalized
|
||||
return normalize_human_input_node_data_for_graph(normalized)
|
||||
node_type = normalized.get("type")
|
||||
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return adapt_human_input_node_data_for_graph(normalized)
|
||||
if node_type == BuiltinNodeTypes.TOOL:
|
||||
return _adapt_tool_node_data_for_graph(normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_config)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
|
||||
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
|
||||
if data_mapping is None:
|
||||
return normalized
|
||||
|
||||
normalized["data"] = normalize_node_data_for_graph(data_mapping)
|
||||
normalized["data"] = adapt_node_data_for_graph(data_mapping)
|
||||
return normalized
|
||||
|
||||
|
||||
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(node_data)
|
||||
|
||||
raw_tool_configurations = normalized.get("tool_configurations")
|
||||
if not isinstance(raw_tool_configurations, Mapping):
|
||||
return normalized
|
||||
|
||||
existing_tool_parameters = normalized.get("tool_parameters")
|
||||
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
|
||||
normalized_tool_configurations: dict[str, Any] = {}
|
||||
found_legacy_tool_inputs = False
|
||||
|
||||
for name, value in raw_tool_configurations.items():
|
||||
if not isinstance(value, Mapping):
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, dict(value))
|
||||
|
||||
flattened_value = _flatten_legacy_tool_configuration_value(
|
||||
input_type=input_type,
|
||||
input_value=input_value,
|
||||
)
|
||||
if flattened_value is not None:
|
||||
normalized_tool_configurations[name] = flattened_value
|
||||
|
||||
if not found_legacy_tool_inputs:
|
||||
return normalized
|
||||
|
||||
normalized["tool_parameters"] = normalized_tool_parameters
|
||||
normalized["tool_configurations"] = normalized_tool_configurations
|
||||
return normalized
|
||||
|
||||
|
||||
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
|
||||
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
|
||||
return input_value
|
||||
|
||||
if (
|
||||
input_type == "variable"
|
||||
and isinstance(input_value, list)
|
||||
and all(isinstance(item, str) for item in input_value)
|
||||
):
|
||||
return "{{#" + ".".join(input_value) + "#}}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
@@ -291,9 +349,9 @@ __all__ = [
|
||||
"MemberRecipient",
|
||||
"WebAppDeliveryMethod",
|
||||
"_WebAppDeliveryConfig",
|
||||
"adapt_human_input_node_data_for_graph",
|
||||
"adapt_node_config_for_graph",
|
||||
"adapt_node_data_for_graph",
|
||||
"is_human_input_webapp_enabled",
|
||||
"normalize_human_input_node_data_for_graph",
|
||||
"normalize_node_config_for_graph",
|
||||
"normalize_node_data_for_graph",
|
||||
"parse_human_input_delivery_methods",
|
||||
]
|
||||
@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_runtime import (
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph.graph import NodeFactory
|
||||
from graphon.model_runtime.memory import PromptMessageMemory
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
"""Resolve the production node class for the requested type/version."""
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_http_client = graphon_ssrf_proxy
|
||||
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
|
||||
self._dify_context,
|
||||
conversation_id_getter=self._conversation_id,
|
||||
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
# Graph configs are initially validated against permissive shared node data.
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
),
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=False,
|
||||
include_llm_file_saver=False,
|
||||
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
validate_node_data = getattr(node_class, "validate_node_data", None)
|
||||
if callable(validate_node_data):
|
||||
return cast("BaseNodeData", validate_node_data(node_data))
|
||||
return node_data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.nodes.llm.runtime_protocols import (
|
||||
PreparedLLMProtocol,
|
||||
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .human_input_compat import (
|
||||
from .human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
DeliveryMethodType,
|
||||
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeExecutionType,
|
||||
@@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
@@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -50,6 +49,18 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
|
||||
if value is None or isinstance(value, (str, float)):
|
||||
return value
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_sequence_item(value: object) -> str:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
@@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
resolved_value: str | Sequence[str] | int | float | None
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
|
||||
@@ -148,11 +148,11 @@ def _build_from_local_file(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
reference=build_file_reference(record_id=str(row.id)),
|
||||
@@ -196,11 +196,11 @@ def _build_from_remote_url(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
@@ -222,9 +222,9 @@ def _build_from_remote_url(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=filename,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
@@ -263,9 +263,9 @@ def _build_from_tool_file(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=tool_file.name,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
@@ -306,9 +306,9 @@ def _build_from_datasource_file(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
file_id=mapping.get("datasource_file_id"),
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
reference=build_file_reference(record_id=str(datasource_file.id)),
|
||||
|
||||
@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
|
||||
|
||||
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type.exposed_type().value
|
||||
return str(v.value_type.exposed_type())
|
||||
else:
|
||||
value_type = v.get("value_type")
|
||||
if value_type is None:
|
||||
raise ValueError("value_type is required but not provided")
|
||||
return value_type.exposed_type().value
|
||||
return str(value_type.exposed_type())
|
||||
|
||||
@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
return str(SegmentType(value).exposed_type())
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": value.value,
|
||||
"value_type": value.value_type.exposed_type().value,
|
||||
"value_type": str(value.value_type.exposed_type()),
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, dict):
|
||||
|
||||
@@ -6,7 +6,7 @@ import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from core.workflow.human_input_compat import DeliveryMethodType
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.helper import generate_string
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
|
||||
return tenant_resolver()
|
||||
|
||||
|
||||
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
|
||||
"""Build a graph `File` directly from serialized metadata."""
|
||||
|
||||
def _coerce_file_type(value: Any) -> FileType:
|
||||
if isinstance(value, FileType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return FileType.value_of(value)
|
||||
raise ValueError("file type is required in file mapping")
|
||||
|
||||
mapping = dict(file_mapping)
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if isinstance(transfer_method_value, FileTransferMethod):
|
||||
transfer_method = transfer_method_value
|
||||
elif isinstance(transfer_method_value, str):
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
else:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
|
||||
file_id = mapping.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
legacy_id = mapping.get("id")
|
||||
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
|
||||
|
||||
related_id = resolve_file_record_id(mapping)
|
||||
if related_id is None:
|
||||
raw_related_id = mapping.get("related_id")
|
||||
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
|
||||
|
||||
remote_url = mapping.get("remote_url")
|
||||
if not isinstance(remote_url, str) or not remote_url:
|
||||
url = mapping.get("url")
|
||||
remote_url = url if isinstance(url, str) and url else None
|
||||
|
||||
reference = mapping.get("reference")
|
||||
if not isinstance(reference, str) or not reference:
|
||||
reference = None
|
||||
|
||||
filename = mapping.get("filename")
|
||||
if not isinstance(filename, str):
|
||||
filename = None
|
||||
|
||||
extension = mapping.get("extension")
|
||||
if not isinstance(extension, str):
|
||||
extension = None
|
||||
|
||||
mime_type = mapping.get("mime_type")
|
||||
if not isinstance(mime_type, str):
|
||||
mime_type = None
|
||||
|
||||
size = mapping.get("size", -1)
|
||||
if not isinstance(size, int):
|
||||
size = -1
|
||||
|
||||
storage_key = mapping.get("storage_key")
|
||||
if not isinstance(storage_key, str):
|
||||
storage_key = None
|
||||
|
||||
tenant_id = mapping.get("tenant_id")
|
||||
if not isinstance(tenant_id, str):
|
||||
tenant_id = None
|
||||
|
||||
dify_model_identity = mapping.get("dify_model_identity")
|
||||
if not isinstance(dify_model_identity, str):
|
||||
dify_model_identity = FILE_MODEL_IDENTITY
|
||||
|
||||
tool_file_id = mapping.get("tool_file_id")
|
||||
if not isinstance(tool_file_id, str):
|
||||
tool_file_id = None
|
||||
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not isinstance(upload_file_id, str):
|
||||
upload_file_id = None
|
||||
|
||||
datasource_file_id = mapping.get("datasource_file_id")
|
||||
if not isinstance(datasource_file_id, str):
|
||||
datasource_file_id = None
|
||||
|
||||
return File(
|
||||
file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
reference=reference,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
storage_key=storage_key,
|
||||
dify_model_identity=dify_model_identity,
|
||||
url=remote_url,
|
||||
tool_file_id=tool_file_id,
|
||||
upload_file_id=upload_file_id,
|
||||
datasource_file_id=datasource_file_id,
|
||||
)
|
||||
|
||||
|
||||
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
|
||||
"""Recursively rebuild serialized graph file payloads into `File` objects.
|
||||
|
||||
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
|
||||
`model_validate_json()`. Dify keeps this recovery path at the model boundary
|
||||
so historical JSON blobs remain readable without reintroducing global graph
|
||||
patches or test-local coercion.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
if maybe_file_object(value):
|
||||
return build_file_from_mapping_without_lookup(file_mapping=value)
|
||||
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def build_file_from_stored_mapping(
|
||||
*,
|
||||
file_mapping: Mapping[str, Any],
|
||||
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
|
||||
pass
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
|
||||
remote_url = mapping.get("remote_url")
|
||||
if not isinstance(remote_url, str) or not remote_url:
|
||||
url = mapping.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
mapping["remote_url"] = url
|
||||
return File.model_validate(mapping)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=mapping)
|
||||
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
|
||||
@@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.variable_prefixes import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
@@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
from .utils.file_input_compat import build_file_from_stored_mapping
|
||||
from .utils.file_input_compat import (
|
||||
build_file_from_mapping_without_lookup,
|
||||
build_file_from_stored_mapping,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -290,7 +293,7 @@ class Workflow(Base): # bug
|
||||
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise NodeNotFoundError(node_id)
|
||||
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
|
||||
@@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
|
||||
return cast(Any, value)
|
||||
normalized_file = dict(value)
|
||||
normalized_file.pop("tenant_id", None)
|
||||
return File.model_validate(normalized_file)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
|
||||
elif isinstance(value, list) and value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
@@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
|
||||
for item in value_list:
|
||||
normalized_file = dict(cast(dict[str, Any], item))
|
||||
normalized_file.pop("tenant_id", None)
|
||||
file_list.append(File.model_validate(normalized_file))
|
||||
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return cast(Any, value)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Tencent APM tracing implementation with separated concerns
|
||||
"""
|
||||
"""Tencent APM tracing with idempotent client cleanup."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
"""
|
||||
Tencent APM trace implementation with single responsibility principle.
|
||||
Acts as a coordinator that delegates specific tasks to specialized classes.
|
||||
|
||||
The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
|
||||
explicitly in tests or implicitly during garbage collection, so shutdown
|
||||
must be safe to call multiple times.
|
||||
"""
|
||||
|
||||
trace_client: TencentTraceClient
|
||||
_closed: bool
|
||||
|
||||
def __init__(self, tencent_config: TencentConfig):
|
||||
super().__init__(tencent_config)
|
||||
self._closed = False
|
||||
self.trace_client = TencentTraceClient(
|
||||
service_name=tencent_config.service_name,
|
||||
endpoint=tencent_config.endpoint,
|
||||
@@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record message trace duration")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure proper cleanup on garbage collection."""
|
||||
def close(self) -> None:
|
||||
"""Synchronously and idempotently shutdown the underlying trace client."""
|
||||
if getattr(self, "_closed", False):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
trace_client = getattr(self, "trace_client", None)
|
||||
if trace_client is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(self, "trace_client"):
|
||||
self.trace_client.shutdown()
|
||||
shutdown_result = trace_client.shutdown()
|
||||
if inspect.isawaitable(shutdown_result):
|
||||
close_awaitable = getattr(shutdown_result, "close", None)
|
||||
if callable(close_awaitable):
|
||||
close_awaitable()
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure best-effort cleanup on garbage collection without retrying shutdown."""
|
||||
self.close()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import gc
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_trace_tencent.config import TencentConfig
|
||||
@@ -632,13 +634,38 @@ class TestTencentDataTrace:
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
|
||||
def test_del(self, tencent_data_trace):
|
||||
def test_close(self, tencent_data_trace):
|
||||
client = tencent_data_trace.trace_client
|
||||
tencent_data_trace.__del__()
|
||||
tencent_data_trace.close()
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_del_exception(self, tencent_data_trace):
|
||||
def test_close_is_idempotent(self, tencent_data_trace):
|
||||
client = tencent_data_trace.trace_client
|
||||
|
||||
tencent_data_trace.close()
|
||||
tencent_data_trace.close()
|
||||
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_close_exception(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.__del__()
|
||||
tencent_data_trace.close()
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
|
||||
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
|
||||
shutdown = AsyncMock()
|
||||
tencent_data_trace.trace_client.shutdown = shutdown
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
tencent_data_trace.close()
|
||||
gc.collect()
|
||||
|
||||
shutdown.assert_called_once()
|
||||
assert not [
|
||||
warning
|
||||
for warning in caught
|
||||
if issubclass(warning.category, RuntimeWarning)
|
||||
and "AsyncMockMixin._execute_mock_call" in str(warning.message)
|
||||
]
|
||||
|
||||
@@ -44,7 +44,7 @@ dependencies = [
|
||||
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.1.2",
|
||||
"graphon~=0.2.2",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.2",
|
||||
]
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created, app_was_deleted, app_was_updated
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
@@ -30,7 +30,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@@ -476,7 +476,7 @@ class RagPipelineService:
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
node_type_enum: NodeType = node_type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
|
||||
# return default block config
|
||||
|
||||
@@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
|
||||
return TruncationResult(StringSegment(value=fallback_result.value), True)
|
||||
|
||||
# Apply final fallback - convert to JSON string and truncate
|
||||
json_str = dumps_with_segments(result.value, ensure_ascii=False)
|
||||
json_str = dumps_with_segments(result.value)
|
||||
if len(json_str) > self._max_size_bytes:
|
||||
json_str = json_str[: self._max_size_bytes] + "..."
|
||||
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
|
||||
|
||||
@@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
variable_id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
@@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
variable_id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
@@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
variable_id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
@@ -1067,7 +1067,7 @@ class DraftVariableSaver:
|
||||
filename = f"{self._generate_filename(name)}.txt"
|
||||
else:
|
||||
# For other types, store as JSON
|
||||
original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
|
||||
original_content_serialized = dumps_with_segments(value_seg.value)
|
||||
content_type = "application/json"
|
||||
filename = f"{self._generate_filename(name)}.json"
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.trigger.constants import is_trigger_node_type
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryChannelConfig,
|
||||
normalize_human_input_node_data_for_graph,
|
||||
adapt_human_input_node_data_for_graph,
|
||||
parse_human_input_delivery_methods,
|
||||
)
|
||||
from core.workflow.node_factory import (
|
||||
@@ -791,7 +791,7 @@ class WorkflowService:
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
node_type_enum: NodeType = node_type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
|
||||
# return default block config
|
||||
@@ -1096,7 +1096,7 @@ class WorkflowService:
|
||||
raise ValueError("Node type must be human-input.")
|
||||
|
||||
node_data = HumanInputNodeData.model_validate(
|
||||
normalize_human_input_node_data_for_graph(node_config["data"]),
|
||||
adapt_human_input_node_data_for_graph(node_config["data"]),
|
||||
from_attributes=True,
|
||||
)
|
||||
delivery_method = self._resolve_human_input_delivery_method(
|
||||
@@ -1237,9 +1237,10 @@ class WorkflowService:
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
|
||||
node = HumanInputNode(
|
||||
id=node_config["id"],
|
||||
config=node_config,
|
||||
node_id=node_config["id"],
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=DifyHumanInputNodeRuntime(run_context),
|
||||
@@ -1529,7 +1530,7 @@ class WorkflowService:
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
|
||||
try:
|
||||
HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
|
||||
HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
|
||||
from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult, StreamCompletedEvent
|
||||
|
||||
@@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
|
||||
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
provider_type="plugin",
|
||||
provider_name="p",
|
||||
plugin_id="plug",
|
||||
datasource_name="ds",
|
||||
),
|
||||
graph_init_params=_GP(),
|
||||
graph_runtime_state=_GS(vp),
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.code.code_node import CodeNode
|
||||
from graphon.nodes.code.entities import CodeNodeData
|
||||
from graphon.nodes.code.limits import CodeNodeLimits
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
@@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=code_config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=CodeNodeData.model_validate(code_config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
code_executor=node_factory._code_executor,
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
|
||||
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
@@ -75,8 +75,8 @@ def init_http_node(config: dict):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
@@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=graph_config["nodes"][1],
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.llm.file_saver import LLMFileSaver
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
@@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
llm_file_saver = MagicMock(spec=LLMFileSaver)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=LLMNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@@ -11,6 +11,7 @@ from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
|
||||
@@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ParameterExtractorNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.template_rendering import TemplateRenderError
|
||||
@@ -86,8 +87,8 @@ def test_execute_template_transform():
|
||||
assert graph is not None
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=TemplateTransformNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
jinja2_template_renderer=_SimpleJinja2Renderer(),
|
||||
|
||||
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.nodes.protocols import ToolFileManagerProtocol
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from graphon.nodes.tool.tool_node import ToolNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
@@ -60,8 +61,8 @@ def init_tool_node(config: dict):
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ToolNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@@ -101,8 +101,8 @@ def _build_graph(
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
node_id="start",
|
||||
config=start_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -116,8 +116,8 @@ def _build_graph(
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
node_id="human",
|
||||
config=human_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
@@ -130,8 +130,8 @@ def _build_graph(
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
node_id="end",
|
||||
config=end_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
file_related_id = related_id
|
||||
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
file_id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
remote_url=remote_url,
|
||||
|
||||
@@ -271,7 +271,7 @@ def _create_recipient(
|
||||
|
||||
|
||||
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
|
||||
from core.workflow.human_input_compat import DeliveryMethodType
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from models.human_input import ConsoleDeliveryPayload
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
file_list = [
|
||||
File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="http://u",
|
||||
)
|
||||
|
||||
@@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
|
||||
|
||||
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
type=FileType.IMAGE,
|
||||
file_id="test_file_id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_upload_file_id",
|
||||
filename="test.jpg",
|
||||
@@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
|
||||
|
||||
# Create a File object with REMOTE_URL transfer method
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
type=FileType.IMAGE,
|
||||
file_id="test_file_id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/test.jpg",
|
||||
filename="test.jpg",
|
||||
|
||||
@@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from graphon.variables import StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
|
||||
assert response.created_at == int(created_at.timestamp())
|
||||
assert response.updated_at == int(created_at.timestamp())
|
||||
|
||||
def test_variable_response_normalizes_string_value_type_alias(self):
|
||||
response = ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SegmentType.INTEGER.value,
|
||||
}
|
||||
)
|
||||
|
||||
assert response.value_type == "number"
|
||||
|
||||
def test_variable_response_normalizes_callable_exposed_type(self):
|
||||
response = ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
|
||||
}
|
||||
)
|
||||
|
||||
assert response.value_type == "string"
|
||||
|
||||
def test_serialize_value_type_supports_segments_and_mappings(self):
|
||||
assert serialize_value_type(StringSegment(value="hello")) == "string"
|
||||
assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
|
||||
|
||||
def test_variable_pagination_response(self):
|
||||
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
{
|
||||
|
||||
@@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
def create_test_file(self, file_id: str = "test_file_1") -> File:
|
||||
"""Create a test File object"""
|
||||
return File(
|
||||
id=file_id,
|
||||
type=FileType.DOCUMENT,
|
||||
file_id=file_id,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related_123",
|
||||
filename=f"{file_id}.txt",
|
||||
|
||||
@@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow import node_factory as node_factory_module
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
@@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: _StubToolNodeData,
|
||||
*,
|
||||
graph_init_params,
|
||||
graph_runtime_state,
|
||||
**_kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
@@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
original_create_node = DifyNodeFactory.create_node
|
||||
original_resolve_node_class = node_factory_module.resolve_workflow_node_class
|
||||
|
||||
def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_data = typed_node_config["data"]
|
||||
if node_data.type == BuiltinNodeTypes.TOOL:
|
||||
return _StubToolNode(
|
||||
id=str(typed_node_config["id"]),
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
return original_create_node(self, typed_node_config)
|
||||
def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
if node_type == BuiltinNodeTypes.TOOL:
|
||||
return _StubToolNode
|
||||
return original_resolve_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
|
||||
mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
|
||||
|
||||
|
||||
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
|
||||
|
||||
@@ -26,8 +26,8 @@ def _build_file(
|
||||
extension: str | None = None,
|
||||
) -> File:
|
||||
return File(
|
||||
id="file-id",
|
||||
type=FileType.IMAGE,
|
||||
file_id="file-id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
reference=reference,
|
||||
remote_url=remote_url,
|
||||
@@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
|
||||
|
||||
assert runtime.multimodal_send_format == "url"
|
||||
|
||||
with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
|
||||
with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
|
||||
assert runtime.http_get("http://example", follow_redirects=False) == "response"
|
||||
mock_get.assert_called_once_with("http://example", follow_redirects=False)
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = id
|
||||
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = node_id
|
||||
self.config = config
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
@@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
|
||||
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
|
||||
built = File(
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tool_file_1",
|
||||
extension=".png",
|
||||
@@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
file_in = File(
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tf",
|
||||
extension=".pdf",
|
||||
|
||||
@@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
|
||||
|
||||
# Assert
|
||||
assert simple_provider.provider == "openai"
|
||||
assert simple_provider.label.en_US == "OpenAI"
|
||||
assert simple_provider.label.en_us == "OpenAI"
|
||||
assert simple_provider.supported_model_types == [ModelType.LLM]
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
|
||||
|
||||
def test_file():
|
||||
file = File(
|
||||
id="test-file",
|
||||
file_id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="test-related-id",
|
||||
filename="image.png",
|
||||
@@ -25,27 +25,21 @@ def test_file():
|
||||
assert file.size == 67
|
||||
|
||||
|
||||
def test_file_model_validate_accepts_legacy_tenant_id():
|
||||
data = {
|
||||
"id": "test-file",
|
||||
"tenant_id": "test-tenant-id",
|
||||
"type": "image",
|
||||
"transfer_method": "tool_file",
|
||||
"related_id": "test-related-id",
|
||||
"filename": "image.png",
|
||||
"extension": ".png",
|
||||
"mime_type": "image/png",
|
||||
"size": 67,
|
||||
"storage_key": "test-storage-key",
|
||||
"url": "https://example.com/image.png",
|
||||
# Extra legacy fields
|
||||
"tool_file_id": "tool-file-123",
|
||||
"upload_file_id": "upload-file-456",
|
||||
"datasource_file_id": "datasource-file-789",
|
||||
}
|
||||
def test_file_constructor_accepts_legacy_tenant_id():
|
||||
file = File(
|
||||
file_id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
tool_file_id="tool-file-123",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=67,
|
||||
storage_key="test-storage-key",
|
||||
url="https://example.com/image.png",
|
||||
)
|
||||
|
||||
file = File.model_validate(data)
|
||||
|
||||
assert file.related_id == "test-related-id"
|
||||
assert file.related_id == "tool-file-123"
|
||||
assert file.storage_key == "test-storage-key"
|
||||
assert "tenant_id" not in file.model_dump()
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.helper.ssrf_proxy import (
|
||||
SSRF_DEFAULT_MAX_RETRIES,
|
||||
SSRFProxy,
|
||||
_get_user_provided_host_header,
|
||||
_to_graphon_http_response,
|
||||
graphon_ssrf_proxy,
|
||||
make_request,
|
||||
max_retries_exceeded_error,
|
||||
request_error,
|
||||
)
|
||||
|
||||
|
||||
@@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
||||
|
||||
def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
|
||||
response = httpx.Response(
|
||||
201,
|
||||
headers={"X-Test": "1"},
|
||||
content=b"payload",
|
||||
request=httpx.Request("GET", "https://example.com/resource"),
|
||||
)
|
||||
|
||||
wrapped = _to_graphon_http_response(response)
|
||||
|
||||
assert wrapped.status_code == 201
|
||||
assert wrapped.headers == {"x-test": "1", "content-length": "7"}
|
||||
assert wrapped.content == b"payload"
|
||||
assert wrapped.url == "https://example.com/resource"
|
||||
assert wrapped.reason_phrase == "Created"
|
||||
assert wrapped.text == "payload"
|
||||
|
||||
|
||||
def test_ssrf_proxy_exposes_expected_error_types() -> None:
|
||||
proxy = SSRFProxy()
|
||||
|
||||
assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
|
||||
assert proxy.request_error is request_error
|
||||
assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
|
||||
assert graphon_ssrf_proxy.request_error is request_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
|
||||
def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
|
||||
response = httpx.Response(
|
||||
200,
|
||||
headers={"X-Test": "1"},
|
||||
content=b"ok",
|
||||
request=httpx.Request("GET", "https://example.com/resource"),
|
||||
)
|
||||
|
||||
with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
|
||||
wrapped = getattr(graphon_ssrf_proxy, method_name)(
|
||||
"https://example.com/resource",
|
||||
max_retries=3,
|
||||
headers={"X-Test": "1"},
|
||||
)
|
||||
|
||||
mock_method.assert_called_once_with(
|
||||
url="https://example.com/resource",
|
||||
max_retries=3,
|
||||
headers={"X-Test": "1"},
|
||||
)
|
||||
assert wrapped.status_code == 200
|
||||
assert wrapped.url == "https://example.com/resource"
|
||||
assert wrapped.content == b"ok"
|
||||
|
||||
@@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
ProviderCredentialSchema,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestPluginModelRuntime:
|
||||
assert len(providers) == 1
|
||||
assert providers[0].provider == "langgenius/openai/openai"
|
||||
assert providers[0].provider_name == "openai"
|
||||
assert providers[0].label.en_US == "OpenAI"
|
||||
assert providers[0].label.en_us == "OpenAI"
|
||||
client.fetch_model_providers.assert_called_once_with("tenant")
|
||||
|
||||
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:
|
||||
|
||||
@@ -466,7 +466,7 @@ class TestConverter:
|
||||
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
|
||||
file_param = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/file.png",
|
||||
storage_key="",
|
||||
@@ -499,14 +499,14 @@ class TestConverter:
|
||||
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
|
||||
file_one = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/a.txt",
|
||||
storage_key="",
|
||||
)
|
||||
file_two = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/b.txt",
|
||||
storage_key="",
|
||||
|
||||
@@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
|
||||
files = [
|
||||
File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="",
|
||||
@@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
|
||||
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
|
||||
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
@@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
|
||||
memory = MagicMock(spec=TokenBufferMemory)
|
||||
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
@@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
@@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
|
||||
@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import Conversation
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
|
||||
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
|
||||
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
# from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import gc
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import UserDict
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from docx import Document
|
||||
@@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
|
||||
WordExtractor("not-a-file", "tenant", "user")
|
||||
|
||||
|
||||
def test_del_closes_temp_file():
|
||||
def test_close_closes_temp_file():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
|
||||
WordExtractor.__del__(extractor)
|
||||
extractor.close()
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
def test_close_is_idempotent():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
|
||||
extractor.close()
|
||||
extractor.close()
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
def test_close_handles_async_close_mock():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
extractor.temp_file.close = AsyncMock()
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
extractor.close()
|
||||
gc.collect()
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
assert not [
|
||||
warning
|
||||
for warning in caught
|
||||
if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
|
||||
]
|
||||
|
||||
|
||||
def test_extract_images_handles_invalid_external_cases(monkeypatch):
|
||||
class FakeTargetRef:
|
||||
def __contains__(self, item):
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
|
||||
HumanInputFormSubmissionRepository,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
|
||||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@@ -6,9 +6,9 @@ from models.workflow import Workflow
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="storage_key",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import dataclasses
|
||||
from typing import Annotated
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
@@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentUnion,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
@@ -47,6 +49,26 @@ from graphon.variables.variables import (
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
|
||||
|
||||
type SegmentUnion = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringSegment, Tag(SegmentType.STRING)]
|
||||
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
|
||||
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
|
||||
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
|
||||
def _build_variable_pool(
|
||||
@@ -123,7 +145,7 @@ def create_test_file(
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing"""
|
||||
return File(
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
@@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
|
||||
assert restored == model
|
||||
|
||||
def test_all_segments_serialization(self):
|
||||
"""Test serialization/deserialization of all segment types"""
|
||||
"""Test file-aware segment serialization through Dify's model boundary."""
|
||||
# Create one instance of each segment type
|
||||
test_file = create_test_file()
|
||||
|
||||
@@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
|
||||
# Test serialization and deserialization
|
||||
model = _Segments(segments=all_segments)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json_str)
|
||||
loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
|
||||
|
||||
# Verify all segments are preserved
|
||||
assert len(loaded.segments) == len(all_segments)
|
||||
@@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
|
||||
assert loaded_segment.value == original.value
|
||||
|
||||
def test_all_variables_serialization(self):
|
||||
"""Test serialization/deserialization of all variable types"""
|
||||
"""Test file-aware variable serialization through Dify's model boundary."""
|
||||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
@@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
|
||||
# Test serialization and deserialization
|
||||
model = _Variables(variables=all_variables)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Variables.model_validate_json(json_str)
|
||||
loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
|
||||
|
||||
# Verify all variables are preserved
|
||||
assert len(loaded.variables) == len(all_variables)
|
||||
|
||||
@@ -35,7 +35,7 @@ def create_test_file(
|
||||
"""Factory function to create File objects for testing."""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""
|
||||
Mock node factory for testing workflows with third-party service dependencies.
|
||||
"""Mock node factory for third-party-service workflow tests.
|
||||
|
||||
This module provides a MockNodeFactory that automatically detects and mocks nodes
|
||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||
The factory follows the same config adaptation path as production
|
||||
`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
|
||||
implementations before instantiation.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
@@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
:param node_config: Node configuration dictionary
|
||||
:return: Node instance (real or mocked)
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_type = node_data.type
|
||||
|
||||
# Check if this node type should be mocked
|
||||
if node_type in self._mock_node_types:
|
||||
node_id = typed_node_config["id"]
|
||||
|
||||
# Create mock node instance
|
||||
mock_class = self._mock_node_types[node_type]
|
||||
resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
|
||||
if node_type == BuiltinNodeTypes.CODE:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
)
|
||||
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
}:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
)
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
return super().create_node(typed_node_config)
|
||||
return super().create_node(node_config)
|
||||
|
||||
def should_mock_node(self, node_type: NodeType) -> bool:
|
||||
"""
|
||||
|
||||
@@ -55,13 +55,14 @@ class MockNodeMixin:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
config: Any,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> None:
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
@@ -96,7 +97,7 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("message_transformer", MagicMock())
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
node_id=start_config["id"],
|
||||
config=StartNodeData(title="Start", variables=[]),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
id=human_a_config["id"],
|
||||
config=human_a_config,
|
||||
node_id=human_a_config["id"],
|
||||
config=human_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
@@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = HumanInputNode(
|
||||
id=human_b_config["id"],
|
||||
config=human_b_config,
|
||||
node_id=human_b_config["id"],
|
||||
config=human_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
@@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
node_id=end_config["id"],
|
||||
config=end_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.answer.answer_node import AnswerNode
|
||||
from graphon.nodes.answer.entities import AnswerNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
@@ -66,20 +67,15 @@ def test_execute_answer():
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
}
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
),
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
@@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
|
||||
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
provider_type="plugin",
|
||||
provider_name="p",
|
||||
plugin_id="plug",
|
||||
datasource_name="ds",
|
||||
),
|
||||
graph_init_params=gp,
|
||||
graph_runtime_state=gs,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
|
||||
from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response
|
||||
from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
@@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config():
|
||||
assert default_config["retry_config"]["max_retries"] == 7
|
||||
|
||||
|
||||
def test_get_default_config_with_malformed_http_request_config_raises_value_error():
|
||||
with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
|
||||
def test_get_default_config_with_malformed_http_request_config_raises_type_error():
|
||||
with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"):
|
||||
HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"})
|
||||
|
||||
|
||||
@@ -114,8 +114,8 @@ def _build_http_node(
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
return HttpRequestNode(
|
||||
id="http-node",
|
||||
config=node_config,
|
||||
node_id="http-node",
|
||||
config=HttpRequestNodeData.model_validate(node_data),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients
|
||||
from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients
|
||||
from graphon.runtime import VariablePool
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecipientEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryMethodType,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
@@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
entity.status_value = HumanInputFormStatus.SUBMITTED
|
||||
|
||||
|
||||
def _build_human_input_node(
|
||||
*,
|
||||
node_id: str,
|
||||
node_data: HumanInputNodeData | Mapping[str, Any],
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
runtime: DifyHumanInputNodeRuntime,
|
||||
) -> HumanInputNode:
|
||||
typed_node_data = (
|
||||
node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data)
|
||||
)
|
||||
return HumanInputNode(
|
||||
node_id=node_id,
|
||||
config=typed_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
class TestDeliveryMethod:
|
||||
"""Test DeliveryMethod entity."""
|
||||
|
||||
@@ -239,7 +259,7 @@ class TestUserAction:
|
||||
data[field_name] = value
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UserAction(**data)
|
||||
UserAction.model_validate(data)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
|
||||
@@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution:
|
||||
|
||||
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
|
||||
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
node = _build_human_input_node(
|
||||
node_id=config["id"],
|
||||
node_data=config["data"],
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
runtime=runtime,
|
||||
@@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution:
|
||||
|
||||
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
|
||||
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
node = _build_human_input_node(
|
||||
node_id=config["id"],
|
||||
node_data=config["data"],
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
runtime=runtime,
|
||||
@@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution:
|
||||
|
||||
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
|
||||
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
node = _build_human_input_node(
|
||||
node_id=config["id"],
|
||||
node_data=config["data"],
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
runtime=runtime,
|
||||
@@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution:
|
||||
|
||||
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
|
||||
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
node = _build_human_input_node(
|
||||
node_id=config["id"],
|
||||
node_data=config["data"],
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
runtime=runtime,
|
||||
@@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent:
|
||||
form_repository = InMemoryHumanInputFormRepository()
|
||||
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
|
||||
runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined]
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
node = _build_human_input_node(
|
||||
node_id=config["id"],
|
||||
node_data=config["data"],
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
runtime=runtime,
|
||||
|
||||
@@ -11,6 +11,7 @@ from graphon.graph_events import (
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.nodes.human_input.enums import HumanInputFormStatus
|
||||
from graphon.nodes.human_input.human_input_node import HumanInputNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
@@ -25,6 +26,28 @@ class _FakeFormRepository:
|
||||
return self._form
|
||||
|
||||
|
||||
def _create_human_input_node(
|
||||
*,
|
||||
config: dict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
repo: _FakeFormRepository,
|
||||
) -> HumanInputNode:
|
||||
node_data = (
|
||||
config["data"]
|
||||
if isinstance(config["data"], HumanInputNodeData)
|
||||
else HumanInputNodeData.model_validate(config["data"])
|
||||
)
|
||||
return HumanInputNode(
|
||||
node_id=config["id"],
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
|
||||
)
|
||||
|
||||
|
||||
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
|
||||
system_variables = default_system_variables()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
@@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
return HumanInputNode(
|
||||
id="node-1",
|
||||
return _create_human_input_node(
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
|
||||
repo=repo,
|
||||
)
|
||||
|
||||
|
||||
@@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode:
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
return HumanInputNode(
|
||||
id="node-1",
|
||||
return _create_human_input_node(
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
|
||||
repo=repo,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.nodes.iteration.entities import IterationNodeData
|
||||
from graphon.nodes.iteration.exc import IterationGraphNotFoundError
|
||||
from graphon.nodes.iteration.iteration_node import IterationNode
|
||||
from graphon.runtime import (
|
||||
@@ -44,17 +45,14 @@ def _build_iteration_node(
|
||||
) -> IterationNode:
|
||||
init_params = build_test_graph_init_params(graph_config=graph_config)
|
||||
return IterationNode(
|
||||
id="iteration-node",
|
||||
config={
|
||||
"id": "iteration-node",
|
||||
"data": {
|
||||
"type": "iteration",
|
||||
"title": "Iteration",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": ["iteration-node", "output"],
|
||||
"start_node_id": start_node_id,
|
||||
},
|
||||
},
|
||||
node_id="iteration-node",
|
||||
config=IterationNodeData(
|
||||
type="iteration",
|
||||
title="Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iteration-node", "output"],
|
||||
start_node_id=start_node_id,
|
||||
),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@@ -93,6 +93,25 @@ def sample_chunks():
|
||||
}
|
||||
|
||||
|
||||
def _build_node(
|
||||
*,
|
||||
node_id: str,
|
||||
node_data: KnowledgeIndexNodeData | dict[str, object],
|
||||
graph_init_params,
|
||||
graph_runtime_state,
|
||||
) -> KnowledgeIndexNode:
|
||||
return KnowledgeIndexNode(
|
||||
node_id=node_id,
|
||||
config=(
|
||||
node_data
|
||||
if isinstance(node_data, KnowledgeIndexNodeData)
|
||||
else KnowledgeIndexNodeData.model_validate(node_data)
|
||||
),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeIndexNode:
|
||||
"""
|
||||
Test suite for KnowledgeIndexNode.
|
||||
@@ -115,9 +134,9 @@ class TestKnowledgeIndexNode:
|
||||
}
|
||||
|
||||
# Act
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -143,9 +162,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -176,9 +195,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -212,9 +231,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -269,9 +288,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -332,9 +351,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -383,9 +402,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -440,9 +459,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -498,9 +517,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -536,9 +555,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -583,9 +602,9 @@ class TestKnowledgeIndexNode:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex:
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
node = _build_node(
|
||||
node_id=node_id,
|
||||
node_data=config["data"],
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user