Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-14 14:05:44 +08:00
164 changed files with 4827 additions and 3925 deletions

View File

@@ -0,0 +1,79 @@
---
name: e2e-cucumber-playwright
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
---
# Dify E2E Cucumber + Playwright
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
## Scope
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
- Do not use this skill for backend test or API review tasks under `api/`.
## Read Order
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
2. Read only the files directly involved in the task:
- target `.feature` files under `e2e/features/`
- related step files under `e2e/features/step-definitions/`
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
## Local Rules
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
- Browser session behavior comes from `features/support/hooks.ts`:
- default: authenticated session with shared storage state
- `@unauthenticated`: clean browser context
- `@authenticated`: readability/selective-run tag only unless implementation changes
- `@fresh`: only for `e2e:full*` flows
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
## Workflow
1. Rebuild local context.
- Inspect the target feature area.
- Reuse an existing step when wording and behavior already match.
- Add a new step only for a genuinely new user action or assertion.
- Keep edits close to the current capability folder unless the step is broadly reusable.
2. Write behavior-first scenarios.
- Describe user-observable behavior, not DOM mechanics.
- Keep each scenario focused on one workflow or outcome.
- Keep scenarios independent and re-runnable.
3. Write step definitions in the local style.
- Keep one step to one user-visible action or one assertion.
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
- Scope locators to stable containers when the page has repeated elements.
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
4. Use Playwright in the local style.
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
- Use web-first `expect(...)` assertions.
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
5. Validate narrowly.
- Run the narrowest tagged scenario or flow that exercises the change.
- Run `pnpm -C e2e check`.
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
## Review Checklist
- Does the scenario describe behavior rather than implementation?
- Does it fit the current session model, tags, and `DifyWorld` usage?
- Should an existing step be reused instead of adding a new one?
- Are locators user-facing and assertions web-first?
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
- Does it document or implement behavior that differs from the real hooks or configuration?
Lead findings with correctness, flake risk, and architecture drift.
## References
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)

View File

@@ -0,0 +1,4 @@
interface:
display_name: "E2E Cucumber + Playwright"
short_description: "Write and review Dify E2E scenarios."
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."

View File

@@ -0,0 +1,93 @@
# Cucumber Best Practices For Dify E2E
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
Official sources:
- https://cucumber.io/docs/guides/10-minute-tutorial/
- https://cucumber.io/docs/cucumber/step-definitions/
- https://cucumber.io/docs/cucumber/cucumber-expressions/
## What Matters Most
### 1. Treat scenarios as executable specifications
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
Apply it like this:
- write what the user does and what should happen
- avoid UI-internal wording such as selector details, DOM structure, or component names
- keep language concrete enough that the scenario reads like living documentation
### 2. Keep scenarios focused
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
In Dify's suite, this means:
- one capability-focused scenario per feature path
- no long setup chains when existing bootstrap or reusable steps already cover them
- no hidden dependency on another scenario's side effects
### 3. Reuse steps, but only when behavior really matches
Good reuse reduces duplication. Bad reuse hides meaning.
Prefer reuse when:
- the user action is genuinely the same
- the expected outcome is genuinely the same
- the wording stays natural across features
Write a new step when:
- the behavior is materially different
- reusing the old wording would make the scenario misleading
- a supposedly generic step would become an implementation-detail wrapper
### 4. Prefer Cucumber Expressions
Use Cucumber Expressions for parameters unless regex is clearly necessary.
Common examples:
- `{string}` for labels, names, and visible text
- `{int}` for counts
- `{float}` for decimal values
- `{word}` only when the value is truly a single token
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
### 5. Keep step definitions thin and meaningful
Step definitions are glue between Gherkin and automation, not a second abstraction language.
For Dify:
- type `this` as `DifyWorld`
- use `async function`
- keep each step to one user-visible action or assertion
- rely on `DifyWorld` and existing support code for shared context
- avoid leaking cross-scenario state
### 6. Use tags intentionally
Tags should communicate run scope or session semantics, not become ad hoc metadata.
In Dify's current suite:
- capability tags group related scenarios
- `@unauthenticated` changes session behavior
- `@authenticated` is descriptive/selective, not a behavior switch by itself
- `@fresh` belongs to reset/full-install flows only
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
## Review Questions
- Does the scenario read like a real example of product behavior?
- Are the steps behavior-oriented instead of implementation-oriented?
- Is a reused step still truthful in this feature?
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
- Would a new reader understand the outcome without opening the step-definition file?

View File

@@ -0,0 +1,96 @@
# Playwright Best Practices For Dify E2E
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
Official sources:
- https://playwright.dev/docs/best-practices
- https://playwright.dev/docs/locators
- https://playwright.dev/docs/test-assertions
- https://playwright.dev/docs/browser-contexts
## What Matters Most
### 1. Keep scenarios isolated
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
Apply it like this:
- do not depend on another scenario having run first
- do not persist ad hoc scenario state outside `DifyWorld`
- do not couple ordinary scenarios to `@fresh` behavior
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
### 2. Prefer user-facing locators
Playwright recommends built-in locators that reflect what users perceive on the page.
Preferred order in this repository:
1. `getByRole`
2. `getByLabel`
3. `getByPlaceholder`
4. `getByText`
5. `getByTestId` when an explicit test contract is the most stable option
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
Also remember:
- repeated content usually needs scoping to a stable container
- exact text matching is often too brittle when role/name or label already exists
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
### 3. Use web-first assertions
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
Prefer:
- `await expect(page).toHaveURL(...)`
- `await expect(locator).toBeVisible()`
- `await expect(locator).toBeHidden()`
- `await expect(locator).toBeEnabled()`
- `await expect(locator).toHaveText(...)`
Avoid:
- `expect(await locator.isVisible()).toBe(true)`
- custom polling loops for DOM state
- `waitForTimeout` as synchronization
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
### 4. Let actions wait for actionability
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
Good pattern:
- assert a meaningful visible state when that is part of the behavior
- then click/fill/select via locator APIs
Bad pattern:
- stack arbitrary waits before every action
- wait on unstable implementation details instead of the visible state the user cares about
### 5. Match debugging to the current suite
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
- full-page screenshots
- page HTML
- console errors
- page errors
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
## Review Questions
- Would this locator survive DOM refactors that do not change user-visible behavior?
- Is this assertion using Playwright's retrying semantics?
- Is any explicit wait masking a real readiness problem?
- Does this code preserve per-scenario isolation?
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?

View File

@@ -0,0 +1 @@
../../.agents/skills/e2e-cucumber-playwright

View File

@@ -69,8 +69,6 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
@@ -84,7 +82,6 @@ ignore = [
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
@@ -93,29 +90,16 @@ ignore = [
]
[lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@@ -1,5 +1,7 @@
"""Configuration for InterSystems IRIS vector database."""
from typing import Any
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
@@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate IRIS configuration values.
Args:

View File

@@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")

View File

@@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
description: I18nObject
parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: dict | None = None
output_schema: dict[str, Any] | None = None
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
@@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict | None
team_credentials: dict[str, Any] | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
@@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: dict | None = None
original_credentials: dict | None = None
masked_credentials: dict[str, Any] | None = None
original_credentials: dict[str, Any] | None = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")

View File

@@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
id: str
name: str
credentials: dict
credentials: dict[str, Any]
credential_source_type: str | None = None
credential_id: str | None = None

View File

@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.

View File

@@ -735,7 +735,9 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
document_id: str,
after_indexing_status: IndexingStatus,
extra_update_params: dict[Any, Any] | None = None,
):
"""
Update the document indexing status.
@@ -762,7 +764,7 @@ class IndexingRunner:
db.session.commit()
@staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
"""
Update the document segment by document id.
"""

View File

@@ -200,7 +200,7 @@ def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
structured_output_schema: Mapping,
model_parameters: dict,
model_parameters: dict[str, Any],
rules: list[ParameterRule],
):
"""
@@ -224,7 +224,7 @@ def _handle_native_json_schema(
return model_parameters
def _set_response_format(model_parameters: dict, rules: list):
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
"""
Set the appropriate response format parameter based on model rules.
@@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict):
def remove_additional_properties(schema: dict[str, Any]):
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.

View File

@@ -77,7 +77,7 @@ class ModelInstance:
@staticmethod
def _get_load_balancing_manager(
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
) -> Optional["LBModelManager"]:
"""
Get load balancing model credentials

View File

@@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
app_mode: AppMode,
model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str,
inputs: dict,
inputs: dict[str, Any],
query: str | None = None,
context: str | None = None,
histories: str | None = None,
) -> tuple[str, dict]:
) -> tuple[str, dict[str, Any]]:
# get prompt template
prompt_template_config = self.get_prompt_template(
app_mode=app_mode,
@@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
self,
app_mode: AppMode,
pre_prompt: str,
inputs: dict,
inputs: dict[str, Any],
query: str,
context: str | None,
files: Sequence["File"],
@@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
self,
app_mode: AppMode,
pre_prompt: str,
inputs: dict,
inputs: dict[str, Any],
query: str,
context: str | None,
files: Sequence["File"],

View File

@@ -856,7 +856,7 @@ class ProviderManager:
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
) -> dict[str, Any]:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,

View File

@@ -174,8 +174,8 @@ class RetrievalService:
cls,
dataset_id: str,
query: str,
external_retrieval_model: dict | None = None,
metadata_filtering_conditions: dict | None = None,
external_retrieval_model: dict[str, Any] | None = None,
metadata_filtering_conditions: dict[str, Any] | None = None,
):
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)

View File

@@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
class Embeddings(ABC):
@@ -20,7 +21,7 @@ class Embeddings(ABC):
raise NotImplementedError
@abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal query."""
raise NotImplementedError

View File

@@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
return _case_routing
def __getattr__(name: str) -> dict:
def __getattr__(name: str) -> Any:
"""Lazy module-level access to routing tables."""
if name == "CASE_ROUTING":
return _get_case_routing()

View File

@@ -198,7 +198,7 @@ class Tool(ABC):
message=ToolInvokeMessage.TextMessage(text=text),
)
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage:
def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage:
"""
create a blob message
@@ -212,7 +212,7 @@ class Tool(ABC):
meta=meta,
)
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage:
"""
create a json message
"""

View File

@@ -149,7 +149,7 @@ class ToolInvokeMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict | list
json_object: dict[str, Any] | list[Any]
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
@@ -337,7 +337,7 @@ class ToolParameter(PluginParameter):
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: str | None = None
# MCP object and array type parameters use this field to store the schema
input_schema: dict | None = None
input_schema: dict[str, Any] | None = None
@classmethod
def get_simple_instance(
@@ -463,7 +463,7 @@ class ToolInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
tool_config: dict[str, Any] | None = None
@classmethod
def empty(cls) -> ToolInvokeMeta:

View File

@@ -85,7 +85,8 @@ class ToolEngine:
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
invocation_meta_dict: dict[str, Any],
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
@@ -200,7 +201,7 @@ class ToolEngine:
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict,
tool_parameters: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,

View File

@@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool):
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
inputs: dict[str, Any],
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool

View File

@@ -277,7 +277,7 @@ class WorkflowTool(Tool):
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""
transform the tool parameters
@@ -323,7 +323,7 @@ class WorkflowTool(Tool):
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]:
"""
extract files from the result
@@ -355,7 +355,7 @@ class WorkflowTool(Tool):
return result, files
def _update_file_mapping(self, file_dict: dict):
def _update_file_mapping(self, file_dict: dict[str, Any]):
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
match transfer_method:

View File

@@ -43,15 +43,20 @@ class IndexProcessorProtocol(Protocol):
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
summary_index_setting: dict[str, Any] | None = None,
) -> IndexingResultDict: ...
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: dict[str, Any] | None,
) -> Preview: ...
class SummaryIndexServiceProtocol(Protocol):
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None
) -> None: ...

View File

@@ -1,4 +1,4 @@
from typing import Literal, Union
from typing import Any, Literal, Union
from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import NodeType
@@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData):
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details")
visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution")

View File

@@ -75,7 +75,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
def generate_file_var(self, param_name: str, file: dict[str, Any]):
file_id = resolve_file_record_id(file.get("reference") or file.get("related_id"))
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
@@ -147,7 +147,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var

View File

@@ -10,6 +10,7 @@ import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Any
import clickzetta
from pydantic import BaseModel, model_validator
@@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
"""Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables,

View File

@@ -65,7 +65,7 @@ class FileMetadata:
return data
@classmethod
def from_dict(cls, data: dict) -> FileMetadata:
def from_dict(cls, data: dict[str, Any]) -> FileMetadata:
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
@@ -459,7 +459,7 @@ class FileLifecycleManager:
newest_file=None,
)
def _create_version_backup(self, filename: str, metadata: dict):
def _create_version_backup(self, filename: str, metadata: dict[str, Any]):
"""Create version backup"""
try:
# Read current file content
@@ -487,7 +487,7 @@ class FileLifecycleManager:
logger.warning("Failed to load metadata: %s", e)
return {}
def _save_metadata(self, metadata_dict: dict):
def _save_metadata(self, metadata_dict: dict[str, Any]):
"""Save metadata file"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)

View File

@@ -3,7 +3,7 @@ import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from typing import Any, Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
@@ -221,7 +221,7 @@ class RedisSubscriptionBase(Subscription):
"""Unsubscribe from the Redis topic using the appropriate command."""
raise NotImplementedError
def _get_message(self) -> dict | None:
def _get_message(self) -> dict[str, Any] | None:
"""Get a message from Redis using the appropriate method."""
raise NotImplementedError

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Any
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@@ -62,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
assert self._pubsub is not None
self._pubsub.unsubscribe(self._topic)
def _get_message(self) -> dict | None:
def _get_message(self) -> dict[str, Any] | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Any
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@@ -60,7 +62,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
assert self._pubsub is not None
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
def _get_message(self) -> dict | None:
def _get_message(self) -> dict[str, Any] | None:
assert self._pubsub is not None
# NOTE(QuantumGhost): this is an issue in
# upstream code. If Sharded PubSub is used with Cluster, the

View File

@@ -1,9 +1,11 @@
from typing import Any
from werkzeug.exceptions import HTTPException
class BaseHTTPException(HTTPException):
error_code: str = "unknown"
data: dict | None = None
data: dict[str, Any] | None = None
def __init__(self, description=None, response=None):
super().__init__(description, response)

View File

@@ -410,7 +410,7 @@ class TokenManager:
token_type: str,
account: "Account | None" = None,
email: str | None = None,
additional_data: dict | None = None,
additional_data: dict[str, Any] | None = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")

View File

@@ -1,4 +1,5 @@
import logging
from typing import Any
import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
@@ -12,7 +13,7 @@ class SendGridClient:
self.sendgrid_api_key = sendgrid_api_key
self._from = _from
def send(self, mail: dict):
def send(self, mail: dict[str, Any]):
logger.debug("Sending email with SendGrid")
_to = ""
try:

View File

@@ -2,6 +2,7 @@ import logging
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Any
from configs import dify_config
@@ -20,7 +21,7 @@ class SMTPClient:
self.use_tls = use_tls
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
def send(self, mail: dict[str, Any]):
smtp: smtplib.SMTP | None = None
local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:

View File

@@ -103,10 +103,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
else:
return dialect.type_descriptor(sa.JSON())
def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
def process_bind_param(
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
) -> dict[str, Any] | list[Any] | None:
return value
def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
def process_result_value(
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
) -> dict[str, Any] | list[Any] | None:
return value

View File

@@ -35,7 +35,7 @@ class AlibabaCloudMySQLVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values.get("host"):
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
if not values.get("port"):

View File

@@ -34,7 +34,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:

View File

@@ -24,7 +24,7 @@ class AnalyticdbVectorBySqlConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:

View File

@@ -59,7 +59,7 @@ class BaiduConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["endpoint"]:
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
if not values["account"]:

View File

@@ -51,7 +51,7 @@ class ClickzettaConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
"""
Validate the configuration values.
"""

View File

@@ -36,7 +36,7 @@ class CouchbaseConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values.get("connection_string"):
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
if not values.get("user"):

View File

@@ -43,7 +43,7 @@ class ElasticSearchConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
use_cloud = values.get("use_cloud", False)
cloud_url = values.get("cloud_url")
@@ -258,7 +258,7 @@ class ElasticSearchVector(BaseVector):
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
index_params: dict[str, Any] | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -43,7 +43,7 @@ class HologresVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values.get("host"):
raise ValueError("config HOLOGRES_HOST is required")
if not values.get("database"):

View File

@@ -44,7 +44,7 @@ class HuaweiCloudVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["hosts"]:
raise ValueError("config HOSTS is required")
return values
@@ -169,7 +169,7 @@ class HuaweiCloudVector(BaseVector):
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
index_params: dict[str, Any] | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -44,7 +44,7 @@ class LindormVectorStoreConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["hosts"]:
raise ValueError("config URL is required")
if not values["username"]:
@@ -336,7 +336,10 @@ class LindormVectorStore(BaseVector):
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
self,
embeddings: list[list[float]],
metadatas: list[dict[str, Any]] | None = None,
index_params: dict[str, Any] | None = None,
):
if not embeddings:
raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'")

View File

@@ -43,7 +43,7 @@ class MatrixoneConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config host is required")
if not values["port"]:

View File

@@ -45,7 +45,7 @@ class MilvusConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
@@ -302,7 +302,10 @@ class MilvusVector(BaseVector):
)
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
self,
embeddings: list[list[float]],
metadatas: list[dict[str, Any]] | None = None,
index_params: dict[str, Any] | None = None,
):
"""
Create a new collection in Milvus with the specified schema and index parameters.

View File

@@ -49,7 +49,7 @@ class OceanBaseVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:

View File

@@ -29,7 +29,7 @@ class OpenGaussConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config OPENGAUSS_HOST is required")
if not values["port"]:

View File

@@ -49,7 +49,7 @@ class OpenSearchConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
@@ -252,7 +252,10 @@ class OpenSearchVector(BaseVector):
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
self,
embeddings: list[list[float]],
metadatas: list[dict[str, Any]] | None = None,
index_params: dict[str, Any] | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -36,7 +36,7 @@ class OracleVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["user"]:
raise ValueError("config ORACLE_USER is required")
if not values["password"]:

View File

@@ -33,7 +33,7 @@ class PgvectoRSConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values["port"]:

View File

@@ -34,7 +34,7 @@ class PGVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
if not values["port"]:

View File

@@ -38,7 +38,7 @@ class RelytConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values["port"]:
@@ -239,7 +239,7 @@ class RelytVector(BaseVector):
self,
embedding: list[float],
k: int = 4,
filter: dict | None = None,
filter: dict[str, Any] | None = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided

View File

@@ -30,7 +30,7 @@ class TableStoreConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["access_key_id"]:
raise ValueError("config ACCESS_KEY_ID is required")
if not values["access_key_secret"]:

View File

@@ -31,7 +31,7 @@ class TiDBVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
if not values["port"]:

View File

@@ -20,7 +20,7 @@ class UpstashVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["url"]:
raise ValueError("Upstash URL is required")
if not values["token"]:

View File

@@ -28,7 +28,7 @@ class VastbaseVectorConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
if not values["host"]:
raise ValueError("config VASTBASE_HOST is required")
if not values["port"]:

View File

@@ -20,7 +20,7 @@ from pydantic import BaseModel, model_validator
from weaviate.classes.data import DataObject
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.exceptions import UnexpectedStatusCodeError
from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateQueryError
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -82,7 +82,7 @@ class WeaviateConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validates that required configuration values are present."""
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
@@ -230,6 +230,8 @@ class WeaviateVector(BaseVector):
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
wc.Property(name="is_summary", data_type=wc.DataType.BOOL),
wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT),
],
vector_config=wc.Configure.Vectors.self_provided(),
)
@@ -262,6 +264,10 @@ class WeaviateVector(BaseVector):
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
if "is_summary" not in existing:
to_add.append(wc.Property(name="is_summary", data_type=wc.DataType.BOOL))
if "original_chunk_id" not in existing:
to_add.append(wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT))
for prop in to_add:
try:
@@ -400,15 +406,27 @@ class WeaviateVector(BaseVector):
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
)
try:
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
)
except WeaviateQueryError:
self._ensure_properties()
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
)
docs: list[Document] = []
for obj in res.objects:
@@ -446,14 +464,25 @@ class WeaviateVector(BaseVector):
top_k = int(kwargs.get("top_k", 4))
res = col.query.bm25(
query=query,
query_properties=[Field.TEXT_KEY.value],
limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
try:
res = col.query.bm25(
query=query,
query_properties=[Field.TEXT_KEY.value],
limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
except WeaviateQueryError:
self._ensure_properties()
res = col.query.bm25(
query=query,
query_properties=[Field.TEXT_KEY.value],
limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
docs: list[Document] = []
for obj in res.objects:

View File

@@ -326,7 +326,7 @@ class TestWeaviateVector(unittest.TestCase):
add_calls = mock_col.config.add_property.call_args_list
added_names = [call.args[0].name for call in add_calls]
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"]
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index", "is_summary", "original_chunk_id"]
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
@@ -346,6 +346,8 @@ class TestWeaviateVector(unittest.TestCase):
SimpleNamespace(name="doc_id"),
SimpleNamespace(name="doc_type"),
SimpleNamespace(name="chunk_index"),
SimpleNamespace(name="is_summary"),
SimpleNamespace(name="original_chunk_id"),
]
mock_cfg = MagicMock()
mock_cfg.properties = existing_props
@@ -383,7 +385,7 @@ class TestWeaviateVector(unittest.TestCase):
with patch.object(weaviate_vector_module.logger, "warning") as mock_warning:
wv._ensure_properties()
assert mock_warning.call_count == 4
assert mock_warning.call_count == 6
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
@@ -484,6 +486,56 @@ class TestWeaviateVector(unittest.TestCase):
assert wv.search_by_vector(query_vector=[0.1] * 3) == []
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
def test_search_by_vector_retries_on_weaviate_query_error(self, mock_weaviate_module):
"""Test that search_by_vector catches WeaviateQueryError, calls _ensure_properties, and retries."""
from weaviate.exceptions import WeaviateQueryError
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# First call raises WeaviateQueryError, second call succeeds
mock_obj = MagicMock()
mock_obj.properties = {"text": "retry result", "document_id": "doc-1"}
mock_obj.metadata.distance = 0.2
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.near_vector.side_effect = [
WeaviateQueryError("missing property", "gRPC"),
mock_result,
]
# Mock _ensure_properties dependencies
mock_cfg = MagicMock()
mock_cfg.properties = [
SimpleNamespace(name="text"),
SimpleNamespace(name="document_id"),
SimpleNamespace(name="doc_id"),
SimpleNamespace(name="doc_type"),
SimpleNamespace(name="chunk_index"),
SimpleNamespace(name="is_summary"),
SimpleNamespace(name="original_chunk_id"),
]
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_vector(query_vector=[0.1] * 3, top_k=1)
assert mock_col.query.near_vector.call_count == 2
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.8)
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_full_text also returns doc_type in document metadata."""
@@ -569,6 +621,56 @@ class TestWeaviateVector(unittest.TestCase):
assert wv.search_by_full_text(query="missing") == []
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_retries_on_weaviate_query_error(self, mock_weaviate_module):
"""Test that search_by_full_text catches WeaviateQueryError, calls _ensure_properties, and retries."""
from weaviate.exceptions import WeaviateQueryError
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
# First call raises WeaviateQueryError, second call succeeds
mock_obj = MagicMock()
mock_obj.properties = {"text": "retry bm25 result", "doc_id": "segment-1"}
mock_obj.vector = {"default": [0.5, 0.6]}
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.bm25.side_effect = [
WeaviateQueryError("missing property", "gRPC"),
mock_result,
]
# Mock _ensure_properties dependencies
mock_cfg = MagicMock()
mock_cfg.properties = [
SimpleNamespace(name="text"),
SimpleNamespace(name="document_id"),
SimpleNamespace(name="doc_id"),
SimpleNamespace(name="doc_type"),
SimpleNamespace(name="chunk_index"),
SimpleNamespace(name="is_summary"),
SimpleNamespace(name="original_chunk_id"),
]
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_full_text(query="retry", top_k=1)
assert mock_col.query.bm25.call_count == 2
assert len(docs) == 1
assert docs[0].page_content == "retry bm25 result"
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
"""Test that add_texts includes doc_type from document metadata in stored properties."""

View File

@@ -3,7 +3,7 @@ import hashlib
import logging
import uuid
from collections.abc import Mapping
from typing import cast
from typing import Any, cast
from urllib.parse import urlparse
from uuid import uuid4
@@ -400,7 +400,7 @@ class AppDslService:
self,
*,
app: App | None,
data: dict,
data: dict[str, Any],
account: Account,
name: str | None = None,
description: str | None = None,
@@ -567,7 +567,7 @@ class AppDslService:
@classmethod
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
cls, *, export_data: dict[str, Any], app_model: App, include_secret: bool, workflow_id: str | None = None
):
"""
Append workflow export data
@@ -620,7 +620,7 @@ class AppDslService:
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
def _append_model_config_export_data(cls, export_data: dict[str, Any], app_model: App):
"""
Append model config export data
:param export_data: export data

View File

@@ -1,3 +1,5 @@
from typing import Any
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
@@ -6,7 +8,7 @@ from models.model import AppMode, AppModelConfigDict
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
def validate_configuration(cls, tenant_id: str, config: dict[str, Any], app_mode: AppMode) -> AppModelConfigDict:
match app_mode:
case AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
class AppService:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
@@ -78,7 +78,7 @@ class AppService:
return app_models
def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App:
"""
Create app
:param tenant_id: tenant id
@@ -389,7 +389,7 @@ class AppService:
"""
app_mode = AppMode.value_of(app_model.mode)
meta: dict = {"tool_icons": {}}
meta: dict[str, Any] = {"tool_icons": {}}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow

View File

@@ -1,4 +1,5 @@
import json
from typing import Any
from sqlalchemy import select
@@ -19,7 +20,7 @@ class ApiKeyAuthService:
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
# Encrypt the api key

View File

@@ -2,7 +2,7 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal, NotRequired, TypedDict
from typing import Any, Literal, NotRequired, TypedDict
import httpx
from pydantic import TypeAdapter
@@ -541,7 +541,7 @@ class BillingService:
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
Returns {"notification_id": str}.
"""
payload: dict = {
payload: dict[str, Any] = {
"contents": contents,
"frequency": frequency,
"status": status,

View File

@@ -318,7 +318,7 @@ class DatasourceProviderService:
self,
tenant_id: str,
datasource_provider_id: DatasourceProviderID,
client_params: dict | None,
client_params: dict[str, Any] | None,
enabled: bool | None,
):
"""
@@ -352,7 +352,7 @@ class DatasourceProviderService:
original_params = (
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
)
new_params: dict = {
new_params: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
@@ -500,7 +500,7 @@ class DatasourceProviderService:
provider_id: DatasourceProviderID,
avatar_url: str | None,
expire_at: int,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
) -> None:
"""
@@ -566,7 +566,7 @@ class DatasourceProviderService:
provider_id: DatasourceProviderID,
avatar_url: str | None,
expire_at: int,
credentials: dict,
credentials: dict[str, Any],
) -> None:
"""
add datasource oauth provider
@@ -634,7 +634,7 @@ class DatasourceProviderService:
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
credentials: dict[str, Any],
) -> None:
"""
validate datasource provider credentials.
@@ -947,7 +947,13 @@ class DatasourceProviderService:
return copy_credentials_list
def update_datasource_credentials(
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
self,
tenant_id: str,
auth_id: str,
provider: str,
plugin_id: str,
credentials: dict[str, Any] | None,
name: str | None,
) -> None:
"""
update datasource credentials.

View File

@@ -1,4 +1,4 @@
from typing import Literal, Union
from typing import Any, Literal, Union
from pydantic import BaseModel
@@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: dict | None = None
params: dict | None = None
headers: dict[str, Any] | None = None
params: dict[str, Any] | None = None

View File

@@ -1,4 +1,4 @@
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, field_validator
@@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
summary_index_setting: dict | None = None
summary_index_setting: dict[str, Any] | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: str | None = None

View File

@@ -1,4 +1,4 @@
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, field_validator
@@ -73,7 +73,7 @@ class KnowledgeConfiguration(BaseModel):
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
# add summary index setting
summary_index_setting: dict | None = None
summary_index_setting: dict[str, Any] | None = None
@field_validator("embedding_model_provider", mode="before")
@classmethod

View File

@@ -45,7 +45,7 @@ class HitTestingService:
query: str,
account: Account,
retrieval_model: dict[str, Any] | None,
external_retrieval_model: dict,
external_retrieval_model: dict[str, Any],
attachment_ids: list | None = None,
limit: int = 10,
):
@@ -125,8 +125,8 @@ class HitTestingService:
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict | None = None,
metadata_filtering_conditions: dict | None = None,
external_retrieval_model: dict[str, Any] | None = None,
metadata_filtering_conditions: dict[str, Any] | None = None,
):
if dataset.provider != "external":
return {

View File

@@ -502,7 +502,7 @@ class ModelLoadBalancingService:
provider: str,
model: str,
model_type: str,
credentials: dict,
credentials: dict[str, Any],
config_id: str | None = None,
):
"""
@@ -561,7 +561,7 @@ class ModelLoadBalancingService:
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
credentials: dict[str, Any],
load_balancing_model_config: LoadBalancingModelConfig | None = None,
model_provider_factory: ModelProviderFactory | None = None,
validate: bool = True,

View File

@@ -1,5 +1,6 @@
import json
import uuid
from typing import Any
from core.plugin.impl.base import BasePluginClient
from extensions.ext_redis import redis_client
@@ -16,7 +17,7 @@ class OAuthProxyService(BasePluginClient):
tenant_id: str,
plugin_id: str,
provider: str,
extra_data: dict = {},
extra_data: dict[str, Any] = {},
credential_id: str | None = None,
):
"""

View File

@@ -5,7 +5,7 @@ import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from typing import cast
from typing import Any, cast
from urllib.parse import urlparse
from uuid import uuid4
@@ -526,7 +526,7 @@ class RagPipelineDslService:
self,
*,
pipeline: Pipeline | None,
data: dict,
data: dict[str, Any],
account: Account,
dependencies: list[PluginDependency] | None = None,
) -> Pipeline:
@@ -660,7 +660,9 @@ class RagPipelineDslService:
return yaml.dump(export_data, allow_unicode=True) # type: ignore
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
def _append_workflow_export_data(
self, *, export_data: dict[str, Any], pipeline: Pipeline, include_secret: bool
) -> None:
"""
Append workflow export data
:param export_data: export data

View File

@@ -2,6 +2,7 @@ import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
import yaml
@@ -154,7 +155,7 @@ class RagPipelineTransformService:
raise ValueError("Unsupported doc form")
return pipeline_yaml
def _deal_file_extensions(self, node: dict):
def _deal_file_extensions(self, node: dict[str, Any]):
file_extensions = node.get("data", {}).get("fileExtensions", [])
if not file_extensions:
return node
@@ -167,7 +168,7 @@ class RagPipelineTransformService:
dataset: Dataset,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
node: dict[str, Any],
):
knowledge_configuration_dict = node.get("data", {})
@@ -191,7 +192,7 @@ class RagPipelineTransformService:
def _create_pipeline(
self,
data: dict,
data: dict[str, Any],
) -> Pipeline:
"""Create a new app or update an existing one."""
pipeline_data = data.get("rag_pipeline", {})
@@ -258,7 +259,7 @@ class RagPipelineTransformService:
db.session.add(pipeline)
return pipeline
def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str):
def _deal_dependencies(self, pipeline_yaml: dict[str, Any], tenant_id: str):
installer_manager = PluginInstaller()
installed_plugins = installer_manager.list_plugins(tenant_id)

View File

@@ -1,6 +1,7 @@
import json
from os import path
from pathlib import Path
from typing import Any
from flask import current_app
@@ -13,7 +14,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
"""
builtin_data: dict | None = None
builtin_data: dict[str, Any] | None = None
def get_type(self) -> str:
return RecommendAppType.BUILDIN
@@ -53,7 +54,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict[str, Any] | None:
"""
Fetch recommended app detail from builtin.
:param app_id: App ID

View File

@@ -1,4 +1,5 @@
import logging
from typing import Any
import httpx
@@ -35,7 +36,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.REMOTE
@classmethod
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict[str, Any] | None:
"""
Fetch recommended app detail from dify official.
:param app_id: App ID
@@ -46,7 +47,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
return None
data: dict = response.json()
data: dict[str, Any] = response.json()
return data
@classmethod
@@ -62,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
if response.status_code != 200:
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
result: dict = response.json()
result: dict[str, Any] = response.json()
if "categories" in result:
result["categories"] = sorted(result["categories"])

View File

@@ -1,3 +1,5 @@
from typing import Any
from sqlalchemy import select
from configs import dify_config
@@ -37,7 +39,7 @@ class RecommendedAppService:
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None:
"""
Get recommend app detail.
:param app_id: app id
@@ -45,7 +47,7 @@ class RecommendedAppService:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
result: dict[str, Any] = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))

View File

@@ -428,7 +428,7 @@ class ToolTransformService:
@staticmethod
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
provider: BuiltinToolProvider, credentials: dict[str, Any]
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,

View File

@@ -3,6 +3,7 @@ import tempfile
import time
import uuid
from pathlib import Path
from typing import Any
import click
import pandas as pd
@@ -51,8 +52,8 @@ def batch_create_segment_to_index_task(
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
dataset_config: dict[str, Any] | None = None
document_config: dict[str, Any] | None = None
with session_factory.create_session() as session:
try:

View File

@@ -679,7 +679,7 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
)
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
def _delete_records(query_sql: str, params: dict[str, Any], delete_func: Callable, name: str) -> None:
while True:
with session_factory.create_session() as session:
rs = session.execute(sa.text(query_sql), params)

View File

@@ -7,6 +7,7 @@ improving performance by offloading storage operations to background workers.
import json
import logging
from typing import Any
from celery import shared_task
from graphon.entities import WorkflowExecution
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
execution_data: dict[str, Any],
tenant_id: str,
app_id: str,
triggered_from: str,

View File

@@ -7,6 +7,7 @@ improving performance by offloading storage operations to background workers.
import json
import logging
from typing import Any
from celery import shared_task
from graphon.entities.workflow_node_execution import (
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
execution_data: dict[str, Any],
tenant_id: str,
app_id: str,
triggered_from: str,

View File

@@ -95,30 +95,6 @@ class TestTextToAudioPayload:
assert payload.streaming is True
# ---------------------------------------------------------------------------
# AudioService Interface Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test AudioService method interfaces exist."""
def test_transcript_asr_method_exists(self):
"""Test that AudioService.transcript_asr exists."""
assert hasattr(AudioService, "transcript_asr")
assert callable(AudioService.transcript_asr)
def test_transcript_tts_method_exists(self):
"""Test that AudioService.transcript_tts exists."""
assert hasattr(AudioService, "transcript_tts")
assert callable(AudioService.transcript_tts)
# ---------------------------------------------------------------------------
# Audio Service Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test suite for AudioService interface methods."""

View File

@@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi:
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
msg_id = uuid4()

View File

@@ -73,11 +73,6 @@ class TestAsyncWorkflowService:
mock_dispatcher = MagicMock()
quota_workflow = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
mock_team_task = MagicMock()
mock_sandbox_task = MagicMock()
with (
patch.object(

View File

@@ -0,0 +1,602 @@
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.entities.model_entities import ModelStatus
from models.provider import ProviderType
from services import model_provider_service as service_module
from services.errors.app_model_config import ProviderNotFoundError
from services.model_provider_service import ModelProviderService
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
manager = MagicMock()
service = ModelProviderService()
service._get_provider_manager = MagicMock(return_value=manager)
return service, manager
def _build_provider_configuration(
*,
provider_name: str = "openai",
supported_model_types: list[ModelType] | None = None,
custom_models: list[Any] | None = None,
custom_config_available: bool = True,
) -> SimpleNamespace:
if supported_model_types is None:
supported_model_types = [ModelType.LLM]
return SimpleNamespace(
provider=SimpleNamespace(
provider=provider_name,
label=I18nObject(en_US=provider_name),
description=None,
icon_small=None,
icon_small_dark=None,
background=None,
help=None,
supported_model_types=supported_model_types,
configurate_methods=[],
provider_credential_schema=None,
model_credential_schema=None,
),
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=SimpleNamespace(
provider=SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="Credential 1",
available_credentials=[],
),
models=custom_models,
can_added_models=[],
),
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
is_custom_configuration_available=lambda: custom_config_available,
)
class TestModelProviderServiceConfiguration:
def test__get_provider_configuration_should_return_configuration_when_provider_exists(self) -> None:
service, manager = _create_service_with_mocked_manager()
provider_configuration = SimpleNamespace(name="provider-config")
manager.get_configurations.return_value = {"openai": provider_configuration}
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
assert result is provider_configuration
def test__get_provider_configuration_should_raise_error_when_provider_is_missing(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_configurations.return_value = {}
with pytest.raises(ProviderNotFoundError, match="does not exist"):
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status(self) -> None:
service, manager = _create_service_with_mocked_manager()
allowed = _build_provider_configuration(
provider_name="openai",
supported_model_types=[ModelType.LLM],
custom_config_available=False,
)
filtered = _build_provider_configuration(
provider_name="embedding",
supported_model_types=[ModelType.TEXT_EMBEDDING],
custom_config_available=True,
)
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert len(result) == 1
assert result[0].provider == "openai"
assert result[0].custom_configuration.status.value == "no-configure"
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context(self) -> None:
service, manager = _create_service_with_mocked_manager()
class _Model:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def model_dump(self) -> dict[str, Any]:
return {
"model": self.model_name,
"label": {"en_US": self.model_name},
"model_type": ModelType.LLM,
"features": [],
"fetch_from": FetchFrom.PREDEFINED_MODEL,
"model_properties": {},
"deprecated": False,
"status": ModelStatus.ACTIVE,
"load_balancing_enabled": False,
"has_invalid_load_balancing_configs": False,
"provider": {
"provider": "openai",
"label": {"en_US": "OpenAI"},
"icon_small": None,
"icon_small_dark": None,
"supported_model_types": [ModelType.LLM],
},
}
provider_configurations = SimpleNamespace(
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
)
manager.get_configurations.return_value = provider_configurations
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
assert len(result) == 2
assert result[0].model == "gpt-4o"
assert result[1].provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(provider="openai")
class TestModelProviderServiceDelegation:
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
[
(
"get_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"get_provider_credential",
{"credential_id": "cred-1"},
{"token": "abc"},
),
(
"validate_provider_credentials",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
"validate_provider_credentials",
({"token": "abc"},),
None,
),
(
"create_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_name": "A",
},
"create_provider_credential",
({"token": "abc"}, "A"),
None,
),
(
"update_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_id": "cred-1",
"credential_name": "B",
},
"update_provider_credential",
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
None,
),
(
"remove_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"delete_provider_credential",
{"credential_id": "cred-1"},
None,
),
(
"switch_active_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"switch_active_provider_credential",
{"credential_id": "cred-1"},
None,
),
],
)
def test_provider_credential_methods_should_delegate_to_provider_configuration(
self,
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
provider_call_kwargs: Any,
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
result = getattr(service, method_name)(**method_kwargs)
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
[
(
"get_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"get_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
{"api_key": "x"},
),
(
"validate_model_credentials",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
},
"validate_custom_model_credentials",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
None,
),
(
"create_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
"create_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
None,
),
(
"update_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
"update_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
None,
),
(
"remove_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"delete_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"switch_active_custom_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"switch_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"add_model_credential_to_model_list",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"add_model_credential_to_model",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"remove_model",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
},
"delete_custom_model",
{"model_type": ModelType.LLM, "model": "gpt-4o"},
None,
),
],
)
def test_custom_model_methods_should_convert_model_type_and_delegate(
self,
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
expected_kwargs: dict[str, Any],
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
result = getattr(service, method_name)(**method_kwargs)
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
if method_name == "get_model_credential":
assert result == {"api_key": "x"}
class TestModelProviderServiceListingsAndDefaults:
def test_get_models_by_model_type_should_group_active_non_deprecated_models(self) -> None:
service, manager = _create_service_with_mocked_manager()
openai_provider = SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
icon_small_dark=None,
)
anthropic_provider = SimpleNamespace(
provider="anthropic",
label=I18nObject(en_US="Anthropic"),
icon_small=None,
icon_small_dark=None,
)
models = [
SimpleNamespace(
provider=openai_provider,
model="gpt-4o",
label=I18nObject(en_US="GPT-4o"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=False,
),
SimpleNamespace(
provider=openai_provider,
model="old-openai",
label=I18nObject(en_US="Old OpenAI"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
SimpleNamespace(
provider=anthropic_provider,
model="old-anthropic",
label=I18nObject(en_US="Old Anthropic"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
]
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
manager.get_configurations.return_value = provider_configurations
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
assert len(result) == 1
assert result[0].provider == "openai"
assert len(result[0].models) == 1
assert result[0].models[0].model == "gpt-4o"
@pytest.mark.parametrize(
("credentials", "schema", "expected_count"),
[
(None, None, 0),
({"api_key": "x"}, None, 0),
(
{"api_key": "x"},
SimpleNamespace(
parameter_rules=[
ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
)
]
),
1,
),
],
)
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
self,
credentials: dict[str, Any] | None,
schema: Any,
expected_count: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
provider_configuration.get_current_credentials.return_value = credentials
provider_configuration.get_model_schema.return_value = schema
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
assert len(result) == expected_count
provider_configuration.get_current_credentials.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
)
if credentials:
provider_configuration.get_model_schema.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
credentials=credentials,
)
else:
provider_configuration.get_model_schema.assert_not_called()
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = SimpleNamespace(
model="gpt-4o",
model_type=ModelType.LLM,
provider=SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
supported_model_types=[ModelType.LLM],
),
)
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert result is not None
assert result.model == "gpt-4o"
assert result.provider.provider == "openai"
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = None
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert result is None
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception(self) -> None:
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.side_effect = RuntimeError("boom")
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
assert result is None
def test_update_default_model_of_model_type_should_delegate_to_provider_manager(self) -> None:
service, manager = _create_service_with_mocked_manager()
service.update_default_model_of_model_type(
tenant_id="tenant-1",
model_type=ModelType.LLM.value,
provider="openai",
model="gpt-4o",
)
manager.update_default_model_record.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.LLM,
provider="openai",
model="gpt-4o",
)
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
factory_instance = MagicMock()
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
factory_constructor = MagicMock(return_value=factory_instance)
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
result = service.get_model_provider_icon(
tenant_id="tenant-1",
provider="openai",
icon_type="icon_small",
lang="en_US",
)
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
assert result == (b"icon-bytes", "image/png")
def test_switch_preferred_provider_should_convert_enum_and_delegate(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
service.switch_preferred_provider(
tenant_id="tenant-1",
provider="openai",
preferred_provider_type=ProviderType.SYSTEM.value,
)
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
@pytest.mark.parametrize(
("method_name", "provider_method_name"),
[
("enable_model", "enable_model"),
("disable_model", "disable_model"),
],
)
def test_model_enablement_methods_should_convert_model_type_and_delegate(
self,
method_name: str,
provider_method_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
getattr(service, method_name)(
tenant_id="tenant-1",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM.value,
)
getattr(provider_configuration, provider_method_name).assert_called_once_with(
model="gpt-4o",
model_type=ModelType.LLM,
)

View File

@@ -85,644 +85,3 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
assert len(custom_models) == 1
# The sanitizer should drop credentials in list response
assert custom_models[0].credentials is None
# === Merged from test_model_provider_service.py ===
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.entities.model_entities import ModelStatus
from models.provider import ProviderType
from services import model_provider_service as service_module
from services.errors.app_model_config import ProviderNotFoundError
from services.model_provider_service import ModelProviderService
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
manager = MagicMock()
service = ModelProviderService()
service._get_provider_manager = MagicMock(return_value=manager)
return service, manager
def _build_provider_configuration(
*,
provider_name: str = "openai",
supported_model_types: list[ModelType] | None = None,
custom_models: list[Any] | None = None,
custom_config_available: bool = True,
) -> SimpleNamespace:
if supported_model_types is None:
supported_model_types = [ModelType.LLM]
return SimpleNamespace(
provider=SimpleNamespace(
provider=provider_name,
label=I18nObject(en_US=provider_name),
description=None,
icon_small=None,
icon_small_dark=None,
background=None,
help=None,
supported_model_types=supported_model_types,
configurate_methods=[],
provider_credential_schema=None,
model_credential_schema=None,
),
preferred_provider_type=ProviderType.CUSTOM,
custom_configuration=SimpleNamespace(
provider=SimpleNamespace(
current_credential_id="cred-1",
current_credential_name="Credential 1",
available_credentials=[],
),
models=custom_models,
can_added_models=[],
),
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
is_custom_configuration_available=lambda: custom_config_available,
)
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
provider_configuration = SimpleNamespace(name="provider-config")
manager.get_configurations.return_value = {"openai": provider_configuration}
# Act
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
# Assert
assert result is provider_configuration
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_configurations.return_value = {}
# Act / Assert
with pytest.raises(ProviderNotFoundError, match="does not exist"):
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
allowed = _build_provider_configuration(
provider_name="openai",
supported_model_types=[ModelType.LLM],
custom_config_available=False,
)
filtered = _build_provider_configuration(
provider_name="embedding",
supported_model_types=[ModelType.TEXT_EMBEDDING],
custom_config_available=True,
)
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
# Act
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert len(result) == 1
assert result[0].provider == "openai"
assert result[0].custom_configuration.status.value == "no-configure"
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
class _Model:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def model_dump(self) -> dict[str, Any]:
return {
"model": self.model_name,
"label": {"en_US": self.model_name},
"model_type": ModelType.LLM,
"features": [],
"fetch_from": FetchFrom.PREDEFINED_MODEL,
"model_properties": {},
"deprecated": False,
"status": ModelStatus.ACTIVE,
"load_balancing_enabled": False,
"has_invalid_load_balancing_configs": False,
"provider": {
"provider": "openai",
"label": {"en_US": "OpenAI"},
"icon_small": None,
"icon_small_dark": None,
"supported_model_types": [ModelType.LLM],
},
}
provider_configurations = SimpleNamespace(
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
)
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
# Assert
assert len(result) == 2
assert result[0].model == "gpt-4o"
assert result[1].provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(provider="openai")
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
[
(
"get_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"get_provider_credential",
{"credential_id": "cred-1"},
{"token": "abc"},
),
(
"validate_provider_credentials",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
"validate_provider_credentials",
({"token": "abc"},),
None,
),
(
"create_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
"create_provider_credential",
({"token": "abc"}, "A"),
None,
),
(
"update_provider_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"credentials": {"token": "abc"},
"credential_id": "cred-1",
"credential_name": "B",
},
"update_provider_credential",
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
None,
),
(
"remove_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"delete_provider_credential",
{"credential_id": "cred-1"},
None,
),
(
"switch_active_provider_credential",
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
"switch_active_provider_credential",
{"credential_id": "cred-1"},
None,
),
],
)
def test_provider_credential_methods_should_delegate_to_provider_configuration(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
provider_call_kwargs: Any,
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
provider_method = getattr(provider_configuration, provider_method_name)
if isinstance(provider_call_kwargs, tuple):
provider_method.assert_called_once_with(*provider_call_kwargs)
elif isinstance(provider_call_kwargs, dict):
provider_method.assert_called_once_with(**provider_call_kwargs)
else:
provider_method.assert_called_once_with(provider_call_kwargs)
if method_name == "get_provider_credential":
assert result == {"token": "abc"}
@pytest.mark.parametrize(
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
[
(
"get_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"get_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
{"api_key": "x"},
),
(
"validate_model_credentials",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
},
"validate_custom_model_credentials",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
None,
),
(
"create_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
"create_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_name": "cred-a",
},
None,
),
(
"update_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
"update_custom_model_credential",
{
"model_type": ModelType.LLM,
"model": "gpt-4o",
"credentials": {"api_key": "x"},
"credential_id": "cred-1",
"credential_name": "cred-b",
},
None,
),
(
"remove_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"delete_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"switch_active_custom_model_credential",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"switch_custom_model_credential",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"add_model_credential_to_model_list",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
"credential_id": "cred-1",
},
"add_model_credential_to_model",
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
None,
),
(
"remove_model",
{
"tenant_id": "tenant-1",
"provider": "openai",
"model_type": ModelType.LLM.value,
"model": "gpt-4o",
},
"delete_custom_model",
{"model_type": ModelType.LLM, "model": "gpt-4o"},
None,
),
],
)
def test_custom_model_methods_should_convert_model_type_and_delegate(
method_name: str,
method_kwargs: dict[str, Any],
provider_method_name: str,
expected_kwargs: dict[str, Any],
provider_return: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
getattr(provider_configuration, provider_method_name).return_value = provider_return
get_provider_config_mock = MagicMock(return_value=provider_configuration)
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
# Act
result = getattr(service, method_name)(**method_kwargs)
# Assert
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
if method_name == "get_model_credential":
assert result == {"api_key": "x"}
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
openai_provider = SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
icon_small_dark=None,
)
anthropic_provider = SimpleNamespace(
provider="anthropic",
label=I18nObject(en_US="Anthropic"),
icon_small=None,
icon_small_dark=None,
)
models = [
SimpleNamespace(
provider=openai_provider,
model="gpt-4o",
label=I18nObject(en_US="GPT-4o"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=False,
),
SimpleNamespace(
provider=openai_provider,
model="old-openai",
label=I18nObject(en_US="Old OpenAI"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
SimpleNamespace(
provider=anthropic_provider,
model="old-anthropic",
label=I18nObject(en_US="Old Anthropic"),
model_type=ModelType.LLM,
features=[],
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=ModelStatus.ACTIVE,
load_balancing_enabled=False,
deprecated=True,
),
]
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
manager.get_configurations.return_value = provider_configurations
# Act
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
assert len(result) == 1
assert result[0].provider == "openai"
assert len(result[0].models) == 1
assert result[0].models[0].model == "gpt-4o"
@pytest.mark.parametrize(
("credentials", "schema", "expected_count"),
[
(None, None, 0),
({"api_key": "x"}, None, 0),
(
{"api_key": "x"},
SimpleNamespace(
parameter_rules=[
ParameterRule(
name="temperature",
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
)
]
),
1,
),
],
)
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
credentials: dict[str, Any] | None,
schema: Any,
expected_count: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
provider_configuration.get_current_credentials.return_value = credentials
provider_configuration.get_model_schema.return_value = schema
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
# Assert
assert len(result) == expected_count
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
if credentials:
provider_configuration.get_model_schema.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
credentials=credentials,
)
else:
provider_configuration.get_model_schema.assert_not_called()
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = SimpleNamespace(
model="gpt-4o",
model_type=ModelType.LLM,
provider=SimpleNamespace(
provider="openai",
label=I18nObject(en_US="OpenAI"),
icon_small=None,
supported_model_types=[ModelType.LLM],
),
)
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is not None
assert result.model == "gpt-4o"
assert result.provider.provider == "openai"
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.return_value = None
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
manager.get_default_model.side_effect = RuntimeError("boom")
# Act
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
# Assert
assert result is None
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
# Arrange
service, manager = _create_service_with_mocked_manager()
# Act
service.update_default_model_of_model_type(
tenant_id="tenant-1",
model_type=ModelType.LLM.value,
provider="openai",
model="gpt-4o",
)
# Assert
manager.update_default_model_record.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.LLM,
provider="openai",
model="gpt-4o",
)
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
factory_instance = MagicMock()
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
factory_constructor = MagicMock(return_value=factory_instance)
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
# Act
result = service.get_model_provider_icon(
tenant_id="tenant-1",
provider="openai",
icon_type="icon_small",
lang="en_US",
)
# Assert
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
assert result == (b"icon-bytes", "image/png")
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
service.switch_preferred_provider(
tenant_id="tenant-1",
provider="openai",
preferred_provider_type=ProviderType.SYSTEM.value,
)
# Assert
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
@pytest.mark.parametrize(
("method_name", "provider_method_name"),
[
("enable_model", "enable_model"),
("disable_model", "disable_model"),
],
)
def test_model_enablement_methods_should_convert_model_type_and_delegate(
method_name: str,
provider_method_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = ModelProviderService()
provider_configuration = MagicMock()
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
# Act
getattr(service, method_name)(
tenant_id="tenant-1",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM.value,
)
# Assert
getattr(provider_configuration, provider_method_name).assert_called_once_with(
model="gpt-4o",
model_type=ModelType.LLM,
)

View File

@@ -12,7 +12,6 @@ This test suite covers all functionality of the current VariableTruncator includ
import functools
import json
import uuid
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
@@ -674,229 +673,3 @@ def test_dummy_variable_truncator_methods():
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False
# === Merged from test_variable_truncator_additional.py ===
from typing import Any
import pytest
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
from graphon.variables.types import SegmentType
from services import variable_truncator as truncator_module
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
class _AbstractPassthrough(BaseTruncator):
def truncate(self, segment: Any) -> TruncationResult:
# Arrange / Act
return super().truncate(segment) # type: ignore[misc]
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
# Arrange / Act
return super().truncate_variable_mapping(v) # type: ignore[misc]
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
# Arrange
passthrough = _AbstractPassthrough()
# Act
truncate_result = passthrough.truncate(StringSegment(value="x"))
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
# Assert
assert truncate_result is None
assert mapping_result is None
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
# Act
truncator = VariableTruncator.default()
# Assert
assert truncator._max_size_bytes == 111
assert truncator._array_element_limit == 7
assert truncator._string_length_limit == 33
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=5)
mapping = {"very_long_key": "value"}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert result == {"very_long_key": "..."}
assert truncated is True
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
mapping = {"seg": StringSegment(value="hello")}
# Act
result, truncated = truncator.truncate_variable_mapping(mapping)
# Assert
assert isinstance(result["seg"], StringSegment)
assert result["seg"].value == "hello"
assert truncated is False
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
(True, False),
(1, False),
(1.5, False),
("x", True),
({"k": "v"}, True),
],
)
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
# Arrange
# Act
result = VariableTruncator._json_value_needs_truncation(value)
# Assert
assert result is expected
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_result = truncator_module._PartResult(
value=StringSegment(value="this is too long"),
value_size=100,
truncated=True,
)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(StringSegment(value="input"))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert not result.result.value.startswith('"')
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator()
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_segment(IntegerSegment(value=1), 10)
def test_calculate_json_size_should_unwrap_segment_values() -> None:
# Arrange
segment = StringSegment(value="abc")
# Act
size = VariableTruncator.calculate_json_size(segment)
# Assert
assert size == VariableTruncator.calculate_json_size("abc")
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
# Arrange
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
size = VariableTruncator.calculate_json_size(updated)
# Assert
assert size > 0
def test_maybe_qa_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
assert VariableTruncator._maybe_qa_structure({}) is False
def test_maybe_parent_child_structure_should_validate_shape() -> None:
# Arrange
# Act / Assert
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
)
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
# Arrange
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
mapping = {"s": StringSegment(value="long-content")}
# Act
result = truncator._truncate_object(mapping, 20)
# Assert
assert result.truncated is True
assert isinstance(result.value["s"], StringSegment)
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=100)
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
# Act
result = truncator._truncate_json_primitives(updated, 100)
# Assert
assert isinstance(result.value, dict)
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
# Arrange
truncator = VariableTruncator()
# Act / Assert
with pytest.raises(AssertionError):
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
truncator = VariableTruncator(max_size_bytes=10)
forced_segment = ObjectSegment(value={"k": "v"})
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
# Act
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
# Assert
assert result.truncated is True
assert isinstance(result.result, StringSegment)

View File

@@ -0,0 +1,174 @@
from collections.abc import Mapping
from typing import Any
import pytest
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
from graphon.variables.types import SegmentType
from services import variable_truncator as truncator_module
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
class _AbstractPassthrough(BaseTruncator):
def truncate(self, segment: Any) -> TruncationResult:
return super().truncate(segment) # type: ignore[misc]
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
return super().truncate_variable_mapping(v) # type: ignore[misc]
class TestBaseTruncatorContract:
def test_base_truncator_methods_should_execute_abstract_placeholders(self) -> None:
passthrough = _AbstractPassthrough()
truncate_result = passthrough.truncate(StringSegment(value="x"))
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
assert truncate_result is None
assert mapping_result is None
class TestVariableTruncatorAdditionalBehavior:
def test_default_should_use_dify_config_limits(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
truncator = VariableTruncator.default()
assert truncator._max_size_bytes == 111
assert truncator._array_element_limit == 7
assert truncator._string_length_limit == 33
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis(self) -> None:
truncator = VariableTruncator(max_size_bytes=5)
mapping = {"very_long_key": "value"}
result, truncated = truncator.truncate_variable_mapping(mapping)
assert result == {"very_long_key": "..."}
assert truncated is True
def test_truncate_variable_mapping_should_handle_segment_values(self) -> None:
truncator = VariableTruncator(max_size_bytes=100)
mapping = {"seg": StringSegment(value="hello")}
result, truncated = truncator.truncate_variable_mapping(mapping)
assert isinstance(result["seg"], StringSegment)
assert result["seg"].value == "hello"
assert truncated is False
@pytest.mark.parametrize(
("value", "expected"),
[
(None, False),
(True, False),
(1, False),
(1.5, False),
("x", True),
({"k": "v"}, True),
],
)
def test_json_value_needs_truncation_should_match_expected_rules(
self,
value: Any,
expected: bool,
) -> None:
result = VariableTruncator._json_value_needs_truncation(value)
assert result is expected
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
truncator = VariableTruncator(max_size_bytes=10)
forced_result = truncator_module._PartResult(
value=StringSegment(value="this is too long"),
value_size=100,
truncated=True,
)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
result = truncator.truncate(StringSegment(value="input"))
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert not result.result.value.startswith('"')
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
truncator = VariableTruncator()
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
with pytest.raises(AssertionError):
truncator._truncate_segment(IntegerSegment(value=1), 10)
def test_calculate_json_size_should_unwrap_segment_values(self) -> None:
segment = StringSegment(value="abc")
size = VariableTruncator.calculate_json_size(segment)
assert size == VariableTruncator.calculate_json_size("abc")
def test_calculate_json_size_should_handle_updated_variable_instances(self) -> None:
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
size = VariableTruncator.calculate_json_size(updated)
assert size > 0
def test_maybe_qa_structure_should_validate_shape(self) -> None:
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
assert VariableTruncator._maybe_qa_structure({}) is False
def test_maybe_parent_child_structure_should_validate_shape(self) -> None:
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
)
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
assert (
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"})
is False
)
def test_truncate_object_should_truncate_segment_values_inside_object(self) -> None:
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
mapping = {"s": StringSegment(value="long-content")}
result = truncator._truncate_object(mapping, 20)
assert result.truncated is True
assert isinstance(result.value["s"], StringSegment)
def test_truncate_json_primitives_should_handle_updated_variable_input(self) -> None:
truncator = VariableTruncator(max_size_bytes=100)
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
result = truncator._truncate_json_primitives(updated, 100)
assert isinstance(result.value, dict)
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type(self) -> None:
truncator = VariableTruncator()
with pytest.raises(AssertionError):
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
truncator = VariableTruncator(max_size_bytes=10)
forced_segment = ObjectSegment(value={"k": "v"})
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
assert result.truncated is True
assert isinstance(result.result, StringSegment)

View File

@@ -559,771 +559,3 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# === Merged from test_webhook_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1", is_debug=True
)
# Assert
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
# Assert
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
# Act / Assert
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
# Arrange
webhook_trigger = MagicMock()
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = MagicMock()
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
# Assert
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
# Act
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
# Assert
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
# Act
result = WebhookService._detect_binary_mimetype(b"binary")
# Assert
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
# Act
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
# Assert
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
# Act
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
# Assert
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
raw_value: str,
param_type: str,
expected: Any,
) -> None:
# Arrange
# Act
result = WebhookService._convert_form_value("param", raw_value, param_type)
# Assert
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
# Act
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
# Assert
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
# Arrange
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters() -> None:
# Arrange
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
# Act
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
# Assert
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
# Arrange
# Act / Assert
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
# Arrange
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
# Act
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
# Assert
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
# Arrange
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act
WebhookService._validate_required_headers(headers, configs)
# Assert
assert True
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
# Arrange
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
# Act / Assert
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
# Arrange
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
# Act
result = WebhookService._validate_http_metadata(webhook_data, node_data)
# Assert
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
# Arrange
headers = {"content-type": "application/json; charset=utf-8"}
# Act
result = WebhookService._extract_content_type(headers)
# Assert
assert result == "application/json"
def test_build_workflow_inputs_should_include_expected_keys() -> None:
# Arrange
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
# Act
result = WebhookService.build_workflow_inputs(webhook_data)
# Assert
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
# Act
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Assert
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
# Act / Assert
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
# Act / Assert
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
# Act / Assert
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
# Act / Assert
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
# Act
WebhookService.sync_webhook_relationships(app, workflow)
# Assert
assert logger_exception_mock.call_count == 1
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
# Arrange
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
# Act
body, status = WebhookService.generate_webhook_response(node_config)
# Assert
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
# Arrange
# Act
webhook_id = WebhookService.generate_webhook_id()
# Assert
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
# Arrange
# Act
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
# Assert
assert result == 123

View File

@@ -0,0 +1,671 @@
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from flask import Flask
from graphon.variables.types import SegmentType
from werkzeug.exceptions import RequestEntityTooLarge
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from models.enums import AppTriggerStatus
from models.model import App
from models.trigger import WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
def _workflow(**kwargs: Any) -> Workflow:
return cast(Workflow, SimpleNamespace(**kwargs))
def _app(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
class TestWebhookServiceLookup:
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_session = MagicMock()
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"key": "value"}}
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
"webhook-1",
is_debug=True,
)
assert got_trigger is webhook_trigger
assert got_workflow is workflow
assert got_node_config == {"data": {"mode": "debug"}}
class TestWebhookServiceExtractionFallbacks:
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
webhook_trigger = MagicMock()
with flask_app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/vnd.custom"},
data="plain content",
):
result = WebhookService.extract_webhook_data(webhook_trigger)
assert result["body"] == {"raw": "plain content"}
warning_mock.assert_called_once()
def test_extract_webhook_data_should_raise_for_request_too_large(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
with pytest.raises(RequestEntityTooLarge):
WebhookService.extract_webhook_data(MagicMock())
def test_extract_octet_stream_body_should_return_none_when_empty_payload(self, flask_app: Flask) -> None:
webhook_trigger = MagicMock()
with flask_app.test_request_context("/webhook", method="POST", data=b""):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
assert body == {"raw": None}
assert files == {}
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = MagicMock()
monkeypatch.setattr(
WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")
)
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
assert body == {"raw": None}
assert files == {}
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
self,
flask_app: Flask,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
body, files = WebhookService._extract_text_body()
assert body == {"raw": ""}
assert files == {}
def test_detect_binary_mimetype_should_fallback_when_magic_raises(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
monkeypatch.setattr(service_module, "magic", fake_magic)
result = WebhookService._detect_binary_mimetype(b"binary")
assert result == "application/octet-stream"
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
file_obj = MagicMock()
file_obj.to_dict.return_value = {"id": "f-1"}
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
uploaded = MagicMock()
uploaded.filename = "file.unknown"
uploaded.content_type = None
uploaded.read.return_value = b"content"
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
assert result == {"f": {"id": "f-1"}}
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
manager = MagicMock()
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
expected_file = MagicMock()
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
assert result is expected_file
manager.create_file_by_raw.assert_called_once()
class TestWebhookServiceValidationAndConversion:
@pytest.mark.parametrize(
("raw_value", "param_type", "expected"),
[
("42", SegmentType.NUMBER, 42),
("3.14", SegmentType.NUMBER, 3.14),
("yes", SegmentType.BOOLEAN, True),
("no", SegmentType.BOOLEAN, False),
],
)
def test_convert_form_value_should_convert_supported_types(
self,
raw_value: str,
param_type: str,
expected: Any,
) -> None:
result = WebhookService._convert_form_value("param", raw_value, param_type)
assert result == expected
def test_convert_form_value_should_raise_for_unsupported_type(self) -> None:
with pytest.raises(ValueError, match="Unsupported type"):
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
warning_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
assert result == {"x": 1}
warning_mock.assert_called_once()
def test_validate_and_convert_value_should_wrap_conversion_errors(self) -> None:
with pytest.raises(ValueError, match="validation failed"):
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
def test_process_parameters_should_raise_when_required_parameter_missing(self) -> None:
raw_params = {"optional": "x"}
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
with pytest.raises(ValueError, match="Required parameter missing"):
WebhookService._process_parameters(raw_params, config, is_form_data=True)
def test_process_parameters_should_include_unconfigured_parameters(self) -> None:
raw_params = {"known": "1", "unknown": "x"}
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
assert result == {"known": 1, "unknown": "x"}
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing(self) -> None:
with pytest.raises(ValueError, match="Required body content missing"):
WebhookService._process_body_parameters(
raw_body={"raw": ""},
body_configs=[WebhookBodyParameter(name="raw", required=True)],
content_type=ContentType.TEXT,
)
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data(self) -> None:
raw_body = {"message": "hello", "extra": "x"}
body_configs = [
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
]
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
assert result == {"message": "hello", "extra": "x"}
def test_validate_required_headers_should_accept_sanitized_header_names(self) -> None:
headers = {"x_api_key": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
WebhookService._validate_required_headers(headers, configs)
def test_validate_required_headers_should_raise_when_required_header_missing(self) -> None:
headers = {"x-other": "123"}
configs = [WebhookParameter(name="x-api-key", required=True)]
with pytest.raises(ValueError, match="Required header missing"):
WebhookService._validate_required_headers(headers, configs)
def test_validate_http_metadata_should_return_content_type_mismatch_error(self) -> None:
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
result = WebhookService._validate_http_metadata(webhook_data, node_data)
assert result["valid"] is False
assert "Content-type mismatch" in result["error"]
def test_extract_content_type_should_fallback_to_lowercase_header_key(self) -> None:
headers = {"content-type": "application/json; charset=utf-8"}
assert WebhookService._extract_content_type(headers) == "application/json"
def test_build_workflow_inputs_should_include_expected_keys(self) -> None:
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
result = WebhookService.build_workflow_inputs(webhook_data)
assert result["webhook_data"] == webhook_data
assert result["webhook_headers"] == {"h": "v"}
assert result["webhook_query_params"] == {"q": 1}
assert result["webhook_body"] == {"b": 2}
class TestWebhookServiceExecutionAndSync:
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
webhook_data = {"body": {"x": 1}}
session = MagicMock()
_patch_session(monkeypatch, session)
end_user = SimpleNamespace(id="end-user-1")
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=end_user),
)
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
monkeypatch.setattr(service_module, "QuotaType", quota_type)
trigger_async_mock = MagicMock()
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
trigger_async_mock.assert_called_once()
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
)
quota_type = SimpleNamespace(
TRIGGER=SimpleNamespace(
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
)
)
monkeypatch.setattr(service_module, "QuotaType", quota_type)
mark_rate_limited_mock = MagicMock()
monkeypatch.setattr(
service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock
)
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
mark_rate_limited_mock.assert_called_once_with("tenant-1")
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
webhook_trigger = _workflow_trigger(
app_id="app-1",
node_id="node-1",
tenant_id="tenant-1",
webhook_id="webhook-1",
)
workflow = _workflow(id="wf-1")
session = MagicMock()
_patch_session(monkeypatch, session)
monkeypatch.setattr(
service_module.EndUserService,
"get_or_create_end_user_by_type",
MagicMock(side_effect=RuntimeError("boom")),
)
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
logger_exception_mock.assert_called_once()
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit(self) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(
walk_nodes=lambda _node_type: [
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
]
)
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
lock = MagicMock()
lock.acquire.return_value = False
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
class _WorkflowWebhookTrigger:
app_id = "app_id"
tenant_id = "tenant_id"
webhook_id = "webhook_id"
node_id = "node_id"
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
self.id = None
self.app_id = app_id
self.tenant_id = tenant_id
self.node_id = node_id
self.webhook_id = webhook_id
self.created_by = created_by
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def __init__(self) -> None:
self.added: list[Any] = []
self.deleted: list[Any] = []
self.commit_count = 0
self.existing_records = [SimpleNamespace(node_id="node-stale")]
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: self.existing_records)
def add(self, obj: Any) -> None:
self.added.append(obj)
def flush(self) -> None:
for idx, obj in enumerate(self.added, start=1):
if obj.id is None:
obj.id = f"rec-{idx}"
def commit(self) -> None:
self.commit_count += 1
def delete(self, obj: Any) -> None:
self.deleted.append(obj)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.return_value = None
fake_session = _Session()
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
redis_set_mock = MagicMock()
redis_delete_mock = MagicMock()
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
_patch_session(monkeypatch, fake_session)
WebhookService.sync_webhook_relationships(app, workflow)
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()
def test_sync_webhook_relationships_should_log_when_lock_release_fails(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
workflow = _workflow(walk_nodes=lambda _node_type: [])
class _Select:
def where(self, *args: Any, **kwargs: Any) -> "_Select":
return self
class _Session:
def scalars(self, _stmt: Any) -> Any:
return SimpleNamespace(all=lambda: [])
def commit(self) -> None:
return None
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
logger_exception_mock = MagicMock()
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
_patch_session(monkeypatch, _Session())
WebhookService.sync_webhook_relationships(app, workflow)
assert logger_exception_mock.call_count == 1
class TestWebhookServiceUtilities:
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json(self) -> None:
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
body, status = WebhookService.generate_webhook_response(node_config)
assert status == 200
assert "message" in body
def test_generate_webhook_id_should_return_24_character_identifier(self) -> None:
webhook_id = WebhookService.generate_webhook_id()
assert isinstance(webhook_id, str)
assert len(webhook_id) == 24
def test_sanitize_key_should_return_original_value_for_non_string_input(self) -> None:
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
assert result == 123

View File

@@ -0,0 +1,262 @@
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from sqlalchemy import Engine
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
from services import workflow_run_service as service_module
from services.workflow_run_service import WorkflowRunService
@pytest.fixture
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
node_repo = MagicMock()
workflow_run_repo = MagicMock()
factory = SimpleNamespace(
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
return node_repo, workflow_run_repo, factory
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
class TestWorkflowRunServiceInitialization:
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
self,
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
service = WorkflowRunService()
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_create_sessionmaker_when_engine_is_provided(
self,
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
class FakeEngine:
pass
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "Engine", FakeEngine)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
engine = cast(Engine, FakeEngine())
service = WorkflowRunService(session_factory=engine)
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
node_repo, workflow_run_repo, factory = repository_factory_mocks
session_factory = MagicMock(name="session_factory")
service = WorkflowRunService(session_factory=session_factory)
assert service._session_factory is session_factory
assert service._node_execution_service_repo is node_repo
assert service._workflow_run_repo is workflow_run_repo
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
class TestWorkflowRunServiceQueries:
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="pagination")
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
result = service.get_paginate_workflow_runs(
app_model=app_model,
args=args,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert result is expected
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
limit=7,
last_id="last-1",
status="succeeded",
)
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
assert result is pagination
assert len(result.data) == 2
assert result.data[0].message_id == "msg-1"
assert result.data[0].conversation_id == "conv-1"
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="workflow_run")
workflow_run_repo.get_workflow_run_by_id.return_value = expected
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
assert result is expected
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
run_id="run-1",
)
def test_get_workflow_runs_count_should_forward_optional_filters(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = {"total": 3, "succeeded": 2}
workflow_run_repo.get_workflow_runs_count.return_value = expected
result = service.get_workflow_runs_count(
app_model=app_model,
status="succeeded",
time_range="7d",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert result == expected
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
status="succeeded",
time_range="7d",
)
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-1")
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
assert result == []
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
class FakeEndUser:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
app_model = _app_model(id="app-1")
expected = [SimpleNamespace(id="exec-1")]
node_repo.get_executions_by_workflow_run.return_value = expected
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-end-user",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-account")
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
node_repo.get_executions_by_workflow_run.return_value = expected
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-account",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id=None)
with pytest.raises(ValueError, match="tenant_id cannot be None"):
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)

View File

@@ -176,300 +176,3 @@ class TestWorkflowRunService:
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
# === Merged from test_workflow_run_service.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
from services import workflow_run_service as service_module
from services.workflow_run_service import WorkflowRunService
@pytest.fixture
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
# Arrange
node_repo = MagicMock()
workflow_run_repo = MagicMock()
factory = SimpleNamespace(
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
# Assert
return node_repo, workflow_run_repo, factory
def _app_model(**kwargs: Any) -> App:
return cast(App, SimpleNamespace(**kwargs))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
# Act
service = WorkflowRunService()
# Assert
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_create_sessionmaker_when_engine_is_provided(
monkeypatch: pytest.MonkeyPatch,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
class FakeEngine:
pass
session_factory = MagicMock(name="session_factory")
sessionmaker_mock = MagicMock(return_value=session_factory)
monkeypatch.setattr(service_module, "Engine", FakeEngine)
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
engine = cast(Engine, FakeEngine())
# Act
service = WorkflowRunService(session_factory=engine)
# Assert
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
assert service._session_factory is session_factory
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
node_repo, workflow_run_repo, factory = repository_factory_mocks
session_factory = MagicMock(name="session_factory")
# Act
service = WorkflowRunService(session_factory=session_factory)
# Assert
assert service._session_factory is session_factory
assert service._node_execution_service_repo is node_repo
assert service._workflow_run_repo is workflow_run_repo
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="pagination")
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
# Act
result = service.get_paginate_workflow_runs(
app_model=app_model,
args=args,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result is expected
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
limit=7,
last_id="last-1",
status="succeeded",
)
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
# Act
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
# Assert
assert result is pagination
assert len(result.data) == 2
assert result.data[0].message_id == "msg-1"
assert result.data[0].conversation_id == "conv-1"
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = MagicMock(name="workflow_run")
workflow_run_repo.get_workflow_run_by_id.return_value = expected
# Act
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
# Assert
assert result is expected
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
run_id="run-1",
)
def test_get_workflow_runs_count_should_forward_optional_filters(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
) -> None:
# Arrange
_, workflow_run_repo, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
expected = {"total": 3, "succeeded": 2}
workflow_run_repo.get_workflow_runs_count.return_value = expected
# Act
result = service.get_workflow_runs_count(
app_model=app_model,
status="succeeded",
time_range="7d",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Assert
assert result == expected
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
tenant_id="tenant-1",
app_id="app-1",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
status="succeeded",
time_range="7d",
)
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-1")
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == []
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
class FakeEndUser:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
app_model = _app_model(id="app-1")
expected = [SimpleNamespace(id="exec-1")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-end-user",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
node_repo, _, _ = repository_factory_mocks
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id="tenant-account")
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
node_repo.get_executions_by_workflow_run.return_value = expected
# Act
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
# Assert
assert result == expected
node_repo.get_executions_by_workflow_run.assert_called_once_with(
tenant_id="tenant-account",
app_id="app-1",
workflow_run_id="run-1",
)
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
app_model = _app_model(id="app-1")
user = _account(current_tenant_id=None)
# Act / Assert
with pytest.raises(ValueError, match="tenant_id cannot be None"):
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)

View File

@@ -3,7 +3,6 @@ import queue
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from itertools import cycle
from threading import Event
import pytest
@@ -223,577 +222,3 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected
# === Merged from test_workflow_event_snapshot_service_additional.py ===
import json
import queue
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"query": "hello"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=1.2,
total_tokens=5,
total_steps=2,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.outputs = {"answer": "ok"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionMaker:
def __init__(self, session: Any) -> None:
self._session = session
def __call__(self) -> _SessionContext:
return _SessionContext(self._session)
class _SubscriptionContext:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def __enter__(self) -> Any:
return self._subscription
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _Topic:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def subscribe(self) -> _SubscriptionContext:
return _SubscriptionContext(self._subscription)
class _StaticSubscription:
def receive(self, timeout: int = 1) -> None:
return None
@dataclass(frozen=True)
class _PauseEntity(WorkflowPauseEntity):
state: bytes
@property
def id(self) -> str:
return "pause-1"
@property
def workflow_execution_id(self) -> str:
return "run-1"
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return datetime(2024, 1, 1, tzinfo=UTC)
def get_state(self) -> bytes:
return self.state
def get_pause_reasons(self) -> list[Any]:
return []
def test_get_message_context_should_return_none_when_no_message() -> None:
# Arrange
session = SimpleNamespace(scalar=MagicMock(return_value=None))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is None
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
# Arrange
message = SimpleNamespace(
id="msg-1",
conversation_id="conv-1",
created_at=None,
answer="answer",
)
session = SimpleNamespace(scalar=MagicMock(return_value=message))
session_maker = _SessionMaker(session)
# Act
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
# Assert
assert result is not None
assert result.created_at == 0
assert result.message_id == "msg-1"
assert result.conversation_id == "conv-1"
assert result.answer == "answer"
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
# Arrange
# Act
result = service_module._load_resumption_context(None)
# Assert
assert result is None
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
# Arrange
pause_entity = _PauseEntity(state=b"not-a-valid-state")
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is None
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
# Arrange
context = _build_resumption_context_additional(task_id="task-ctx")
pause_entity = _PauseEntity(state=context.dumps().encode())
# Act
result = service_module._load_resumption_context(pause_entity)
# Assert
assert result is not None
assert result.get_generate_entity().task_id == "task-ctx"
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
# Arrange
# Act
result = service_module._resolve_task_id(
resumption_context=None,
buffer_state=None,
workflow_run_id="run-1",
)
# Assert
assert result == "run-1"
@pytest.mark.parametrize(
("payload", "expected"),
[
(b'{"event":"node_started"}', {"event": "node_started"}),
(b"invalid-json", None),
(b"[]", None),
],
)
def test_parse_event_message_should_parse_only_json_object(
payload: bytes,
expected: dict[str, Any] | None,
) -> None:
# Arrange
# Act
result = service_module._parse_event_message(payload)
# Assert
assert result == expected
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
# Arrange
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
# Act
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
# Assert
assert is_finished is True
assert paused_without_flag is False
assert paused_with_flag is True
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
# Arrange
payload: dict[str, Any] = {"event": "workflow_started"}
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
# Act
service_module._apply_message_context(payload, context)
# Assert
assert payload["conversation_id"] == "conv-1"
assert payload["message_id"] == "msg-1"
assert payload["created_at"] == 1700000000
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
# Arrange
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-1"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
event = buffer_state.queue.get(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert buffer_state.task_id_hint == "task-1"
assert event["event"] == "node_started"
def test_start_buffering_should_drop_old_event_when_queue_is_full(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
class QueueWithSingleFull:
def __init__(self) -> None:
self._first_put = True
self.items: list[dict[str, Any]] = [{"event": "old"}]
def put_nowait(self, item: dict[str, Any]) -> None:
if self._first_put:
self._first_put = False
raise queue.Full
self.items.append(item)
def get_nowait(self) -> dict[str, Any]:
if not self.items:
raise queue.Empty
return self.items.pop(0)
def empty(self) -> bool:
return len(self.items) == 0
fake_queue = QueueWithSingleFull()
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-2"}'
return None
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert ready is True
assert finished is True
assert fake_queue.items[-1]["task_id"] == "task-2"
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
# Arrange
class Subscription:
def receive(self, timeout: int = 1) -> bytes | None:
raise RuntimeError("subscription failure")
subscription = Subscription()
# Act
buffer_state = service_module._start_buffering(subscription)
finished = buffer_state.done_event.wait(timeout=1)
# Assert
assert finished is True
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(
service_module,
"_get_message_context",
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
)
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
monkeypatch.setattr(
service_module,
"_build_snapshot_events",
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
)
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.ADVANCED_CHAT,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
finished_event = cast(Mapping[str, Any], events[1])
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
assert buffer_state.stop_event.is_set() is True
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
assert called_kwargs["workflow_run_id"] == "run-1"
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
class AlwaysEmptyQueue:
def empty(self) -> bool:
return False
def get(self, timeout: int = 1) -> None:
raise queue.Empty
buffer_state = BufferState(
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
time_values = cycle([0.0, 6.0, 21.0, 26.0])
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
idle_timeout=20.0,
ping_interval=5.0,
)
)
# Assert
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
buffer_state.done_event.set()
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events == [StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Arrange
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
# Act
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
# Assert
assert events[0] == StreamEvent.PING.value
assert snapshot_builder.call_args.kwargs["pause_entity"] is None

View File

@@ -0,0 +1,505 @@
import json
import queue
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from itertools import cycle
from threading import Event
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import StreamEvent
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services import workflow_event_snapshot_service as service_module
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
def _build_workflow_run(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"query": "hello"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=1.2,
total_tokens=5,
total_steps=2,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.outputs = {"answer": "ok"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class _SessionContext:
def __init__(self, session: Any) -> None:
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _SessionMaker:
def __init__(self, session: Any) -> None:
self._session = session
def __call__(self) -> _SessionContext:
return _SessionContext(self._session)
class _SubscriptionContext:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def __enter__(self) -> Any:
return self._subscription
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
class _Topic:
def __init__(self, subscription: Any) -> None:
self._subscription = subscription
def subscribe(self) -> _SubscriptionContext:
return _SubscriptionContext(self._subscription)
class _StaticSubscription:
def receive(self, timeout: int = 1) -> None:
return None
@dataclass(frozen=True)
class _PauseEntity(WorkflowPauseEntity):
state: bytes
@property
def id(self) -> str:
return "pause-1"
@property
def workflow_execution_id(self) -> str:
return "run-1"
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return datetime(2024, 1, 1, tzinfo=UTC)
def get_state(self) -> bytes:
return self.state
def get_pause_reasons(self) -> list[Any]:
return []
class TestWorkflowEventSnapshotHelpers:
def test_get_message_context_should_return_none_when_no_message(self) -> None:
session = SimpleNamespace(scalar=MagicMock(return_value=None))
session_maker = _SessionMaker(session)
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
assert result is None
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp(self) -> None:
message = SimpleNamespace(
id="msg-1",
conversation_id="conv-1",
created_at=None,
answer="answer",
)
session = SimpleNamespace(scalar=MagicMock(return_value=message))
session_maker = _SessionMaker(session)
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
assert result is not None
assert result.created_at == 0
assert result.message_id == "msg-1"
assert result.conversation_id == "conv-1"
assert result.answer == "answer"
def test_load_resumption_context_should_return_none_when_pause_entity_missing(self) -> None:
assert service_module._load_resumption_context(None) is None
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid(self) -> None:
pause_entity = _PauseEntity(state=b"not-a-valid-state")
assert service_module._load_resumption_context(pause_entity) is None
def test_load_resumption_context_should_parse_valid_state_into_context(self) -> None:
context = _build_resumption_context(task_id="task-ctx")
pause_entity = _PauseEntity(state=context.dumps().encode())
result = service_module._load_resumption_context(pause_entity)
assert result is not None
assert result.get_generate_entity().task_id == "task-ctx"
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing(self) -> None:
result = service_module._resolve_task_id(
resumption_context=None,
buffer_state=None,
workflow_run_id="run-1",
)
assert result == "run-1"
@pytest.mark.parametrize(
("payload", "expected"),
[
(b'{"event":"node_started"}', {"event": "node_started"}),
(b"invalid-json", None),
(b"[]", None),
],
)
def test_parse_event_message_should_parse_only_json_object(
self,
payload: bytes,
expected: dict[str, Any] | None,
) -> None:
result = service_module._parse_event_message(payload)
assert result == expected
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events(self) -> None:
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
assert is_finished is True
assert paused_without_flag is False
assert paused_with_flag is True
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
def test_apply_message_context_should_update_payload_when_context_exists(self) -> None:
payload: dict[str, Any] = {"event": "workflow_started"}
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
service_module._apply_message_context(payload, context)
assert payload["conversation_id"] == "conv-1"
assert payload["message_id"] == "msg-1"
assert payload["created_at"] == 1700000000
def test_start_buffering_should_capture_task_id_and_enqueue_event(self) -> None:
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-1"}'
return None
subscription = Subscription()
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
event = buffer_state.queue.get(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
assert ready is True
assert finished is True
assert buffer_state.task_id_hint == "task-1"
assert event["event"] == "node_started"
def test_start_buffering_should_drop_old_event_when_queue_is_full(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
class QueueWithSingleFull:
def __init__(self) -> None:
self._first_put = True
self.items: list[dict[str, Any]] = [{"event": "old"}]
def put_nowait(self, item: dict[str, Any]) -> None:
if self._first_put:
self._first_put = False
raise queue.Full
self.items.append(item)
def get_nowait(self) -> dict[str, Any]:
if not self.items:
raise queue.Empty
return self.items.pop(0)
def empty(self) -> bool:
return len(self.items) == 0
fake_queue = QueueWithSingleFull()
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
class Subscription:
def __init__(self) -> None:
self._calls = 0
def receive(self, timeout: int = 1) -> bytes | None:
self._calls += 1
if self._calls == 1:
return b'{"event":"node_started","task_id":"task-2"}'
return None
subscription = Subscription()
buffer_state = service_module._start_buffering(subscription)
ready = buffer_state.task_id_ready.wait(timeout=1)
buffer_state.stop_event.set()
finished = buffer_state.done_event.wait(timeout=1)
assert ready is True
assert finished is True
assert fake_queue.items[-1]["task_id"] == "task-2"
def test_start_buffering_should_set_done_event_when_subscription_raises(self) -> None:
class Subscription:
def receive(self, timeout: int = 1) -> bytes | None:
raise RuntimeError("subscription failure")
subscription = Subscription()
buffer_state = service_module._start_buffering(subscription)
assert buffer_state.done_event.wait(timeout=1) is True
class TestBuildWorkflowEventStream:
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(
service_module,
"_get_message_context",
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
)
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
monkeypatch.setattr(
service_module,
"_build_snapshot_events",
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
)
events = list(
build_workflow_event_stream(
app_mode=AppMode.ADVANCED_CHAT,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
assert events[0] == StreamEvent.PING.value
finished_event = cast(Mapping[str, Any], events[1])
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
assert buffer_state.stop_event.is_set() is True
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
assert called_kwargs["workflow_run_id"] == "run-1"
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
class AlwaysEmptyQueue:
def empty(self) -> bool:
return False
def get(self, timeout: int = 1) -> None:
raise queue.Empty
buffer_state = BufferState(
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
time_values = cycle([0.0, 6.0, 21.0, 26.0])
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
idle_timeout=20.0,
ping_interval=5.0,
)
)
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
buffer_state.done_event.set()
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
assert events == [StreamEvent.PING.value]
assert buffer_state.stop_event.is_set() is True
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.PAUSED)
topic = _Topic(_StaticSubscription())
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
factory = SimpleNamespace(
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
)
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint="task-1",
)
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
events = list(
build_workflow_event_stream(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=MagicMock(),
)
)
assert events[0] == StreamEvent.PING.value
assert snapshot_builder.call_args.kwargs["pause_entity"] is None

26
api/uv.lock generated
View File

@@ -4704,21 +4704,21 @@ wheels = [
[[package]]
name = "pillow"
version = "12.1.1"
version = "12.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/1f/42/5c74462b4fd957fcd7b13b04fb3205ff8349236ea74c7c375766d6c82288/pillow-12.1.1.tar.gz", hash = "sha256:9ad8fa5937ab05218e2b6a4cff30295ad35afd2f83ac592e68c0d871bb0fdbc4", size = 46980264, upload-time = "2026-02-11T04:23:07.146Z" }
sdist = { url = "https://files.pythonhosted.org/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5", size = 46987819, upload-time = "2026-04-01T14:46:17.687Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/d3/8df65da0d4df36b094351dce696f2989bec731d4f10e743b1c5f4da4d3bf/pillow-12.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab323b787d6e18b3d91a72fc99b1a2c28651e4358749842b8f8dfacd28ef2052", size = 5262803, upload-time = "2026-02-11T04:20:47.653Z" },
{ url = "https://files.pythonhosted.org/packages/d6/71/5026395b290ff404b836e636f51d7297e6c83beceaa87c592718747e670f/pillow-12.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:adebb5bee0f0af4909c30db0d890c773d1a92ffe83da908e2e9e720f8edf3984", size = 4657601, upload-time = "2026-02-11T04:20:49.328Z" },
{ url = "https://files.pythonhosted.org/packages/b1/2e/1001613d941c67442f745aff0f7cc66dd8df9a9c084eb497e6a543ee6f7e/pillow-12.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb66b7cc26f50977108790e2456b7921e773f23db5630261102233eb355a3b79", size = 6234995, upload-time = "2026-02-11T04:20:51.032Z" },
{ url = "https://files.pythonhosted.org/packages/07/26/246ab11455b2549b9233dbd44d358d033a2f780fa9007b61a913c5b2d24e/pillow-12.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aee2810642b2898bb187ced9b349e95d2a7272930796e022efaf12e99dccd293", size = 8045012, upload-time = "2026-02-11T04:20:52.882Z" },
{ url = "https://files.pythonhosted.org/packages/b2/8b/07587069c27be7535ac1fe33874e32de118fbd34e2a73b7f83436a88368c/pillow-12.1.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a0b1cd6232e2b618adcc54d9882e4e662a089d5768cd188f7c245b4c8c44a397", size = 6349638, upload-time = "2026-02-11T04:20:54.444Z" },
{ url = "https://files.pythonhosted.org/packages/ff/79/6df7b2ee763d619cda2fb4fea498e5f79d984dae304d45a8999b80d6cf5c/pillow-12.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7aac39bcf8d4770d089588a2e1dd111cbaa42df5a94be3114222057d68336bd0", size = 7041540, upload-time = "2026-02-11T04:20:55.97Z" },
{ url = "https://files.pythonhosted.org/packages/2c/5e/2ba19e7e7236d7529f4d873bdaf317a318896bac289abebd4bb00ef247f0/pillow-12.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ab174cd7d29a62dd139c44bf74b698039328f45cb03b4596c43473a46656b2f3", size = 6462613, upload-time = "2026-02-11T04:20:57.542Z" },
{ url = "https://files.pythonhosted.org/packages/03/03/31216ec124bb5c3dacd74ce8efff4cc7f52643653bad4825f8f08c697743/pillow-12.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:339ffdcb7cbeaa08221cd401d517d4b1fe7a9ed5d400e4a8039719238620ca35", size = 7166745, upload-time = "2026-02-11T04:20:59.196Z" },
{ url = "https://files.pythonhosted.org/packages/1f/e7/7c4552d80052337eb28653b617eafdef39adfb137c49dd7e831b8dc13bc5/pillow-12.1.1-cp312-cp312-win32.whl", hash = "sha256:5d1f9575a12bed9e9eedd9a4972834b08c97a352bd17955ccdebfeca5913fa0a", size = 6328823, upload-time = "2026-02-11T04:21:01.385Z" },
{ url = "https://files.pythonhosted.org/packages/3d/17/688626d192d7261bbbf98846fc98995726bddc2c945344b65bec3a29d731/pillow-12.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:21329ec8c96c6e979cd0dfd29406c40c1d52521a90544463057d2aaa937d66a6", size = 7033367, upload-time = "2026-02-11T04:21:03.536Z" },
{ url = "https://files.pythonhosted.org/packages/ed/fe/a0ef1f73f939b0eca03ee2c108d0043a87468664770612602c63266a43c4/pillow-12.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:af9a332e572978f0218686636610555ae3defd1633597be015ed50289a03c523", size = 2453811, upload-time = "2026-02-11T04:21:05.116Z" },
{ url = "https://files.pythonhosted.org/packages/58/be/7482c8a5ebebbc6470b3eb791812fff7d5e0216c2be3827b30b8bb6603ed/pillow-12.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2d192a155bbcec180f8564f693e6fd9bccff5a7af9b32e2e4bf8c9c69dbad6b5", size = 5308279, upload-time = "2026-04-01T14:43:13.246Z" },
{ url = "https://files.pythonhosted.org/packages/d8/95/0a351b9289c2b5cbde0bacd4a83ebc44023e835490a727b2a3bd60ddc0f4/pillow-12.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3f40b3c5a968281fd507d519e444c35f0ff171237f4fdde090dd60699458421", size = 4695490, upload-time = "2026-04-01T14:43:15.584Z" },
{ url = "https://files.pythonhosted.org/packages/de/af/4e8e6869cbed569d43c416fad3dc4ecb944cb5d9492defaed89ddd6fe871/pillow-12.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:03e7e372d5240cc23e9f07deca4d775c0817bffc641b01e9c3af208dbd300987", size = 6284462, upload-time = "2026-04-01T14:43:18.268Z" },
{ url = "https://files.pythonhosted.org/packages/e9/9e/c05e19657fd57841e476be1ab46c4d501bffbadbafdc31a6d665f8b737b6/pillow-12.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b86024e52a1b269467a802258c25521e6d742349d760728092e1bc2d135b4d76", size = 8094744, upload-time = "2026-04-01T14:43:20.716Z" },
{ url = "https://files.pythonhosted.org/packages/2b/54/1789c455ed10176066b6e7e6da1b01e50e36f94ba584dc68d9eebfe9156d/pillow-12.2.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7371b48c4fa448d20d2714c9a1f775a81155050d383333e0a6c15b1123dda005", size = 6398371, upload-time = "2026-04-01T14:43:23.443Z" },
{ url = "https://files.pythonhosted.org/packages/43/e3/fdc657359e919462369869f1c9f0e973f353f9a9ee295a39b1fea8ee1a77/pillow-12.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62f5409336adb0663b7caa0da5c7d9e7bdbaae9ce761d34669420c2a801b2780", size = 7087215, upload-time = "2026-04-01T14:43:26.758Z" },
{ url = "https://files.pythonhosted.org/packages/8b/f8/2f6825e441d5b1959d2ca5adec984210f1ec086435b0ed5f52c19b3b8a6e/pillow-12.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:01afa7cf67f74f09523699b4e88c73fb55c13346d212a59a2db1f86b0a63e8c5", size = 6509783, upload-time = "2026-04-01T14:43:29.56Z" },
{ url = "https://files.pythonhosted.org/packages/67/f9/029a27095ad20f854f9dba026b3ea6428548316e057e6fc3545409e86651/pillow-12.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc3d34d4a8fbec3e88a79b92e5465e0f9b842b628675850d860b8bd300b159f5", size = 7212112, upload-time = "2026-04-01T14:43:32.091Z" },
{ url = "https://files.pythonhosted.org/packages/be/42/025cfe05d1be22dbfdb4f264fe9de1ccda83f66e4fc3aac94748e784af04/pillow-12.2.0-cp312-cp312-win32.whl", hash = "sha256:58f62cc0f00fd29e64b29f4fd923ffdb3859c9f9e6105bfc37ba1d08994e8940", size = 6378489, upload-time = "2026-04-01T14:43:34.601Z" },
{ url = "https://files.pythonhosted.org/packages/5d/7b/25a221d2c761c6a8ae21bfa3874988ff2583e19cf8a27bf2fee358df7942/pillow-12.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f84204dee22a783350679a0333981df803dac21a0190d706a50475e361c93f5", size = 7084129, upload-time = "2026-04-01T14:43:37.213Z" },
{ url = "https://files.pythonhosted.org/packages/10/e1/542a474affab20fd4a0f1836cb234e8493519da6b76899e30bcc5d990b8b/pillow-12.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:af73337013e0b3b46f175e79492d96845b16126ddf79c438d7ea7ff27783a414", size = 2463612, upload-time = "2026-04-01T14:43:39.421Z" },
]
[[package]]

View File

@@ -165,3 +165,132 @@ Open the HTML report locally with:
```bash
open cucumber-report/report.html
```
## Writing new scenarios
### Workflow
1. Create a `.feature` file under `features/<capability>/`
2. Add step definitions under `features/step-definitions/<capability>/`
3. Reuse existing steps from `common/` and other definition files before writing new ones
4. Run with `pnpm -C e2e e2e -- --tags @your-tag` to verify
5. Run `pnpm -C e2e check` before committing
### Feature file conventions
Tag every feature or scenario with a capability tag. Add auth tags only when they clarify intent or change the browser session behavior:
```gherkin
@datasets @authenticated
Feature: Create dataset
Scenario: Create a new empty dataset
Given I am signed in as the default E2E admin
When I open the datasets page
...
```
- Capability tags (`@apps`, `@auth`, `@datasets`, …) group related scenarios for selective runs
- Auth/session tags:
- default behavior — scenarios run with the shared authenticated storageState unless marked otherwise
- `@unauthenticated` — uses a clean BrowserContext with no cookies or storage
- `@authenticated` — optional intent tag for readability or selective runs; it does not currently change hook behavior on its own
- `@fresh` — only runs in `e2e:full` mode (requires uninitialized instance)
- `@skip` — excluded from all runs
Keep scenarios short and declarative. Each step should describe **what** the user does, not **how** the UI works.
### Step definition conventions
```typescript
import { When, Then } from '@cucumber/cucumber'
import { expect } from '@playwright/test'
import type { DifyWorld } from '../../support/world'
When('I open the datasets page', async function (this: DifyWorld) {
await this.getPage().goto('/datasets')
})
```
Rules:
- Always type `this` as `DifyWorld` for proper context access
- Use `async function` (not arrow functions — Cucumber binds `this`)
- One step = one user-visible action or one assertion
- Keep steps stateless across scenarios; use `DifyWorld` properties for in-scenario state
### Locator priority
Follow the Playwright recommended locator strategy, in order of preference:
| Priority | Locator | Example | When to use |
| -------- | ------------------ | ----------------------------------------- | ----------------------------------------- |
| 1 | `getByRole` | `getByRole('button', { name: 'Create' })` | Default choice — accessible and resilient |
| 2 | `getByLabel` | `getByLabel('App name')` | Form inputs with visible labels |
| 3 | `getByPlaceholder` | `getByPlaceholder('Enter name')` | Inputs without visible labels |
| 4 | `getByText` | `getByText('Welcome')` | Static text content |
| 5 | `getByTestId` | `getByTestId('workflow-canvas')` | Only when no semantic locator works |
Avoid raw CSS/XPath selectors. They break when the DOM structure changes.
### Assertions
Use `@playwright/test` `expect` — it auto-waits and retries until the condition is met or the timeout expires:
```typescript
// URL assertion
await expect(page).toHaveURL(/\/datasets\/[a-f0-9-]+\/documents/)
// Element visibility
await expect(page.getByRole('button', { name: 'Save' })).toBeVisible()
// Element state
await expect(page.getByRole('button', { name: 'Submit' })).toBeEnabled()
// Negation
await expect(page.getByText('Loading')).not.toBeVisible()
```
Do not use manual `waitForTimeout` or polling loops. If you need a longer wait for a specific assertion, pass `{ timeout: 30_000 }` to the assertion.
### Cucumber expressions
Use Cucumber expression parameter types to extract values from Gherkin steps:
| Type | Pattern | Example step |
| ---------- | ------------- | ---------------------------------- |
| `{string}` | Quoted string | `I select the "Workflow" app type` |
| `{int}` | Integer | `I should see {int} items` |
| `{float}` | Decimal | `the progress is {float} percent` |
| `{word}` | Single word | `I click the {word} tab` |
Prefer `{string}` for UI labels, names, and text content — it maps naturally to Gherkin's quoted values.
### Scoping locators
When the page has multiple similar elements, scope locators to a container:
```typescript
When('I fill in the app name in the dialog', async function (this: DifyWorld) {
const dialog = this.getPage().getByRole('dialog')
await dialog.getByPlaceholder('Give your app a name').fill('My App')
})
```
### Failure diagnostics
The `After` hook automatically captures on failure:
- Full-page screenshot (PNG)
- Page HTML dump
- Console errors and page errors
Artifacts are saved to `cucumber-report/artifacts/` and attached to the HTML report. No extra code needed in step definitions.
## Reusing existing steps
Before writing a new step definition, inspect the existing step definition files first. Reuse a matching step when the wording and behavior already fit, and only add a new step when the scenario needs a genuinely new user action or assertion. Steps in `common/` are designed for broad reuse across all features.
Or browse the step definition files directly:
- `features/step-definitions/common/` — auth guards and navigation assertions shared by all features
- `features/step-definitions/<capability>/` — domain-specific steps scoped to a single feature area

Some files were not shown because too many files have changed in this diff Show More