mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-30 17:50:29 +08:00
Merge branch 'main' into jzh
This commit is contained in:
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
@@ -0,0 +1,79 @@
|
||||
---
|
||||
name: e2e-cucumber-playwright
|
||||
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
|
||||
---
|
||||
|
||||
# Dify E2E Cucumber + Playwright
|
||||
|
||||
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
|
||||
|
||||
## Scope
|
||||
|
||||
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
|
||||
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
|
||||
- Do not use this skill for backend test or API review tasks under `api/`.
|
||||
|
||||
## Read Order
|
||||
|
||||
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
|
||||
2. Read only the files directly involved in the task:
|
||||
- target `.feature` files under `e2e/features/`
|
||||
- related step files under `e2e/features/step-definitions/`
|
||||
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
|
||||
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
|
||||
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
|
||||
- Browser session behavior comes from `features/support/hooks.ts`:
|
||||
- default: authenticated session with shared storage state
|
||||
- `@unauthenticated`: clean browser context
|
||||
- `@authenticated`: readability/selective-run tag only unless implementation changes
|
||||
- `@fresh`: only for `e2e:full*` flows
|
||||
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Rebuild local context.
|
||||
- Inspect the target feature area.
|
||||
- Reuse an existing step when wording and behavior already match.
|
||||
- Add a new step only for a genuinely new user action or assertion.
|
||||
- Keep edits close to the current capability folder unless the step is broadly reusable.
|
||||
2. Write behavior-first scenarios.
|
||||
- Describe user-observable behavior, not DOM mechanics.
|
||||
- Keep each scenario focused on one workflow or outcome.
|
||||
- Keep scenarios independent and re-runnable.
|
||||
3. Write step definitions in the local style.
|
||||
- Keep one step to one user-visible action or one assertion.
|
||||
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
|
||||
- Scope locators to stable containers when the page has repeated elements.
|
||||
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
|
||||
4. Use Playwright in the local style.
|
||||
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
|
||||
- Use web-first `expect(...)` assertions.
|
||||
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
|
||||
5. Validate narrowly.
|
||||
- Run the narrowest tagged scenario or flow that exercises the change.
|
||||
- Run `pnpm -C e2e check`.
|
||||
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does the scenario describe behavior rather than implementation?
|
||||
- Does it fit the current session model, tags, and `DifyWorld` usage?
|
||||
- Should an existing step be reused instead of adding a new one?
|
||||
- Are locators user-facing and assertions web-first?
|
||||
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
|
||||
- Does it document or implement behavior that differs from the real hooks or configuration?
|
||||
|
||||
Lead findings with correctness, flake risk, and architecture drift.
|
||||
|
||||
## References
|
||||
|
||||
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
|
||||
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)
|
||||
@@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "E2E Cucumber + Playwright"
|
||||
short_description: "Write and review Dify E2E scenarios."
|
||||
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."
|
||||
@@ -0,0 +1,93 @@
|
||||
# Cucumber Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://cucumber.io/docs/guides/10-minute-tutorial/
|
||||
- https://cucumber.io/docs/cucumber/step-definitions/
|
||||
- https://cucumber.io/docs/cucumber/cucumber-expressions/
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Treat scenarios as executable specifications
|
||||
|
||||
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- write what the user does and what should happen
|
||||
- avoid UI-internal wording such as selector details, DOM structure, or component names
|
||||
- keep language concrete enough that the scenario reads like living documentation
|
||||
|
||||
### 2. Keep scenarios focused
|
||||
|
||||
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
|
||||
|
||||
In Dify's suite, this means:
|
||||
|
||||
- one capability-focused scenario per feature path
|
||||
- no long setup chains when existing bootstrap or reusable steps already cover them
|
||||
- no hidden dependency on another scenario's side effects
|
||||
|
||||
### 3. Reuse steps, but only when behavior really matches
|
||||
|
||||
Good reuse reduces duplication. Bad reuse hides meaning.
|
||||
|
||||
Prefer reuse when:
|
||||
|
||||
- the user action is genuinely the same
|
||||
- the expected outcome is genuinely the same
|
||||
- the wording stays natural across features
|
||||
|
||||
Write a new step when:
|
||||
|
||||
- the behavior is materially different
|
||||
- reusing the old wording would make the scenario misleading
|
||||
- a supposedly generic step would become an implementation-detail wrapper
|
||||
|
||||
### 4. Prefer Cucumber Expressions
|
||||
|
||||
Use Cucumber Expressions for parameters unless regex is clearly necessary.
|
||||
|
||||
Common examples:
|
||||
|
||||
- `{string}` for labels, names, and visible text
|
||||
- `{int}` for counts
|
||||
- `{float}` for decimal values
|
||||
- `{word}` only when the value is truly a single token
|
||||
|
||||
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
|
||||
|
||||
### 5. Keep step definitions thin and meaningful
|
||||
|
||||
Step definitions are glue between Gherkin and automation, not a second abstraction language.
|
||||
|
||||
For Dify:
|
||||
|
||||
- type `this` as `DifyWorld`
|
||||
- use `async function`
|
||||
- keep each step to one user-visible action or assertion
|
||||
- rely on `DifyWorld` and existing support code for shared context
|
||||
- avoid leaking cross-scenario state
|
||||
|
||||
### 6. Use tags intentionally
|
||||
|
||||
Tags should communicate run scope or session semantics, not become ad hoc metadata.
|
||||
|
||||
In Dify's current suite:
|
||||
|
||||
- capability tags group related scenarios
|
||||
- `@unauthenticated` changes session behavior
|
||||
- `@authenticated` is descriptive/selective, not a behavior switch by itself
|
||||
- `@fresh` belongs to reset/full-install flows only
|
||||
|
||||
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Does the scenario read like a real example of product behavior?
|
||||
- Are the steps behavior-oriented instead of implementation-oriented?
|
||||
- Is a reused step still truthful in this feature?
|
||||
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
|
||||
- Would a new reader understand the outcome without opening the step-definition file?
|
||||
@@ -0,0 +1,96 @@
|
||||
# Playwright Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://playwright.dev/docs/best-practices
|
||||
- https://playwright.dev/docs/locators
|
||||
- https://playwright.dev/docs/test-assertions
|
||||
- https://playwright.dev/docs/browser-contexts
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Keep scenarios isolated
|
||||
|
||||
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- do not depend on another scenario having run first
|
||||
- do not persist ad hoc scenario state outside `DifyWorld`
|
||||
- do not couple ordinary scenarios to `@fresh` behavior
|
||||
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
|
||||
|
||||
### 2. Prefer user-facing locators
|
||||
|
||||
Playwright recommends built-in locators that reflect what users perceive on the page.
|
||||
|
||||
Preferred order in this repository:
|
||||
|
||||
1. `getByRole`
|
||||
2. `getByLabel`
|
||||
3. `getByPlaceholder`
|
||||
4. `getByText`
|
||||
5. `getByTestId` when an explicit test contract is the most stable option
|
||||
|
||||
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
|
||||
|
||||
Also remember:
|
||||
|
||||
- repeated content usually needs scoping to a stable container
|
||||
- exact text matching is often too brittle when role/name or label already exists
|
||||
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
|
||||
|
||||
### 3. Use web-first assertions
|
||||
|
||||
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
|
||||
|
||||
Prefer:
|
||||
|
||||
- `await expect(page).toHaveURL(...)`
|
||||
- `await expect(locator).toBeVisible()`
|
||||
- `await expect(locator).toBeHidden()`
|
||||
- `await expect(locator).toBeEnabled()`
|
||||
- `await expect(locator).toHaveText(...)`
|
||||
|
||||
Avoid:
|
||||
|
||||
- `expect(await locator.isVisible()).toBe(true)`
|
||||
- custom polling loops for DOM state
|
||||
- `waitForTimeout` as synchronization
|
||||
|
||||
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
|
||||
|
||||
### 4. Let actions wait for actionability
|
||||
|
||||
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
|
||||
|
||||
Good pattern:
|
||||
|
||||
- assert a meaningful visible state when that is part of the behavior
|
||||
- then click/fill/select via locator APIs
|
||||
|
||||
Bad pattern:
|
||||
|
||||
- stack arbitrary waits before every action
|
||||
- wait on unstable implementation details instead of the visible state the user cares about
|
||||
|
||||
### 5. Match debugging to the current suite
|
||||
|
||||
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
|
||||
|
||||
- full-page screenshots
|
||||
- page HTML
|
||||
- console errors
|
||||
- page errors
|
||||
|
||||
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Would this locator survive DOM refactors that do not change user-visible behavior?
|
||||
- Is this assertion using Playwright's retrying semantics?
|
||||
- Is any explicit wait masking a real readiness problem?
|
||||
- Does this code preserve per-scenario isolation?
|
||||
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?
|
||||
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/e2e-cucumber-playwright
|
||||
@@ -69,8 +69,6 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
@@ -84,7 +82,6 @@ ignore = [
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
@@ -93,29 +90,16 @@ ignore = [
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
parameters: list[DatasourceParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||
@@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
|
||||
icon: str | dict
|
||||
label: I18nObjectDict
|
||||
type: str
|
||||
team_credentials: dict | None
|
||||
team_credentials: dict[str, Any] | None
|
||||
is_team_authorization: bool
|
||||
allow_delete: bool
|
||||
datasources: list[Any]
|
||||
@@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
masked_credentials: dict[str, Any] | None = None
|
||||
original_credentials: dict[str, Any] | None = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
|
||||
|
||||
@@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
credential_source_type: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
|
||||
@@ -735,7 +735,9 @@ class IndexingRunner:
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
|
||||
document_id: str,
|
||||
after_indexing_status: IndexingStatus,
|
||||
extra_update_params: dict[Any, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
@@ -762,7 +764,7 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
|
||||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
|
||||
@@ -200,7 +200,7 @@ def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
model_parameters: dict[str, Any],
|
||||
rules: list[ParameterRule],
|
||||
):
|
||||
"""
|
||||
@@ -224,7 +224,7 @@ def _handle_native_json_schema(
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list):
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
@@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict):
|
||||
def remove_additional_properties(schema: dict[str, Any]):
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
@@ -77,7 +77,7 @@ class ModelInstance:
|
||||
|
||||
@staticmethod
|
||||
def _get_load_balancing_manager(
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
|
||||
) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
|
||||
@@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
|
||||
app_mode: AppMode,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str | None = None,
|
||||
context: str | None = None,
|
||||
histories: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
# get prompt template
|
||||
prompt_template_config = self.get_prompt_template(
|
||||
app_mode=app_mode,
|
||||
@@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
@@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
|
||||
@@ -856,7 +856,7 @@ class ProviderManager:
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -174,8 +174,8 @@ class RetrievalService:
|
||||
cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: dict | None = None,
|
||||
metadata_filtering_conditions: dict | None = None,
|
||||
external_retrieval_model: dict[str, Any] | None = None,
|
||||
metadata_filtering_conditions: dict[str, Any] | None = None,
|
||||
):
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
|
||||
@@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
@@ -20,7 +21,7 @@ class Embeddings(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal query."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
|
||||
return _case_routing
|
||||
|
||||
|
||||
def __getattr__(name: str) -> dict:
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Lazy module-level access to routing tables."""
|
||||
if name == "CASE_ROUTING":
|
||||
return _get_case_routing()
|
||||
|
||||
@@ -198,7 +198,7 @@ class Tool(ABC):
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage:
|
||||
def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
@@ -212,7 +212,7 @@ class Tool(ABC):
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
|
||||
@@ -149,7 +149,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict | list
|
||||
json_object: dict[str, Any] | list[Any]
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
@@ -337,7 +337,7 @@ class ToolParameter(PluginParameter):
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: str | None = None
|
||||
# MCP object and array type parameters use this field to store the schema
|
||||
input_schema: dict | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
@@ -463,7 +463,7 @@ class ToolInvokeMeta(BaseModel):
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
tool_config: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> ToolInvokeMeta:
|
||||
|
||||
@@ -85,7 +85,8 @@ class ToolEngine:
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
|
||||
invocation_meta_dict: dict[str, Any],
|
||||
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
|
||||
):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
@@ -200,7 +201,7 @@ class ToolEngine:
|
||||
@staticmethod
|
||||
def _invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
|
||||
@@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool):
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
|
||||
@@ -277,7 +277,7 @@ class WorkflowTool(Tool):
|
||||
session.expunge(app)
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
@@ -323,7 +323,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||
def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
@@ -355,7 +355,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]):
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
match transfer_method:
|
||||
|
||||
@@ -43,15 +43,20 @@ class IndexProcessorProtocol(Protocol):
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: dict[str, Any] | None = None,
|
||||
) -> IndexingResultDict: ...
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: dict[str, Any] | None,
|
||||
) -> Preview: ...
|
||||
|
||||
|
||||
class SummaryIndexServiceProtocol(Protocol):
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None
|
||||
) -> None: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import NodeType
|
||||
@@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData):
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
visual_config: dict | None = Field(default=None, description="Visual configuration details")
|
||||
visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details")
|
||||
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
def generate_file_var(self, param_name: str, file: dict[str, Any]):
|
||||
file_id = resolve_file_record_id(file.get("reference") or file.get("related_id"))
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
@@ -147,7 +147,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
||||
continue
|
||||
elif self.node_data.content_type == ContentType.BINARY:
|
||||
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
|
||||
raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {})
|
||||
file_var = self.generate_file_var(param_name, raw_data)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
|
||||
@@ -10,6 +10,7 @@ import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
|
||||
@@ -65,7 +65,7 @@ class FileMetadata:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> FileMetadata:
|
||||
def from_dict(cls, data: dict[str, Any]) -> FileMetadata:
|
||||
"""Create instance from dictionary"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
@@ -459,7 +459,7 @@ class FileLifecycleManager:
|
||||
newest_file=None,
|
||||
)
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
def _create_version_backup(self, filename: str, metadata: dict[str, Any]):
|
||||
"""Create version backup"""
|
||||
try:
|
||||
# Read current file content
|
||||
@@ -487,7 +487,7 @@ class FileLifecycleManager:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
def _save_metadata(self, metadata_dict: dict[str, Any]):
|
||||
"""Save metadata file"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
|
||||
@@ -3,7 +3,7 @@ import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
from typing import Any, Self
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
@@ -221,7 +221,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
"""Unsubscribe from the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
def _get_message(self) -> dict[str, Any] | None:
|
||||
"""Get a message from Redis using the appropriate method."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@@ -62,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.unsubscribe(self._topic)
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
def _get_message(self) -> dict[str, Any] | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@@ -60,7 +62,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
def _get_message(self) -> dict[str, Any] | None:
|
||||
assert self._pubsub is not None
|
||||
# NOTE(QuantumGhost): this is an issue in
|
||||
# upstream code. If Sharded PubSub is used with Cluster, the
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
class BaseHTTPException(HTTPException):
|
||||
error_code: str = "unknown"
|
||||
data: dict | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
def __init__(self, description=None, response=None):
|
||||
super().__init__(description, response)
|
||||
|
||||
@@ -410,7 +410,7 @@ class TokenManager:
|
||||
token_type: str,
|
||||
account: "Account | None" = None,
|
||||
email: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
additional_data: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
if account is None and email is None:
|
||||
raise ValueError("Account or email must be provided")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import sendgrid
|
||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||
@@ -12,7 +13,7 @@ class SendGridClient:
|
||||
self.sendgrid_api_key = sendgrid_api_key
|
||||
self._from = _from
|
||||
|
||||
def send(self, mail: dict):
|
||||
def send(self, mail: dict[str, Any]):
|
||||
logger.debug("Sending email with SendGrid")
|
||||
_to = ""
|
||||
try:
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@@ -20,7 +21,7 @@ class SMTPClient:
|
||||
self.use_tls = use_tls
|
||||
self.opportunistic_tls = opportunistic_tls
|
||||
|
||||
def send(self, mail: dict):
|
||||
def send(self, mail: dict[str, Any]):
|
||||
smtp: smtplib.SMTP | None = None
|
||||
local_host = dify_config.SMTP_LOCAL_HOSTNAME
|
||||
try:
|
||||
|
||||
@@ -103,10 +103,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
|
||||
else:
|
||||
return dialect.type_descriptor(sa.JSON())
|
||||
|
||||
def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
|
||||
def process_bind_param(
|
||||
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
|
||||
) -> dict[str, Any] | list[Any] | None:
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
|
||||
def process_result_value(
|
||||
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
|
||||
) -> dict[str, Any] | list[Any] | None:
|
||||
return value
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class AlibabaCloudMySQLVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
|
||||
if not values.get("port"):
|
||||
|
||||
@@ -34,7 +34,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
|
||||
@@ -24,7 +24,7 @@ class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -59,7 +59,7 @@ class BaiduConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
|
||||
if not values["account"]:
|
||||
|
||||
@@ -51,7 +51,7 @@ class ClickzettaConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
"""
|
||||
Validate the configuration values.
|
||||
"""
|
||||
|
||||
@@ -36,7 +36,7 @@ class CouchbaseConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values.get("connection_string"):
|
||||
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
|
||||
if not values.get("user"):
|
||||
|
||||
@@ -43,7 +43,7 @@ class ElasticSearchConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
use_cloud = values.get("use_cloud", False)
|
||||
cloud_url = values.get("cloud_url")
|
||||
|
||||
@@ -258,7 +258,7 @@ class ElasticSearchVector(BaseVector):
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[Any, Any]] | None = None,
|
||||
index_params: dict | None = None,
|
||||
index_params: dict[str, Any] | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -43,7 +43,7 @@ class HologresVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config HOLOGRES_HOST is required")
|
||||
if not values.get("database"):
|
||||
|
||||
@@ -44,7 +44,7 @@ class HuaweiCloudVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["hosts"]:
|
||||
raise ValueError("config HOSTS is required")
|
||||
return values
|
||||
@@ -169,7 +169,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[Any, Any]] | None = None,
|
||||
index_params: dict | None = None,
|
||||
index_params: dict[str, Any] | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -44,7 +44,7 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["hosts"]:
|
||||
raise ValueError("config URL is required")
|
||||
if not values["username"]:
|
||||
@@ -336,7 +336,10 @@ class LindormVectorStore(BaseVector):
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
index_params: dict[str, Any] | None = None,
|
||||
):
|
||||
if not embeddings:
|
||||
raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'")
|
||||
|
||||
@@ -43,7 +43,7 @@ class MatrixoneConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config host is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -45,7 +45,7 @@ class MilvusConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
"""
|
||||
Validate the configuration values.
|
||||
Raises ValueError if required fields are missing.
|
||||
@@ -302,7 +302,10 @@ class MilvusVector(BaseVector):
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
index_params: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a new collection in Milvus with the specified schema and index parameters.
|
||||
|
||||
@@ -49,7 +49,7 @@ class OceanBaseVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -29,7 +29,7 @@ class OpenGaussConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config OPENGAUSS_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -49,7 +49,7 @@ class OpenSearchConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get("port"):
|
||||
@@ -252,7 +252,10 @@ class OpenSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
index_params: dict[str, Any] | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -36,7 +36,7 @@ class OracleVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["user"]:
|
||||
raise ValueError("config ORACLE_USER is required")
|
||||
if not values["password"]:
|
||||
|
||||
@@ -33,7 +33,7 @@ class PgvectoRSConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -34,7 +34,7 @@ class PGVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -38,7 +38,7 @@ class RelytConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
if not values["port"]:
|
||||
@@ -239,7 +239,7 @@ class RelytVector(BaseVector):
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: dict | None = None,
|
||||
filter: dict[str, Any] | None = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class TableStoreConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ACCESS_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
|
||||
@@ -31,7 +31,7 @@ class TiDBVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -20,7 +20,7 @@ class UpstashVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["url"]:
|
||||
raise ValueError("Upstash URL is required")
|
||||
if not values["token"]:
|
||||
|
||||
@@ -28,7 +28,7 @@ class VastbaseVectorConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
if not values["host"]:
|
||||
raise ValueError("config VASTBASE_HOST is required")
|
||||
if not values["port"]:
|
||||
|
||||
@@ -20,7 +20,7 @@ from pydantic import BaseModel, model_validator
|
||||
from weaviate.classes.data import DataObject
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter, MetadataQuery
|
||||
from weaviate.exceptions import UnexpectedStatusCodeError
|
||||
from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateQueryError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -82,7 +82,7 @@ class WeaviateConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validates that required configuration values are present."""
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
@@ -230,6 +230,8 @@ class WeaviateVector(BaseVector):
|
||||
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
|
||||
wc.Property(name="is_summary", data_type=wc.DataType.BOOL),
|
||||
wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT),
|
||||
],
|
||||
vector_config=wc.Configure.Vectors.self_provided(),
|
||||
)
|
||||
@@ -262,6 +264,10 @@ class WeaviateVector(BaseVector):
|
||||
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
|
||||
if "chunk_index" not in existing:
|
||||
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
|
||||
if "is_summary" not in existing:
|
||||
to_add.append(wc.Property(name="is_summary", data_type=wc.DataType.BOOL))
|
||||
if "original_chunk_id" not in existing:
|
||||
to_add.append(wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT))
|
||||
|
||||
for prop in to_add:
|
||||
try:
|
||||
@@ -400,15 +406,27 @@ class WeaviateVector(BaseVector):
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
res = col.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
include_vector=False,
|
||||
filters=where,
|
||||
target_vector="default",
|
||||
)
|
||||
try:
|
||||
res = col.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
include_vector=False,
|
||||
filters=where,
|
||||
target_vector="default",
|
||||
)
|
||||
except WeaviateQueryError:
|
||||
self._ensure_properties()
|
||||
res = col.query.near_vector(
|
||||
near_vector=query_vector,
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
return_metadata=MetadataQuery(distance=True),
|
||||
include_vector=False,
|
||||
filters=where,
|
||||
target_vector="default",
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
for obj in res.objects:
|
||||
@@ -446,14 +464,25 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
|
||||
res = col.query.bm25(
|
||||
query=query,
|
||||
query_properties=[Field.TEXT_KEY.value],
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
filters=where,
|
||||
)
|
||||
try:
|
||||
res = col.query.bm25(
|
||||
query=query,
|
||||
query_properties=[Field.TEXT_KEY.value],
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
filters=where,
|
||||
)
|
||||
except WeaviateQueryError:
|
||||
self._ensure_properties()
|
||||
res = col.query.bm25(
|
||||
query=query,
|
||||
query_properties=[Field.TEXT_KEY.value],
|
||||
limit=top_k,
|
||||
return_properties=props,
|
||||
include_vector=True,
|
||||
filters=where,
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
for obj in res.objects:
|
||||
|
||||
@@ -326,7 +326,7 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
|
||||
add_calls = mock_col.config.add_property.call_args_list
|
||||
added_names = [call.args[0].name for call in add_calls]
|
||||
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"]
|
||||
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index", "is_summary", "original_chunk_id"]
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
|
||||
@@ -346,6 +346,8 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
SimpleNamespace(name="is_summary"),
|
||||
SimpleNamespace(name="original_chunk_id"),
|
||||
]
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = existing_props
|
||||
@@ -383,7 +385,7 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
with patch.object(weaviate_vector_module.logger, "warning") as mock_warning:
|
||||
wv._ensure_properties()
|
||||
|
||||
assert mock_warning.call_count == 4
|
||||
assert mock_warning.call_count == 6
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
@@ -484,6 +486,56 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
|
||||
assert wv.search_by_vector(query_vector=[0.1] * 3) == []
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_vector_retries_on_weaviate_query_error(self, mock_weaviate_module):
|
||||
"""Test that search_by_vector catches WeaviateQueryError, calls _ensure_properties, and retries."""
|
||||
from weaviate.exceptions import WeaviateQueryError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# First call raises WeaviateQueryError, second call succeeds
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {"text": "retry result", "document_id": "doc-1"}
|
||||
mock_obj.metadata.distance = 0.2
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
|
||||
mock_col.query.near_vector.side_effect = [
|
||||
WeaviateQueryError("missing property", "gRPC"),
|
||||
mock_result,
|
||||
]
|
||||
|
||||
# Mock _ensure_properties dependencies
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = [
|
||||
SimpleNamespace(name="text"),
|
||||
SimpleNamespace(name="document_id"),
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
SimpleNamespace(name="is_summary"),
|
||||
SimpleNamespace(name="original_chunk_id"),
|
||||
]
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_vector(query_vector=[0.1] * 3, top_k=1)
|
||||
|
||||
assert mock_col.query.near_vector.call_count == 2
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.8)
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
|
||||
"""Test that search_by_full_text also returns doc_type in document metadata."""
|
||||
@@ -569,6 +621,56 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
|
||||
assert wv.search_by_full_text(query="missing") == []
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_search_by_full_text_retries_on_weaviate_query_error(self, mock_weaviate_module):
|
||||
"""Test that search_by_full_text catches WeaviateQueryError, calls _ensure_properties, and retries."""
|
||||
from weaviate.exceptions import WeaviateQueryError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.is_ready.return_value = True
|
||||
mock_weaviate_module.connect_to_custom.return_value = mock_client
|
||||
|
||||
mock_client.collections.exists.return_value = True
|
||||
mock_col = MagicMock()
|
||||
mock_client.collections.use.return_value = mock_col
|
||||
|
||||
# First call raises WeaviateQueryError, second call succeeds
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.properties = {"text": "retry bm25 result", "doc_id": "segment-1"}
|
||||
mock_obj.vector = {"default": [0.5, 0.6]}
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.objects = [mock_obj]
|
||||
|
||||
mock_col.query.bm25.side_effect = [
|
||||
WeaviateQueryError("missing property", "gRPC"),
|
||||
mock_result,
|
||||
]
|
||||
|
||||
# Mock _ensure_properties dependencies
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.properties = [
|
||||
SimpleNamespace(name="text"),
|
||||
SimpleNamespace(name="document_id"),
|
||||
SimpleNamespace(name="doc_id"),
|
||||
SimpleNamespace(name="doc_type"),
|
||||
SimpleNamespace(name="chunk_index"),
|
||||
SimpleNamespace(name="is_summary"),
|
||||
SimpleNamespace(name="original_chunk_id"),
|
||||
]
|
||||
mock_col.config.get.return_value = mock_cfg
|
||||
|
||||
wv = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=self.config,
|
||||
attributes=self.attributes,
|
||||
)
|
||||
docs = wv.search_by_full_text(query="retry", top_k=1)
|
||||
|
||||
assert mock_col.query.bm25.call_count == 2
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "retry bm25 result"
|
||||
|
||||
@patch("dify_vdb_weaviate.weaviate_vector.weaviate")
|
||||
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
|
||||
"""Test that add_texts includes doc_type from document metadata in stored properties."""
|
||||
|
||||
@@ -3,7 +3,7 @@ import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -400,7 +400,7 @@ class AppDslService:
|
||||
self,
|
||||
*,
|
||||
app: App | None,
|
||||
data: dict,
|
||||
data: dict[str, Any],
|
||||
account: Account,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
@@ -567,7 +567,7 @@ class AppDslService:
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(
|
||||
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
|
||||
cls, *, export_data: dict[str, Any], app_model: App, include_secret: bool, workflow_id: str | None = None
|
||||
):
|
||||
"""
|
||||
Append workflow export data
|
||||
@@ -620,7 +620,7 @@ class AppDslService:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
|
||||
def _append_model_config_export_data(cls, export_data: dict[str, Any], app_model: App):
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
@@ -6,7 +8,7 @@ from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
def validate_configuration(cls, tenant_id: str, config: dict[str, Any], app_mode: AppMode) -> AppModelConfigDict:
|
||||
match app_mode:
|
||||
case AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
|
||||
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param user_id: user id
|
||||
@@ -78,7 +78,7 @@ class AppService:
|
||||
|
||||
return app_models
|
||||
|
||||
def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
|
||||
def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App:
|
||||
"""
|
||||
Create app
|
||||
:param tenant_id: tenant id
|
||||
@@ -389,7 +389,7 @@ class AppService:
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
meta: dict = {"tool_icons": {}}
|
||||
meta: dict[str, Any] = {"tool_icons": {}}
|
||||
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -19,7 +20,7 @@ class ApiKeyAuthService:
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
@@ -541,7 +541,7 @@ class BillingService:
|
||||
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||
Returns {"notification_id": str}.
|
||||
"""
|
||||
payload: dict = {
|
||||
payload: dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"frequency": frequency,
|
||||
"status": status,
|
||||
|
||||
@@ -318,7 +318,7 @@ class DatasourceProviderService:
|
||||
self,
|
||||
tenant_id: str,
|
||||
datasource_provider_id: DatasourceProviderID,
|
||||
client_params: dict | None,
|
||||
client_params: dict[str, Any] | None,
|
||||
enabled: bool | None,
|
||||
):
|
||||
"""
|
||||
@@ -352,7 +352,7 @@ class DatasourceProviderService:
|
||||
original_params = (
|
||||
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
|
||||
)
|
||||
new_params: dict = {
|
||||
new_params: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
@@ -500,7 +500,7 @@ class DatasourceProviderService:
|
||||
provider_id: DatasourceProviderID,
|
||||
avatar_url: str | None,
|
||||
expire_at: int,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -566,7 +566,7 @@ class DatasourceProviderService:
|
||||
provider_id: DatasourceProviderID,
|
||||
avatar_url: str | None,
|
||||
expire_at: int,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
add datasource oauth provider
|
||||
@@ -634,7 +634,7 @@ class DatasourceProviderService:
|
||||
name: str | None,
|
||||
tenant_id: str,
|
||||
provider_id: DatasourceProviderID,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
validate datasource provider credentials.
|
||||
@@ -947,7 +947,13 @@ class DatasourceProviderService:
|
||||
return copy_credentials_list
|
||||
|
||||
def update_datasource_credentials(
|
||||
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
|
||||
self,
|
||||
tenant_id: str,
|
||||
auth_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credentials: dict[str, Any] | None,
|
||||
name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update datasource credentials.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
|
||||
class ExternalKnowledgeApiSetting(BaseModel):
|
||||
url: str
|
||||
request_method: str
|
||||
headers: dict | None = None
|
||||
params: dict | None = None
|
||||
headers: dict[str, Any] | None = None
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
@@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
|
||||
data_source: DataSource | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
embedding_model: str | None = None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
@@ -73,7 +73,7 @@ class KnowledgeConfiguration(BaseModel):
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
# add summary index setting
|
||||
summary_index_setting: dict | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -45,7 +45,7 @@ class HitTestingService:
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict[str, Any] | None,
|
||||
external_retrieval_model: dict,
|
||||
external_retrieval_model: dict[str, Any],
|
||||
attachment_ids: list | None = None,
|
||||
limit: int = 10,
|
||||
):
|
||||
@@ -125,8 +125,8 @@ class HitTestingService:
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict | None = None,
|
||||
metadata_filtering_conditions: dict | None = None,
|
||||
external_retrieval_model: dict[str, Any] | None = None,
|
||||
metadata_filtering_conditions: dict[str, Any] | None = None,
|
||||
):
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
|
||||
@@ -502,7 +502,7 @@ class ModelLoadBalancingService:
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
config_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -561,7 +561,7 @@ class ModelLoadBalancingService:
|
||||
provider_configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
load_balancing_model_config: LoadBalancingModelConfig | None = None,
|
||||
model_provider_factory: ModelProviderFactory | None = None,
|
||||
validate: bool = True,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -16,7 +17,7 @@ class OAuthProxyService(BasePluginClient):
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
extra_data: dict = {},
|
||||
extra_data: dict[str, Any] = {},
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -526,7 +526,7 @@ class RagPipelineDslService:
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline | None,
|
||||
data: dict,
|
||||
data: dict[str, Any],
|
||||
account: Account,
|
||||
dependencies: list[PluginDependency] | None = None,
|
||||
) -> Pipeline:
|
||||
@@ -660,7 +660,9 @@ class RagPipelineDslService:
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
|
||||
def _append_workflow_export_data(
|
||||
self, *, export_data: dict[str, Any], pipeline: Pipeline, include_secret: bool
|
||||
) -> None:
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
@@ -154,7 +155,7 @@ class RagPipelineTransformService:
|
||||
raise ValueError("Unsupported doc form")
|
||||
return pipeline_yaml
|
||||
|
||||
def _deal_file_extensions(self, node: dict):
|
||||
def _deal_file_extensions(self, node: dict[str, Any]):
|
||||
file_extensions = node.get("data", {}).get("fileExtensions", [])
|
||||
if not file_extensions:
|
||||
return node
|
||||
@@ -167,7 +168,7 @@ class RagPipelineTransformService:
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
node: dict[str, Any],
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
|
||||
@@ -191,7 +192,7 @@ class RagPipelineTransformService:
|
||||
|
||||
def _create_pipeline(
|
||||
self,
|
||||
data: dict,
|
||||
data: dict[str, Any],
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
pipeline_data = data.get("rag_pipeline", {})
|
||||
@@ -258,7 +259,7 @@ class RagPipelineTransformService:
|
||||
db.session.add(pipeline)
|
||||
return pipeline
|
||||
|
||||
def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str):
|
||||
def _deal_dependencies(self, pipeline_yaml: dict[str, Any], tenant_id: str):
|
||||
installer_manager = PluginInstaller()
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@@ -13,7 +14,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
|
||||
"""
|
||||
|
||||
builtin_data: dict | None = None
|
||||
builtin_data: dict[str, Any] | None = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.BUILDIN
|
||||
@@ -53,7 +54,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return builtin_data.get("recommended_apps", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
|
||||
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch recommended app detail from builtin.
|
||||
:param app_id: App ID
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -35,7 +36,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
|
||||
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch recommended app detail from dify official.
|
||||
:param app_id: App ID
|
||||
@@ -46,7 +47,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
data: dict = response.json()
|
||||
data: dict[str, Any] = response.json()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
@@ -62,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
|
||||
result: dict = response.json()
|
||||
result: dict[str, Any] = response.json()
|
||||
|
||||
if "categories" in result:
|
||||
result["categories"] = sorted(result["categories"])
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
@@ -37,7 +39,7 @@ class RecommendedAppService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
|
||||
def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get recommend app detail.
|
||||
:param app_id: app id
|
||||
@@ -45,7 +47,7 @@ class RecommendedAppService:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
result: dict[str, Any] = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
app_id = result["id"]
|
||||
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
|
||||
@@ -428,7 +428,7 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
provider: BuiltinToolProvider, credentials: dict
|
||||
provider: BuiltinToolProvider, credentials: dict[str, Any]
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
return ToolProviderCredentialApiEntity(
|
||||
id=provider.id,
|
||||
|
||||
@@ -3,6 +3,7 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import pandas as pd
|
||||
@@ -51,8 +52,8 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
# Initialize variables with default values
|
||||
upload_file_key: str | None = None
|
||||
dataset_config: dict | None = None
|
||||
document_config: dict | None = None
|
||||
dataset_config: dict[str, Any] | None = None
|
||||
document_config: dict[str, Any] | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
|
||||
@@ -679,7 +679,7 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
||||
)
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
def _delete_records(query_sql: str, params: dict[str, Any], delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with session_factory.create_session() as session:
|
||||
rs = session.execute(sa.text(query_sql), params)
|
||||
|
||||
@@ -7,6 +7,7 @@ improving performance by offloading storage operations to background workers.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from graphon.entities import WorkflowExecution
|
||||
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
execution_data: dict[str, Any],
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
|
||||
@@ -7,6 +7,7 @@ improving performance by offloading storage operations to background workers.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from graphon.entities.workflow_node_execution import (
|
||||
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_node_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
execution_data: dict[str, Any],
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
|
||||
@@ -95,30 +95,6 @@ class TestTextToAudioPayload:
|
||||
assert payload.streaming is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioService Interface Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test AudioService method interfaces exist."""
|
||||
|
||||
def test_transcript_asr_method_exists(self):
|
||||
"""Test that AudioService.transcript_asr exists."""
|
||||
assert hasattr(AudioService, "transcript_asr")
|
||||
assert callable(AudioService.transcript_asr)
|
||||
|
||||
def test_transcript_tts_method_exists(self):
|
||||
"""Test that AudioService.transcript_tts exists."""
|
||||
assert hasattr(AudioService, "transcript_tts")
|
||||
assert callable(AudioService.transcript_tts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio Service Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test suite for AudioService interface methods."""
|
||||
|
||||
|
||||
@@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi:
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
|
||||
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
|
||||
@@ -73,11 +73,6 @@ class TestAsyncWorkflowService:
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
|
||||
602
api/tests/unit_tests/services/test_model_provider_service.py
Normal file
602
api/tests/unit_tests/services/test_model_provider_service.py
Normal file
@@ -0,0 +1,602 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from models.provider import ProviderType
|
||||
from services import model_provider_service as service_module
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
|
||||
manager = MagicMock()
|
||||
service = ModelProviderService()
|
||||
service._get_provider_manager = MagicMock(return_value=manager)
|
||||
return service, manager
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
provider_name: str = "openai",
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
custom_models: list[Any] | None = None,
|
||||
custom_config_available: bool = True,
|
||||
) -> SimpleNamespace:
|
||||
if supported_model_types is None:
|
||||
supported_model_types = [ModelType.LLM]
|
||||
|
||||
return SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
provider=provider_name,
|
||||
label=I18nObject(en_US=provider_name),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=supported_model_types,
|
||||
configurate_methods=[],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="Credential 1",
|
||||
available_credentials=[],
|
||||
),
|
||||
models=custom_models,
|
||||
can_added_models=[],
|
||||
),
|
||||
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
|
||||
is_custom_configuration_available=lambda: custom_config_available,
|
||||
)
|
||||
|
||||
|
||||
class TestModelProviderServiceConfiguration:
|
||||
def test__get_provider_configuration_should_return_configuration_when_provider_exists(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
provider_configuration = SimpleNamespace(name="provider-config")
|
||||
manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
assert result is provider_configuration
|
||||
|
||||
def test__get_provider_configuration_should_raise_error_when_provider_is_missing(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_configurations.return_value = {}
|
||||
|
||||
with pytest.raises(ProviderNotFoundError, match="does not exist"):
|
||||
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
|
||||
|
||||
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
allowed = _build_provider_configuration(
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
custom_config_available=False,
|
||||
)
|
||||
filtered = _build_provider_configuration(
|
||||
provider_name="embedding",
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
custom_config_available=True,
|
||||
)
|
||||
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
|
||||
|
||||
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert result[0].custom_configuration.status.value == "no-configure"
|
||||
|
||||
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
class _Model:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"label": {"en_US": self.model_name},
|
||||
"model_type": ModelType.LLM,
|
||||
"features": [],
|
||||
"fetch_from": FetchFrom.PREDEFINED_MODEL,
|
||||
"model_properties": {},
|
||||
"deprecated": False,
|
||||
"status": ModelStatus.ACTIVE,
|
||||
"load_balancing_enabled": False,
|
||||
"has_invalid_load_balancing_configs": False,
|
||||
"provider": {
|
||||
"provider": "openai",
|
||||
"label": {"en_US": "OpenAI"},
|
||||
"icon_small": None,
|
||||
"icon_small_dark": None,
|
||||
"supported_model_types": [ModelType.LLM],
|
||||
},
|
||||
}
|
||||
|
||||
provider_configurations = SimpleNamespace(
|
||||
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
|
||||
)
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "gpt-4o"
|
||||
assert result[1].provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderServiceDelegation:
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"get_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
{"token": "abc"},
|
||||
),
|
||||
(
|
||||
"validate_provider_credentials",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
|
||||
"validate_provider_credentials",
|
||||
({"token": "abc"},),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_name": "A",
|
||||
},
|
||||
"create_provider_credential",
|
||||
({"token": "abc"}, "A"),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "B",
|
||||
},
|
||||
"update_provider_credential",
|
||||
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"delete_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"switch_active_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_provider_credential_methods_should_delegate_to_provider_configuration(
|
||||
self,
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
provider_call_kwargs: Any,
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"get_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
{"api_key": "x"},
|
||||
),
|
||||
(
|
||||
"validate_model_credentials",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
"validate_custom_model_credentials",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
"create_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
"update_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"delete_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_custom_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"switch_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"add_model_credential_to_model_list",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"add_model_credential_to_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"delete_custom_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_model_methods_should_convert_model_type_and_delegate(
|
||||
self,
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
expected_kwargs: dict[str, Any],
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
|
||||
if method_name == "get_model_credential":
|
||||
assert result == {"api_key": "x"}
|
||||
|
||||
|
||||
class TestModelProviderServiceListingsAndDefaults:
|
||||
def test_get_models_by_model_type_should_group_active_non_deprecated_models(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
openai_provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
anthropic_provider = SimpleNamespace(
|
||||
provider="anthropic",
|
||||
label=I18nObject(en_US="Anthropic"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
models = [
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="gpt-4o",
|
||||
label=I18nObject(en_US="GPT-4o"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=False,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="old-openai",
|
||||
label=I18nObject(en_US="Old OpenAI"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=anthropic_provider,
|
||||
model="old-anthropic",
|
||||
label=I18nObject(en_US="Old Anthropic"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
]
|
||||
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert len(result[0].models) == 1
|
||||
assert result[0].models[0].model == "gpt-4o"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "schema", "expected_count"),
|
||||
[
|
||||
(None, None, 0),
|
||||
({"api_key": "x"}, None, 0),
|
||||
(
|
||||
{"api_key": "x"},
|
||||
SimpleNamespace(
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
|
||||
self,
|
||||
credentials: dict[str, Any] | None,
|
||||
schema: Any,
|
||||
expected_count: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.get_current_credentials.return_value = credentials
|
||||
provider_configuration.get_model_schema.return_value = schema
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
|
||||
|
||||
assert len(result) == expected_count
|
||||
provider_configuration.get_current_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
)
|
||||
if credentials:
|
||||
provider_configuration.get_model_schema.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
provider_configuration.get_model_schema.assert_not_called()
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = SimpleNamespace(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
provider=SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4o"
|
||||
assert result.provider.provider == "openai"
|
||||
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = None
|
||||
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.side_effect = RuntimeError("boom")
|
||||
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_update_default_model_of_model_type_should_delegate_to_provider_manager(self) -> None:
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
service.update_default_model_of_model_type(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM.value,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
manager.update_default_model_record.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
|
||||
factory_constructor = MagicMock(return_value=factory_instance)
|
||||
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
|
||||
|
||||
result = service.get_model_provider_icon(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
icon_type="icon_small",
|
||||
lang="en_US",
|
||||
)
|
||||
|
||||
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
|
||||
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
assert result == (b"icon-bytes", "image/png")
|
||||
|
||||
def test_switch_preferred_provider_should_convert_enum_and_delegate(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
service.switch_preferred_provider(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
preferred_provider_type=ProviderType.SYSTEM.value,
|
||||
)
|
||||
|
||||
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "provider_method_name"),
|
||||
[
|
||||
("enable_model", "enable_model"),
|
||||
("disable_model", "disable_model"),
|
||||
],
|
||||
)
|
||||
def test_model_enablement_methods_should_convert_model_type_and_delegate(
|
||||
self,
|
||||
method_name: str,
|
||||
provider_method_name: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
getattr(service, method_name)(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM.value,
|
||||
)
|
||||
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
@@ -85,644 +85,3 @@ def test_get_provider_list_strips_credentials(service_with_fake_configurations:
|
||||
assert len(custom_models) == 1
|
||||
# The sanitizer should drop credentials in list response
|
||||
assert custom_models[0].credentials is None
|
||||
|
||||
|
||||
# === Merged from test_model_provider_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from models.provider import ProviderType
|
||||
from services import model_provider_service as service_module
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
def _create_service_with_mocked_manager() -> tuple[ModelProviderService, MagicMock]:
|
||||
manager = MagicMock()
|
||||
service = ModelProviderService()
|
||||
service._get_provider_manager = MagicMock(return_value=manager)
|
||||
return service, manager
|
||||
|
||||
|
||||
def _build_provider_configuration(
|
||||
*,
|
||||
provider_name: str = "openai",
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
custom_models: list[Any] | None = None,
|
||||
custom_config_available: bool = True,
|
||||
) -> SimpleNamespace:
|
||||
if supported_model_types is None:
|
||||
supported_model_types = [ModelType.LLM]
|
||||
return SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
provider=provider_name,
|
||||
label=I18nObject(en_US=provider_name),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=supported_model_types,
|
||||
configurate_methods=[],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=SimpleNamespace(
|
||||
provider=SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="Credential 1",
|
||||
available_credentials=[],
|
||||
),
|
||||
models=custom_models,
|
||||
can_added_models=[],
|
||||
),
|
||||
system_configuration=SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]),
|
||||
is_custom_configuration_available=lambda: custom_config_available,
|
||||
)
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_return_configuration_when_provider_exists() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
provider_configuration = SimpleNamespace(name="provider-config")
|
||||
manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||
|
||||
# Act
|
||||
result = service._get_provider_configuration(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert result is provider_configuration
|
||||
|
||||
|
||||
def test__get_provider_configuration_should_raise_error_when_provider_is_missing() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_configurations.return_value = {}
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ProviderNotFoundError, match="does not exist"):
|
||||
service._get_provider_configuration(tenant_id="tenant-1", provider="missing")
|
||||
|
||||
|
||||
def test_get_provider_list_should_filter_by_model_type_and_build_no_configure_status() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
allowed = _build_provider_configuration(
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
custom_config_available=False,
|
||||
)
|
||||
filtered = _build_provider_configuration(
|
||||
provider_name="embedding",
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
custom_config_available=True,
|
||||
)
|
||||
manager.get_configurations.return_value = {"openai": allowed, "embedding": filtered}
|
||||
|
||||
# Act
|
||||
result = service.get_provider_list(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert result[0].custom_configuration.status.value == "no-configure"
|
||||
|
||||
|
||||
def test_get_models_by_provider_should_wrap_model_entities_with_tenant_context() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
class _Model:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"label": {"en_US": self.model_name},
|
||||
"model_type": ModelType.LLM,
|
||||
"features": [],
|
||||
"fetch_from": FetchFrom.PREDEFINED_MODEL,
|
||||
"model_properties": {},
|
||||
"deprecated": False,
|
||||
"status": ModelStatus.ACTIVE,
|
||||
"load_balancing_enabled": False,
|
||||
"has_invalid_load_balancing_configs": False,
|
||||
"provider": {
|
||||
"provider": "openai",
|
||||
"label": {"en_US": "OpenAI"},
|
||||
"icon_small": None,
|
||||
"icon_small_dark": None,
|
||||
"supported_model_types": [ModelType.LLM],
|
||||
},
|
||||
}
|
||||
|
||||
provider_configurations = SimpleNamespace(
|
||||
get_models=MagicMock(return_value=[_Model("gpt-4o"), _Model("gpt-4o-mini")])
|
||||
)
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_provider(tenant_id="tenant-1", provider="openai")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "gpt-4o"
|
||||
assert result[1].provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "provider_call_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"get_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
{"token": "abc"},
|
||||
),
|
||||
(
|
||||
"validate_provider_credentials",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}},
|
||||
"validate_provider_credentials",
|
||||
({"token": "abc"},),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credentials": {"token": "abc"}, "credential_name": "A"},
|
||||
"create_provider_credential",
|
||||
({"token": "abc"}, "A"),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_provider_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "B",
|
||||
},
|
||||
"update_provider_credential",
|
||||
{"credential_id": "cred-1", "credentials": {"token": "abc"}, "credential_name": "B"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"delete_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_provider_credential",
|
||||
{"tenant_id": "tenant-1", "provider": "openai", "credential_id": "cred-1"},
|
||||
"switch_active_provider_credential",
|
||||
{"credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_provider_credential_methods_should_delegate_to_provider_configuration(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
provider_call_kwargs: Any,
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
provider_method = getattr(provider_configuration, provider_method_name)
|
||||
if isinstance(provider_call_kwargs, tuple):
|
||||
provider_method.assert_called_once_with(*provider_call_kwargs)
|
||||
elif isinstance(provider_call_kwargs, dict):
|
||||
provider_method.assert_called_once_with(**provider_call_kwargs)
|
||||
else:
|
||||
provider_method.assert_called_once_with(provider_call_kwargs)
|
||||
if method_name == "get_provider_credential":
|
||||
assert result == {"token": "abc"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "method_kwargs", "provider_method_name", "expected_kwargs", "provider_return"),
|
||||
[
|
||||
(
|
||||
"get_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"get_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
{"api_key": "x"},
|
||||
),
|
||||
(
|
||||
"validate_model_credentials",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
},
|
||||
"validate_custom_model_credentials",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credentials": {"api_key": "x"}},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"create_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
"create_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_name": "cred-a",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"update_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
"update_custom_model_credential",
|
||||
{
|
||||
"model_type": ModelType.LLM,
|
||||
"model": "gpt-4o",
|
||||
"credentials": {"api_key": "x"},
|
||||
"credential_id": "cred-1",
|
||||
"credential_name": "cred-b",
|
||||
},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"delete_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"switch_active_custom_model_credential",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"switch_custom_model_credential",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"add_model_credential_to_model_list",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
"credential_id": "cred-1",
|
||||
},
|
||||
"add_model_credential_to_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o", "credential_id": "cred-1"},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"remove_model",
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"provider": "openai",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model": "gpt-4o",
|
||||
},
|
||||
"delete_custom_model",
|
||||
{"model_type": ModelType.LLM, "model": "gpt-4o"},
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_model_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
method_kwargs: dict[str, Any],
|
||||
provider_method_name: str,
|
||||
expected_kwargs: dict[str, Any],
|
||||
provider_return: Any,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
getattr(provider_configuration, provider_method_name).return_value = provider_return
|
||||
get_provider_config_mock = MagicMock(return_value=provider_configuration)
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", get_provider_config_mock)
|
||||
|
||||
# Act
|
||||
result = getattr(service, method_name)(**method_kwargs)
|
||||
|
||||
# Assert
|
||||
get_provider_config_mock.assert_called_once_with("tenant-1", "openai")
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(**expected_kwargs)
|
||||
if method_name == "get_model_credential":
|
||||
assert result == {"api_key": "x"}
|
||||
|
||||
|
||||
def test_get_models_by_model_type_should_group_active_non_deprecated_models() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
openai_provider = SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
anthropic_provider = SimpleNamespace(
|
||||
provider="anthropic",
|
||||
label=I18nObject(en_US="Anthropic"),
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
)
|
||||
models = [
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="gpt-4o",
|
||||
label=I18nObject(en_US="GPT-4o"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=False,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=openai_provider,
|
||||
model="old-openai",
|
||||
label=I18nObject(en_US="Old OpenAI"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
SimpleNamespace(
|
||||
provider=anthropic_provider,
|
||||
model="old-anthropic",
|
||||
label=I18nObject(en_US="Old Anthropic"),
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=ModelStatus.ACTIVE,
|
||||
load_balancing_enabled=False,
|
||||
deprecated=True,
|
||||
),
|
||||
]
|
||||
provider_configurations = SimpleNamespace(get_models=MagicMock(return_value=models))
|
||||
manager.get_configurations.return_value = provider_configurations
|
||||
|
||||
# Act
|
||||
result = service.get_models_by_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
assert len(result) == 1
|
||||
assert result[0].provider == "openai"
|
||||
assert len(result[0].models) == 1
|
||||
assert result[0].models[0].model == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "schema", "expected_count"),
|
||||
[
|
||||
(None, None, 0),
|
||||
({"api_key": "x"}, None, 0),
|
||||
(
|
||||
{"api_key": "x"},
|
||||
SimpleNamespace(
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
)
|
||||
]
|
||||
),
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_model_parameter_rules_should_handle_missing_credentials_and_schema(
|
||||
credentials: dict[str, Any] | None,
|
||||
schema: Any,
|
||||
expected_count: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.get_current_credentials.return_value = credentials
|
||||
provider_configuration.get_model_schema.return_value = schema
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
result = service.get_model_parameter_rules(tenant_id="tenant-1", provider="openai", model="gpt-4o")
|
||||
|
||||
# Assert
|
||||
assert len(result) == expected_count
|
||||
provider_configuration.get_current_credentials.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4o")
|
||||
if credentials:
|
||||
provider_configuration.get_model_schema.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
provider_configuration.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_response_when_manager_returns_model() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = SimpleNamespace(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
provider=SimpleNamespace(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
icon_small=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4o"
|
||||
assert result.provider.provider == "openai"
|
||||
manager.get_default_model.assert_called_once_with(tenant_id="tenant-1", model_type=ModelType.LLM)
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_returns_none() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.return_value = None
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_default_model_of_model_type_should_return_none_when_manager_raises_exception() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
manager.get_default_model.side_effect = RuntimeError("boom")
|
||||
|
||||
# Act
|
||||
result = service.get_default_model_of_model_type(tenant_id="tenant-1", model_type=ModelType.LLM.value)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_update_default_model_of_model_type_should_delegate_to_provider_manager() -> None:
|
||||
# Arrange
|
||||
service, manager = _create_service_with_mocked_manager()
|
||||
|
||||
# Act
|
||||
service.update_default_model_of_model_type(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM.value,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Assert
|
||||
manager.update_default_model_record.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.LLM,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_provider_icon_should_fetch_icon_bytes_from_factory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.get_provider_icon.return_value = (b"icon-bytes", "image/png")
|
||||
factory_constructor = MagicMock(return_value=factory_instance)
|
||||
monkeypatch.setattr(service_module, "create_plugin_model_provider_factory", factory_constructor)
|
||||
|
||||
# Act
|
||||
result = service.get_model_provider_icon(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
icon_type="icon_small",
|
||||
lang="en_US",
|
||||
)
|
||||
|
||||
# Assert
|
||||
factory_constructor.assert_called_once_with(tenant_id="tenant-1")
|
||||
factory_instance.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
assert result == (b"icon-bytes", "image/png")
|
||||
|
||||
|
||||
def test_switch_preferred_provider_should_convert_enum_and_delegate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
service.switch_preferred_provider(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
preferred_provider_type=ProviderType.SYSTEM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
provider_configuration.switch_preferred_provider_type.assert_called_once_with(ProviderType.SYSTEM)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "provider_method_name"),
|
||||
[
|
||||
("enable_model", "enable_model"),
|
||||
("disable_model", "disable_model"),
|
||||
],
|
||||
)
|
||||
def test_model_enablement_methods_should_convert_model_type_and_delegate(
|
||||
method_name: str,
|
||||
provider_method_name: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = ModelProviderService()
|
||||
provider_configuration = MagicMock()
|
||||
monkeypatch.setattr(service, "_get_provider_configuration", MagicMock(return_value=provider_configuration))
|
||||
|
||||
# Act
|
||||
getattr(service, method_name)(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM.value,
|
||||
)
|
||||
|
||||
# Assert
|
||||
getattr(provider_configuration, provider_method_name).assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ This test suite covers all functionality of the current VariableTruncator includ
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -674,229 +673,3 @@ def test_dummy_variable_truncator_methods():
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
# === Merged from test_variable_truncator_additional.py ===
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
from services import variable_truncator as truncator_module
|
||||
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
|
||||
|
||||
|
||||
class _AbstractPassthrough(BaseTruncator):
|
||||
def truncate(self, segment: Any) -> TruncationResult:
|
||||
# Arrange / Act
|
||||
return super().truncate(segment) # type: ignore[misc]
|
||||
|
||||
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||
# Arrange / Act
|
||||
return super().truncate_variable_mapping(v) # type: ignore[misc]
|
||||
|
||||
|
||||
def test_base_truncator_methods_should_execute_abstract_placeholders() -> None:
|
||||
# Arrange
|
||||
passthrough = _AbstractPassthrough()
|
||||
|
||||
# Act
|
||||
truncate_result = passthrough.truncate(StringSegment(value="x"))
|
||||
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
|
||||
|
||||
# Assert
|
||||
assert truncate_result is None
|
||||
assert mapping_result is None
|
||||
|
||||
|
||||
def test_default_should_use_dify_config_limits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
|
||||
|
||||
# Act
|
||||
truncator = VariableTruncator.default()
|
||||
|
||||
# Assert
|
||||
assert truncator._max_size_bytes == 111
|
||||
assert truncator._array_element_limit == 7
|
||||
assert truncator._string_length_limit == 33
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=5)
|
||||
mapping = {"very_long_key": "value"}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert result == {"very_long_key": "..."}
|
||||
assert truncated is True
|
||||
|
||||
|
||||
def test_truncate_variable_mapping_should_handle_segment_values() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
mapping = {"seg": StringSegment(value="hello")}
|
||||
|
||||
# Act
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result["seg"], StringSegment)
|
||||
assert result["seg"].value == "hello"
|
||||
assert truncated is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
(True, False),
|
||||
(1, False),
|
||||
(1.5, False),
|
||||
("x", True),
|
||||
({"k": "v"}, True),
|
||||
],
|
||||
)
|
||||
def test_json_value_needs_truncation_should_match_expected_rules(value: Any, expected: bool) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = VariableTruncator._json_value_needs_truncation(value)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_result = truncator_module._PartResult(
|
||||
value=StringSegment(value="this is too long"),
|
||||
value_size=100,
|
||||
truncated=True,
|
||||
)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(StringSegment(value="input"))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert not result.result.value.startswith('"')
|
||||
|
||||
|
||||
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_segment(IntegerSegment(value=1), 10)
|
||||
|
||||
|
||||
def test_calculate_json_size_should_unwrap_segment_values() -> None:
|
||||
# Arrange
|
||||
segment = StringSegment(value="abc")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(segment)
|
||||
|
||||
# Assert
|
||||
assert size == VariableTruncator.calculate_json_size("abc")
|
||||
|
||||
|
||||
def test_calculate_json_size_should_handle_updated_variable_instances() -> None:
|
||||
# Arrange
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
size = VariableTruncator.calculate_json_size(updated)
|
||||
|
||||
# Assert
|
||||
assert size > 0
|
||||
|
||||
|
||||
def test_maybe_qa_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
|
||||
assert VariableTruncator._maybe_qa_structure({}) is False
|
||||
|
||||
|
||||
def test_maybe_parent_child_structure_should_validate_shape() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"}) is False
|
||||
)
|
||||
|
||||
|
||||
def test_truncate_object_should_truncate_segment_values_inside_object() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
|
||||
mapping = {"s": StringSegment(value="long-content")}
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_object(mapping, 20)
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value["s"], StringSegment)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_handle_updated_variable_input() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
# Act
|
||||
result = truncator._truncate_json_primitives(updated, 100)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type() -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_segment = ObjectSegment(value={"k": "v"})
|
||||
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
# Act
|
||||
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
|
||||
|
||||
# Assert
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||
from graphon.variables.segments import IntegerSegment, ObjectSegment, StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
from services import variable_truncator as truncator_module
|
||||
from services.variable_truncator import BaseTruncator, TruncationResult, VariableTruncator
|
||||
|
||||
|
||||
class _AbstractPassthrough(BaseTruncator):
|
||||
def truncate(self, segment: Any) -> TruncationResult:
|
||||
return super().truncate(segment) # type: ignore[misc]
|
||||
|
||||
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
||||
return super().truncate_variable_mapping(v) # type: ignore[misc]
|
||||
|
||||
|
||||
class TestBaseTruncatorContract:
|
||||
def test_base_truncator_methods_should_execute_abstract_placeholders(self) -> None:
|
||||
passthrough = _AbstractPassthrough()
|
||||
|
||||
truncate_result = passthrough.truncate(StringSegment(value="x"))
|
||||
mapping_result = passthrough.truncate_variable_mapping({"a": 1})
|
||||
|
||||
assert truncate_result is None
|
||||
assert mapping_result is None
|
||||
|
||||
|
||||
class TestVariableTruncatorAdditionalBehavior:
|
||||
def test_default_should_use_dify_config_limits(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE", 111)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH", 7)
|
||||
monkeypatch.setattr(truncator_module.dify_config, "WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH", 33)
|
||||
|
||||
truncator = VariableTruncator.default()
|
||||
|
||||
assert truncator._max_size_bytes == 111
|
||||
assert truncator._array_element_limit == 7
|
||||
assert truncator._string_length_limit == 33
|
||||
|
||||
def test_truncate_variable_mapping_should_mark_over_budget_keys_with_ellipsis(self) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=5)
|
||||
mapping = {"very_long_key": "value"}
|
||||
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
assert result == {"very_long_key": "..."}
|
||||
assert truncated is True
|
||||
|
||||
def test_truncate_variable_mapping_should_handle_segment_values(self) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
mapping = {"seg": StringSegment(value="hello")}
|
||||
|
||||
result, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
|
||||
assert isinstance(result["seg"], StringSegment)
|
||||
assert result["seg"].value == "hello"
|
||||
assert truncated is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, False),
|
||||
(True, False),
|
||||
(1, False),
|
||||
(1.5, False),
|
||||
("x", True),
|
||||
({"k": "v"}, True),
|
||||
],
|
||||
)
|
||||
def test_json_value_needs_truncation_should_match_expected_rules(
|
||||
self,
|
||||
value: Any,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
result = VariableTruncator._json_value_needs_truncation(value)
|
||||
assert result is expected
|
||||
|
||||
def test_truncate_should_use_string_fallback_when_truncated_value_size_exceeds_limit(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_result = truncator_module._PartResult(
|
||||
value=StringSegment(value="this is too long"),
|
||||
value_size=100,
|
||||
truncated=True,
|
||||
)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
result = truncator.truncate(StringSegment(value="input"))
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert not result.result.value.startswith('"')
|
||||
|
||||
def test_truncate_segment_should_raise_assertion_for_unexpected_truncatable_segment(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
truncator = VariableTruncator()
|
||||
monkeypatch.setattr(VariableTruncator, "_segment_need_truncation", lambda _segment: True)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_segment(IntegerSegment(value=1), 10)
|
||||
|
||||
def test_calculate_json_size_should_unwrap_segment_values(self) -> None:
|
||||
segment = StringSegment(value="abc")
|
||||
|
||||
size = VariableTruncator.calculate_json_size(segment)
|
||||
|
||||
assert size == VariableTruncator.calculate_json_size("abc")
|
||||
|
||||
def test_calculate_json_size_should_handle_updated_variable_instances(self) -> None:
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
size = VariableTruncator.calculate_json_size(updated)
|
||||
|
||||
assert size > 0
|
||||
|
||||
def test_maybe_qa_structure_should_validate_shape(self) -> None:
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": []}) is True
|
||||
assert VariableTruncator._maybe_qa_structure({"qa_chunks": "not-list"}) is False
|
||||
assert VariableTruncator._maybe_qa_structure({}) is False
|
||||
|
||||
def test_maybe_parent_child_structure_should_validate_shape(self) -> None:
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": []}) is True
|
||||
)
|
||||
assert VariableTruncator._maybe_parent_child_structure({"parent_mode": 1, "parent_child_chunks": []}) is False
|
||||
assert (
|
||||
VariableTruncator._maybe_parent_child_structure({"parent_mode": "full", "parent_child_chunks": "bad"})
|
||||
is False
|
||||
)
|
||||
|
||||
def test_truncate_object_should_truncate_segment_values_inside_object(self) -> None:
|
||||
truncator = VariableTruncator(string_length_limit=8, max_size_bytes=30)
|
||||
mapping = {"s": StringSegment(value="long-content")}
|
||||
|
||||
result = truncator._truncate_object(mapping, 20)
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value["s"], StringSegment)
|
||||
|
||||
def test_truncate_json_primitives_should_handle_updated_variable_input(self) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=100)
|
||||
updated = UpdatedVariable(name="n", selector=["node", "var"], value_type=SegmentType.STRING, new_value="v")
|
||||
|
||||
result = truncator._truncate_json_primitives(updated, 100)
|
||||
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
def test_truncate_json_primitives_should_raise_assertion_for_unsupported_value_type(self) -> None:
|
||||
truncator = VariableTruncator()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
truncator._truncate_json_primitives(object(), 100) # type: ignore[arg-type]
|
||||
|
||||
def test_truncate_should_apply_json_string_fallback_for_large_non_string_segment(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
truncator = VariableTruncator(max_size_bytes=10)
|
||||
forced_segment = ObjectSegment(value={"k": "v"})
|
||||
forced_result = truncator_module._PartResult(value=forced_segment, value_size=100, truncated=True)
|
||||
monkeypatch.setattr(truncator, "_truncate_segment", lambda *_args, **_kwargs: forced_result)
|
||||
|
||||
result = truncator.truncate(ObjectSegment(value={"a": "b"}))
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
@@ -559,771 +559,3 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
|
||||
# === Merged from test_webhook_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1", is_debug=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
# Act / Assert
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(flask_app: Flask) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream"))
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
# Act
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
# Assert
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
# Act
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
# Assert
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
# Act
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
# Assert
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
# Assert
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing() -> None:
|
||||
# Arrange
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters() -> None:
|
||||
# Arrange
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
# Assert
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data() -> None:
|
||||
# Arrange
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
# Assert
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names() -> None:
|
||||
# Arrange
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
# Assert
|
||||
assert True
|
||||
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing() -> None:
|
||||
# Arrange
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
# Act
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
# Assert
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key() -> None:
|
||||
# Arrange
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
|
||||
# Act
|
||||
result = WebhookService._extract_content_type(headers)
|
||||
|
||||
# Assert
|
||||
assert result == "application/json"
|
||||
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys() -> None:
|
||||
# Arrange
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
# Act
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Assert
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(return_value=end_user)
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
# Act
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Assert
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService, "get_or_create_end_user_by_type", MagicMock(side_effect=RuntimeError("boom"))
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit() -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
# Act
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
# Assert
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json() -> None:
|
||||
# Arrange
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
# Act
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
# Assert
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
# Assert
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
|
||||
# Assert
|
||||
assert result == 123
|
||||
|
||||
671
api/tests/unit_tests/services/test_webhook_service_additional.py
Normal file
671
api/tests/unit_tests/services/test_webhook_service_additional.py
Normal file
@@ -0,0 +1,671 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
return cast(WorkflowWebhookTrigger, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _workflow(**kwargs: Any) -> Workflow:
|
||||
return cast(Workflow, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _app(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
class TestWebhookServiceLookup:
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_limited(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mode(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow("webhook-1")
|
||||
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"key": "value"}}
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
"webhook-1",
|
||||
is_debug=True,
|
||||
)
|
||||
|
||||
assert got_trigger is webhook_trigger
|
||||
assert got_workflow is workflow
|
||||
assert got_node_config == {"data": {"mode": "debug"}}
|
||||
|
||||
|
||||
class TestWebhookServiceExtractionFallbacks:
|
||||
def test_extract_webhook_data_should_use_text_fallback_for_unknown_content_type(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
with flask_app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/vnd.custom"},
|
||||
data="plain content",
|
||||
):
|
||||
result = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert result["body"] == {"raw": "plain content"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_should_raise_for_request_too_large(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(service_module.dify_config, "WEBHOOK_REQUEST_BODY_MAX_SIZE", 1)
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="ab"):
|
||||
with pytest.raises(RequestEntityTooLarge):
|
||||
WebhookService.extract_webhook_data(MagicMock())
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_empty_payload(self, flask_app: Flask) -> None:
|
||||
webhook_trigger = MagicMock()
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b""):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
def test_extract_octet_stream_body_should_return_none_when_processing_raises(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
WebhookService, "_detect_binary_mimetype", MagicMock(return_value="application/octet-stream")
|
||||
)
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data=b"abc"):
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
|
||||
assert body == {"raw": None}
|
||||
assert files == {}
|
||||
|
||||
def test_extract_text_body_should_return_empty_string_when_request_read_fails(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr("flask.wrappers.Request.get_data", MagicMock(side_effect=RuntimeError("read error")))
|
||||
|
||||
with flask_app.test_request_context("/webhook", method="POST", data="abc"):
|
||||
body, files = WebhookService._extract_text_body()
|
||||
|
||||
assert body == {"raw": ""}
|
||||
assert files == {}
|
||||
|
||||
def test_detect_binary_mimetype_should_fallback_when_magic_raises(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_magic = MagicMock()
|
||||
fake_magic.from_buffer.side_effect = RuntimeError("magic failed")
|
||||
monkeypatch.setattr(service_module, "magic", fake_magic)
|
||||
|
||||
result = WebhookService._detect_binary_mimetype(b"binary")
|
||||
|
||||
assert result == "application/octet-stream"
|
||||
|
||||
def test_process_file_uploads_should_use_octet_stream_fallback_when_mimetype_unknown(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
file_obj = MagicMock()
|
||||
file_obj.to_dict.return_value = {"id": "f-1"}
|
||||
monkeypatch.setattr(WebhookService, "_create_file_from_binary", MagicMock(return_value=file_obj))
|
||||
monkeypatch.setattr(service_module.mimetypes, "guess_type", MagicMock(return_value=(None, None)))
|
||||
|
||||
uploaded = MagicMock()
|
||||
uploaded.filename = "file.unknown"
|
||||
uploaded.content_type = None
|
||||
uploaded.read.return_value = b"content"
|
||||
|
||||
result = WebhookService._process_file_uploads({"f": uploaded}, webhook_trigger)
|
||||
|
||||
assert result == {"f": {"id": "f-1"}}
|
||||
|
||||
def test_create_file_from_binary_should_call_tool_file_manager_and_file_factory(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(created_by="user-1", tenant_id="tenant-1")
|
||||
manager = MagicMock()
|
||||
manager.create_file_by_raw.return_value = SimpleNamespace(id="tool-file-1")
|
||||
monkeypatch.setattr(service_module, "ToolFileManager", MagicMock(return_value=manager))
|
||||
expected_file = MagicMock()
|
||||
monkeypatch.setattr(service_module.file_factory, "build_from_mapping", MagicMock(return_value=expected_file))
|
||||
|
||||
result = WebhookService._create_file_from_binary(b"abc", "text/plain", webhook_trigger)
|
||||
|
||||
assert result is expected_file
|
||||
manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
|
||||
class TestWebhookServiceValidationAndConversion:
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "param_type", "expected"),
|
||||
[
|
||||
("42", SegmentType.NUMBER, 42),
|
||||
("3.14", SegmentType.NUMBER, 3.14),
|
||||
("yes", SegmentType.BOOLEAN, True),
|
||||
("no", SegmentType.BOOLEAN, False),
|
||||
],
|
||||
)
|
||||
def test_convert_form_value_should_convert_supported_types(
|
||||
self,
|
||||
raw_value: str,
|
||||
param_type: str,
|
||||
expected: Any,
|
||||
) -> None:
|
||||
result = WebhookService._convert_form_value("param", raw_value, param_type)
|
||||
assert result == expected
|
||||
|
||||
def test_convert_form_value_should_raise_for_unsupported_type(self) -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
WebhookService._convert_form_value("p", "x", SegmentType.FILE)
|
||||
|
||||
def test_validate_json_value_should_return_original_for_unmapped_supported_segment_type(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
warning_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "warning", warning_mock)
|
||||
|
||||
result = WebhookService._validate_json_value("param", {"x": 1}, "unsupported-type")
|
||||
|
||||
assert result == {"x": 1}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
def test_validate_and_convert_value_should_wrap_conversion_errors(self) -> None:
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
WebhookService._validate_and_convert_value("param", "bad", SegmentType.NUMBER, is_form_data=True)
|
||||
|
||||
def test_process_parameters_should_raise_when_required_parameter_missing(self) -> None:
|
||||
raw_params = {"optional": "x"}
|
||||
config = [WebhookParameter(name="required_param", type=SegmentType.STRING, required=True)]
|
||||
|
||||
with pytest.raises(ValueError, match="Required parameter missing"):
|
||||
WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
def test_process_parameters_should_include_unconfigured_parameters(self) -> None:
|
||||
raw_params = {"known": "1", "unknown": "x"}
|
||||
config = [WebhookParameter(name="known", type=SegmentType.NUMBER, required=False)]
|
||||
|
||||
result = WebhookService._process_parameters(raw_params, config, is_form_data=True)
|
||||
|
||||
assert result == {"known": 1, "unknown": "x"}
|
||||
|
||||
def test_process_body_parameters_should_raise_when_required_text_raw_is_missing(self) -> None:
|
||||
with pytest.raises(ValueError, match="Required body content missing"):
|
||||
WebhookService._process_body_parameters(
|
||||
raw_body={"raw": ""},
|
||||
body_configs=[WebhookBodyParameter(name="raw", required=True)],
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
def test_process_body_parameters_should_skip_file_config_for_multipart_form_data(self) -> None:
|
||||
raw_body = {"message": "hello", "extra": "x"}
|
||||
body_configs = [
|
||||
WebhookBodyParameter(name="upload", type=SegmentType.FILE, required=True),
|
||||
WebhookBodyParameter(name="message", type=SegmentType.STRING, required=True),
|
||||
]
|
||||
|
||||
result = WebhookService._process_body_parameters(raw_body, body_configs, ContentType.FORM_DATA)
|
||||
|
||||
assert result == {"message": "hello", "extra": "x"}
|
||||
|
||||
def test_validate_required_headers_should_accept_sanitized_header_names(self) -> None:
|
||||
headers = {"x_api_key": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
def test_validate_required_headers_should_raise_when_required_header_missing(self) -> None:
|
||||
headers = {"x-other": "123"}
|
||||
configs = [WebhookParameter(name="x-api-key", required=True)]
|
||||
|
||||
with pytest.raises(ValueError, match="Required header missing"):
|
||||
WebhookService._validate_required_headers(headers, configs)
|
||||
|
||||
def test_validate_http_metadata_should_return_content_type_mismatch_error(self) -> None:
|
||||
webhook_data = {"method": "POST", "headers": {"Content-Type": "application/json"}}
|
||||
node_data = WebhookData(method="post", content_type=ContentType.TEXT)
|
||||
|
||||
result = WebhookService._validate_http_metadata(webhook_data, node_data)
|
||||
|
||||
assert result["valid"] is False
|
||||
assert "Content-type mismatch" in result["error"]
|
||||
|
||||
def test_extract_content_type_should_fallback_to_lowercase_header_key(self) -> None:
|
||||
headers = {"content-type": "application/json; charset=utf-8"}
|
||||
assert WebhookService._extract_content_type(headers) == "application/json"
|
||||
|
||||
def test_build_workflow_inputs_should_include_expected_keys(self) -> None:
|
||||
webhook_data = {"headers": {"h": "v"}, "query_params": {"q": 1}, "body": {"b": 2}}
|
||||
|
||||
result = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
assert result["webhook_data"] == webhook_data
|
||||
assert result["webhook_headers"] == {"h": "v"}
|
||||
assert result["webhook_query_params"] == {"q": 1}
|
||||
assert result["webhook_body"] == {"b": 2}
|
||||
|
||||
|
||||
class TestWebhookServiceExecutionAndSync:
|
||||
def test_trigger_workflow_execution_should_trigger_async_workflow_successfully(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
webhook_data = {"body": {"x": 1}}
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=end_user),
|
||||
)
|
||||
quota_type = SimpleNamespace(TRIGGER=SimpleNamespace(consume=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
trigger_async_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.AsyncWorkflowService, "trigger_workflow_async", trigger_async_mock)
|
||||
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
trigger_async_mock.assert_called_once()
|
||||
|
||||
def test_trigger_workflow_execution_should_mark_tenant_rate_limited_when_quota_exceeded(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(return_value=SimpleNamespace(id="end-user-1")),
|
||||
)
|
||||
quota_type = SimpleNamespace(
|
||||
TRIGGER=SimpleNamespace(
|
||||
consume=MagicMock(side_effect=QuotaExceededError(feature="trigger", tenant_id="tenant-1", required=1))
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(service_module, "QuotaType", quota_type)
|
||||
mark_rate_limited_mock = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
service_module.AppTriggerService, "mark_tenant_triggers_rate_limited", mark_rate_limited_mock
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
|
||||
mark_rate_limited_mock.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_trigger_workflow_execution_should_log_and_reraise_unexpected_errors(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
webhook_trigger = _workflow_trigger(
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
webhook_id="webhook-1",
|
||||
)
|
||||
workflow = _workflow(id="wf-1")
|
||||
|
||||
session = MagicMock()
|
||||
_patch_session(monkeypatch, session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
service_module.EndUserService,
|
||||
"get_or_create_end_user_by_type",
|
||||
MagicMock(side_effect=RuntimeError("boom")),
|
||||
)
|
||||
logger_exception_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, {"body": {}}, workflow)
|
||||
|
||||
logger_exception_mock.assert_called_once()
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_workflow_exceeds_node_limit(self) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(
|
||||
walk_nodes=lambda _node_type: [
|
||||
(f"node-{i}", {}) for i in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_should_raise_when_lock_not_acquired(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-1", {})])
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_should_create_missing_records_and_delete_stale_records(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [("node-new", {})])
|
||||
|
||||
class _WorkflowWebhookTrigger:
|
||||
app_id = "app_id"
|
||||
tenant_id = "tenant_id"
|
||||
webhook_id = "webhook_id"
|
||||
node_id = "node_id"
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, node_id: str, webhook_id: str, created_by: str) -> None:
|
||||
self.id = None
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.node_id = node_id
|
||||
self.webhook_id = webhook_id
|
||||
self.created_by = created_by
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[Any] = []
|
||||
self.deleted: list[Any] = []
|
||||
self.commit_count = 0
|
||||
self.existing_records = [SimpleNamespace(node_id="node-stale")]
|
||||
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: self.existing_records)
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def flush(self) -> None:
|
||||
for idx, obj in enumerate(self.added, start=1):
|
||||
if obj.id is None:
|
||||
obj.id = f"rec-{idx}"
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
|
||||
def delete(self, obj: Any) -> None:
|
||||
self.deleted.append(obj)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.return_value = None
|
||||
|
||||
fake_session = _Session()
|
||||
|
||||
monkeypatch.setattr(service_module, "WorkflowWebhookTrigger", _WorkflowWebhookTrigger)
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
redis_set_mock = MagicMock()
|
||||
redis_delete_mock = MagicMock()
|
||||
monkeypatch.setattr(service_module.redis_client, "set", redis_set_mock)
|
||||
monkeypatch.setattr(service_module.redis_client, "delete", redis_delete_mock)
|
||||
monkeypatch.setattr(WebhookService, "generate_webhook_id", MagicMock(return_value="generated-webhook-id"))
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_sync_webhook_relationships_should_log_when_lock_release_fails(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
app = _app(id="app-1", tenant_id="tenant-1", created_by="user-1")
|
||||
workflow = _workflow(walk_nodes=lambda _node_type: [])
|
||||
|
||||
class _Select:
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_Select":
|
||||
return self
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt: Any) -> Any:
|
||||
return SimpleNamespace(all=lambda: [])
|
||||
|
||||
def commit(self) -> None:
|
||||
return None
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
logger_exception_mock = MagicMock()
|
||||
|
||||
monkeypatch.setattr(service_module, "select", MagicMock(return_value=_Select()))
|
||||
monkeypatch.setattr(service_module.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module.redis_client, "lock", MagicMock(return_value=lock))
|
||||
monkeypatch.setattr(service_module.logger, "exception", logger_exception_mock)
|
||||
_patch_session(monkeypatch, _Session())
|
||||
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
assert logger_exception_mock.call_count == 1
|
||||
|
||||
|
||||
class TestWebhookServiceUtilities:
|
||||
def test_generate_webhook_response_should_fallback_when_response_body_is_not_json(self) -> None:
|
||||
node_config = {"data": {"status_code": 200, "response_body": "{bad-json"}}
|
||||
|
||||
body, status = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status == 200
|
||||
assert "message" in body
|
||||
|
||||
def test_generate_webhook_id_should_return_24_character_identifier(self) -> None:
|
||||
webhook_id = WebhookService.generate_webhook_id()
|
||||
|
||||
assert isinstance(webhook_id, str)
|
||||
assert len(webhook_id) == 24
|
||||
|
||||
def test_sanitize_key_should_return_original_value_for_non_string_input(self) -> None:
|
||||
result = WebhookService._sanitize_key(123) # type: ignore[arg-type]
|
||||
assert result == 123
|
||||
262
api/tests/unit_tests/services/test_workflow_run_service.py
Normal file
262
api/tests/unit_tests/services/test_workflow_run_service.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
|
||||
from services import workflow_run_service as service_module
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
|
||||
node_repo = MagicMock()
|
||||
workflow_run_repo = MagicMock()
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
return node_repo, workflow_run_repo, factory
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
class TestWorkflowRunServiceInitialization:
|
||||
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
|
||||
|
||||
service = WorkflowRunService()
|
||||
|
||||
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
def test___init___should_create_sessionmaker_when_engine_is_provided(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
class FakeEngine:
|
||||
pass
|
||||
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "Engine", FakeEngine)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
engine = cast(Engine, FakeEngine())
|
||||
|
||||
service = WorkflowRunService(session_factory=engine)
|
||||
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
node_repo, workflow_run_repo, factory = repository_factory_mocks
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
|
||||
service = WorkflowRunService(session_factory=session_factory)
|
||||
|
||||
assert service._session_factory is session_factory
|
||||
assert service._node_execution_service_repo is node_repo
|
||||
assert service._workflow_run_repo is workflow_run_repo
|
||||
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
|
||||
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
|
||||
class TestWorkflowRunServiceQueries:
|
||||
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="pagination")
|
||||
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
|
||||
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
|
||||
|
||||
result = service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
limit=7,
|
||||
last_id="last-1",
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
run_with_message = SimpleNamespace(
|
||||
id="run-1",
|
||||
status="running",
|
||||
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
|
||||
)
|
||||
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
|
||||
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
|
||||
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
|
||||
|
||||
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
|
||||
|
||||
assert result is pagination
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].message_id == "msg-1"
|
||||
assert result.data[0].conversation_id == "conv-1"
|
||||
assert result.data[0].status == "running"
|
||||
assert not hasattr(result.data[1], "message_id")
|
||||
assert result.data[1].id == "run-2"
|
||||
|
||||
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="workflow_run")
|
||||
workflow_run_repo.get_workflow_run_by_id.return_value = expected
|
||||
|
||||
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
|
||||
|
||||
assert result is expected
|
||||
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
def test_get_workflow_runs_count_should_forward_optional_filters(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = {"total": 3, "succeeded": 2}
|
||||
workflow_run_repo.get_workflow_runs_count.return_value = expected
|
||||
|
||||
result = service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-1")
|
||||
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
|
||||
class FakeEndUser:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
|
||||
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
|
||||
app_model = _app_model(id="app-1")
|
||||
expected = [SimpleNamespace(id="exec-1")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-end-user",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-account")
|
||||
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-account",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
|
||||
self,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id=None)
|
||||
|
||||
with pytest.raises(ValueError, match="tenant_id cannot be None"):
|
||||
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
@@ -176,300 +176,3 @@ class TestWorkflowRunService:
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
|
||||
|
||||
# === Merged from test_workflow_run_service.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account, App, EndUser, WorkflowRunTriggeredFrom
|
||||
from services import workflow_run_service as service_module
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_factory_mocks(monkeypatch: pytest.MonkeyPatch) -> tuple[MagicMock, MagicMock, Any]:
|
||||
# Arrange
|
||||
node_repo = MagicMock()
|
||||
workflow_run_repo = MagicMock()
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
|
||||
# Assert
|
||||
return node_repo, workflow_run_repo, factory
|
||||
|
||||
|
||||
def _app_model(**kwargs: Any) -> App:
|
||||
return cast(App, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine="db-engine"))
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService()
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind="db-engine", expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_create_sessionmaker_when_engine_is_provided(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
class FakeEngine:
|
||||
pass
|
||||
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
sessionmaker_mock = MagicMock(return_value=session_factory)
|
||||
monkeypatch.setattr(service_module, "Engine", FakeEngine)
|
||||
monkeypatch.setattr(service_module, "sessionmaker", sessionmaker_mock)
|
||||
engine = cast(Engine, FakeEngine())
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=engine)
|
||||
|
||||
# Assert
|
||||
sessionmaker_mock.assert_called_once_with(bind=engine, expire_on_commit=False)
|
||||
assert service._session_factory is session_factory
|
||||
|
||||
|
||||
def test___init___should_keep_provided_sessionmaker_and_create_repositories(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, workflow_run_repo, factory = repository_factory_mocks
|
||||
session_factory = MagicMock(name="session_factory")
|
||||
|
||||
# Act
|
||||
service = WorkflowRunService(session_factory=session_factory)
|
||||
|
||||
# Assert
|
||||
assert service._session_factory is session_factory
|
||||
assert service._node_execution_service_repo is node_repo
|
||||
assert service._workflow_run_repo is workflow_run_repo
|
||||
factory.create_api_workflow_node_execution_repository.assert_called_once_with(session_factory)
|
||||
factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
|
||||
def test_get_paginate_workflow_runs_should_forward_filters_and_parse_limit(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="pagination")
|
||||
workflow_run_repo.get_paginated_workflow_runs.return_value = expected
|
||||
args = {"limit": "7", "last_id": "last-1", "status": "succeeded"}
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_paginated_workflow_runs.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
limit=7,
|
||||
last_id="last-1",
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_when_message_exists(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
run_with_message = SimpleNamespace(
|
||||
id="run-1",
|
||||
status="running",
|
||||
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
|
||||
)
|
||||
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
|
||||
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
|
||||
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})
|
||||
|
||||
# Assert
|
||||
assert result is pagination
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].message_id == "msg-1"
|
||||
assert result.data[0].conversation_id == "conv-1"
|
||||
assert result.data[0].status == "running"
|
||||
assert not hasattr(result.data[1], "message_id")
|
||||
assert result.data[1].id == "run-2"
|
||||
|
||||
|
||||
def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = MagicMock(name="workflow_run")
|
||||
workflow_run_repo.get_workflow_run_by_id.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run(app_model=app_model, run_id="run-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected
|
||||
workflow_run_repo.get_workflow_run_by_id.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_runs_count_should_forward_optional_filters(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
) -> None:
|
||||
# Arrange
|
||||
_, workflow_run_repo, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
app_model = _app_model(tenant_id="tenant-1", id="app-1")
|
||||
expected = {"total": 3, "succeeded": 2}
|
||||
workflow_run_repo.get_workflow_runs_count.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
workflow_run_repo.get_workflow_runs_count.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_return_empty_list_when_run_not_found(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=None))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-1")
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_end_user_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
|
||||
class FakeEndUser:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(service_module, "EndUser", FakeEndUser)
|
||||
user = cast(EndUser, FakeEndUser(tenant_id="tenant-end-user"))
|
||||
app_model = _app_model(id="app-1")
|
||||
expected = [SimpleNamespace(id="exec-1")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-end-user",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_use_account_current_tenant_id(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
node_repo, _, _ = repository_factory_mocks
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id="tenant-account")
|
||||
expected = [SimpleNamespace(id="exec-1"), SimpleNamespace(id="exec-2")]
|
||||
node_repo.get_executions_by_workflow_run.return_value = expected
|
||||
|
||||
# Act
|
||||
result = service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
node_repo.get_executions_by_workflow_run.assert_called_once_with(
|
||||
tenant_id="tenant-account",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_run_node_executions_should_raise_when_resolved_tenant_id_is_none(
|
||||
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
|
||||
monkeypatch.setattr(service, "get_workflow_run", MagicMock(return_value=SimpleNamespace(id="run-1")))
|
||||
app_model = _app_model(id="app-1")
|
||||
user = _account(current_tenant_id=None)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="tenant_id cannot be None"):
|
||||
service.get_workflow_run_node_executions(app_model=app_model, run_id="run-1", user=user)
|
||||
|
||||
@@ -3,7 +3,6 @@ import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from itertools import cycle
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
@@ -223,577 +222,3 @@ def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
|
||||
|
||||
# === Merged from test_workflow_event_snapshot_service_additional.py ===
|
||||
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services import workflow_event_snapshot_service as service_module
|
||||
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
|
||||
|
||||
|
||||
def _build_workflow_run_additional(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"query": "hello"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context_additional(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.outputs = {"answer": "ok"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionMaker:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __call__(self) -> _SessionContext:
|
||||
return _SessionContext(self._session)
|
||||
|
||||
|
||||
class _SubscriptionContext:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._subscription
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _Topic:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def subscribe(self) -> _SubscriptionContext:
|
||||
return _SubscriptionContext(self._subscription)
|
||||
|
||||
|
||||
class _StaticSubscription:
|
||||
def receive(self, timeout: int = 1) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PauseEntity(WorkflowPauseEntity):
|
||||
state: bytes
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return "pause-1"
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return "run-1"
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
return self.state
|
||||
|
||||
def get_pause_reasons(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def test_get_message_context_should_return_none_when_no_message() -> None:
|
||||
# Arrange
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp() -> None:
|
||||
# Arrange
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
created_at=None,
|
||||
answer="answer",
|
||||
)
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=message))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
# Act
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.created_at == 0
|
||||
assert result.message_id == "msg-1"
|
||||
assert result.conversation_id == "conv-1"
|
||||
assert result.answer == "answer"
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(None)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid() -> None:
|
||||
# Arrange
|
||||
pause_entity = _PauseEntity(state=b"not-a-valid-state")
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_resumption_context_should_parse_valid_state_into_context() -> None:
|
||||
# Arrange
|
||||
context = _build_resumption_context_additional(task_id="task-ctx")
|
||||
pause_entity = _PauseEntity(state=context.dumps().encode())
|
||||
|
||||
# Act
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.get_generate_entity().task_id == "task-ctx"
|
||||
|
||||
|
||||
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing() -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._resolve_task_id(
|
||||
resumption_context=None,
|
||||
buffer_state=None,
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(b'{"event":"node_started"}', {"event": "node_started"}),
|
||||
(b"invalid-json", None),
|
||||
(b"[]", None),
|
||||
],
|
||||
)
|
||||
def test_parse_event_message_should_parse_only_json_object(
|
||||
payload: bytes,
|
||||
expected: dict[str, Any] | None,
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
||||
# Act
|
||||
result = service_module._parse_event_message(payload)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events() -> None:
|
||||
# Arrange
|
||||
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
|
||||
|
||||
# Act
|
||||
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
|
||||
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
|
||||
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
|
||||
|
||||
# Assert
|
||||
assert is_finished is True
|
||||
assert paused_without_flag is False
|
||||
assert paused_with_flag is True
|
||||
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
|
||||
|
||||
|
||||
def test_apply_message_context_should_update_payload_when_context_exists() -> None:
|
||||
# Arrange
|
||||
payload: dict[str, Any] = {"event": "workflow_started"}
|
||||
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
|
||||
|
||||
# Act
|
||||
service_module._apply_message_context(payload, context)
|
||||
|
||||
# Assert
|
||||
assert payload["conversation_id"] == "conv-1"
|
||||
assert payload["message_id"] == "msg-1"
|
||||
assert payload["created_at"] == 1700000000
|
||||
|
||||
|
||||
def test_start_buffering_should_capture_task_id_and_enqueue_event() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-1"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert buffer_state.task_id_hint == "task-1"
|
||||
assert event["event"] == "node_started"
|
||||
|
||||
|
||||
def test_start_buffering_should_drop_old_event_when_queue_is_full(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
class QueueWithSingleFull:
|
||||
def __init__(self) -> None:
|
||||
self._first_put = True
|
||||
self.items: list[dict[str, Any]] = [{"event": "old"}]
|
||||
|
||||
def put_nowait(self, item: dict[str, Any]) -> None:
|
||||
if self._first_put:
|
||||
self._first_put = False
|
||||
raise queue.Full
|
||||
self.items.append(item)
|
||||
|
||||
def get_nowait(self) -> dict[str, Any]:
|
||||
if not self.items:
|
||||
raise queue.Empty
|
||||
return self.items.pop(0)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.items) == 0
|
||||
|
||||
fake_queue = QueueWithSingleFull()
|
||||
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
|
||||
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-2"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert fake_queue.items[-1]["task_id"] == "task-2"
|
||||
|
||||
|
||||
def test_start_buffering_should_set_done_event_when_subscription_raises() -> None:
|
||||
# Arrange
|
||||
class Subscription:
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
raise RuntimeError("subscription failure")
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
# Act
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
# Assert
|
||||
assert finished is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_get_message_context",
|
||||
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_build_snapshot_events",
|
||||
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
|
||||
)
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
finished_event = cast(Mapping[str, Any], events[1])
|
||||
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
|
||||
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
|
||||
assert called_kwargs["workflow_run_id"] == "run-1"
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
|
||||
class AlwaysEmptyQueue:
|
||||
def empty(self) -> bool:
|
||||
return False
|
||||
|
||||
def get(self, timeout: int = 1) -> None:
|
||||
raise queue.Empty
|
||||
|
||||
buffer_state = BufferState(
|
||||
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
time_values = cycle([0.0, 6.0, 21.0, 26.0])
|
||||
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
idle_timeout=20.0,
|
||||
ping_interval=5.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
buffer_state.done_event.set()
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events == [StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
|
||||
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Arrange
|
||||
workflow_run = _build_workflow_run_additional(status=WorkflowExecutionStatus.PAUSED)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
# Act
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
assert snapshot_builder.call_args.kwargs["pause_entity"] is None
|
||||
|
||||
@@ -0,0 +1,505 @@
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from itertools import cycle
|
||||
from threading import Event
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services import workflow_event_snapshot_service as service_module
|
||||
from services.workflow_event_snapshot_service import BufferState, MessageContext, build_workflow_event_stream
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"query": "hello"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.outputs = {"answer": "ok"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionMaker:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def __call__(self) -> _SessionContext:
|
||||
return _SessionContext(self._session)
|
||||
|
||||
|
||||
class _SubscriptionContext:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._subscription
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _Topic:
|
||||
def __init__(self, subscription: Any) -> None:
|
||||
self._subscription = subscription
|
||||
|
||||
def subscribe(self) -> _SubscriptionContext:
|
||||
return _SubscriptionContext(self._subscription)
|
||||
|
||||
|
||||
class _StaticSubscription:
|
||||
def receive(self, timeout: int = 1) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PauseEntity(WorkflowPauseEntity):
|
||||
state: bytes
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return "pause-1"
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return "run-1"
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
return self.state
|
||||
|
||||
def get_pause_reasons(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
class TestWorkflowEventSnapshotHelpers:
|
||||
def test_get_message_context_should_return_none_when_no_message(self) -> None:
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=None))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_message_context_should_default_created_at_to_zero_when_message_has_no_timestamp(self) -> None:
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
created_at=None,
|
||||
answer="answer",
|
||||
)
|
||||
session = SimpleNamespace(scalar=MagicMock(return_value=message))
|
||||
session_maker = _SessionMaker(session)
|
||||
|
||||
result = service_module._get_message_context(cast(sessionmaker[Session], session_maker), "run-1")
|
||||
|
||||
assert result is not None
|
||||
assert result.created_at == 0
|
||||
assert result.message_id == "msg-1"
|
||||
assert result.conversation_id == "conv-1"
|
||||
assert result.answer == "answer"
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_missing(self) -> None:
|
||||
assert service_module._load_resumption_context(None) is None
|
||||
|
||||
def test_load_resumption_context_should_return_none_when_pause_entity_state_is_invalid(self) -> None:
|
||||
pause_entity = _PauseEntity(state=b"not-a-valid-state")
|
||||
assert service_module._load_resumption_context(pause_entity) is None
|
||||
|
||||
def test_load_resumption_context_should_parse_valid_state_into_context(self) -> None:
|
||||
context = _build_resumption_context(task_id="task-ctx")
|
||||
pause_entity = _PauseEntity(state=context.dumps().encode())
|
||||
|
||||
result = service_module._load_resumption_context(pause_entity)
|
||||
|
||||
assert result is not None
|
||||
assert result.get_generate_entity().task_id == "task-ctx"
|
||||
|
||||
def test_resolve_task_id_should_return_workflow_run_id_when_buffer_state_is_missing(self) -> None:
|
||||
result = service_module._resolve_task_id(
|
||||
resumption_context=None,
|
||||
buffer_state=None,
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
assert result == "run-1"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(b'{"event":"node_started"}', {"event": "node_started"}),
|
||||
(b"invalid-json", None),
|
||||
(b"[]", None),
|
||||
],
|
||||
)
|
||||
def test_parse_event_message_should_parse_only_json_object(
|
||||
self,
|
||||
payload: bytes,
|
||||
expected: dict[str, Any] | None,
|
||||
) -> None:
|
||||
result = service_module._parse_event_message(payload)
|
||||
assert result == expected
|
||||
|
||||
def test_is_terminal_event_should_recognize_finished_and_optional_paused_events(self) -> None:
|
||||
finished_event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
paused_event = {"event": StreamEvent.WORKFLOW_PAUSED.value}
|
||||
|
||||
is_finished = service_module._is_terminal_event(finished_event, include_paused=False)
|
||||
paused_without_flag = service_module._is_terminal_event(paused_event, include_paused=False)
|
||||
paused_with_flag = service_module._is_terminal_event(paused_event, include_paused=True)
|
||||
|
||||
assert is_finished is True
|
||||
assert paused_without_flag is False
|
||||
assert paused_with_flag is True
|
||||
assert service_module._is_terminal_event(StreamEvent.PING.value, include_paused=True) is False
|
||||
|
||||
def test_apply_message_context_should_update_payload_when_context_exists(self) -> None:
|
||||
payload: dict[str, Any] = {"event": "workflow_started"}
|
||||
context = MessageContext(conversation_id="conv-1", message_id="msg-1", created_at=1700000000)
|
||||
|
||||
service_module._apply_message_context(payload, context)
|
||||
|
||||
assert payload["conversation_id"] == "conv-1"
|
||||
assert payload["message_id"] == "msg-1"
|
||||
assert payload["created_at"] == 1700000000
|
||||
|
||||
def test_start_buffering_should_capture_task_id_and_enqueue_event(self) -> None:
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-1"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
event = buffer_state.queue.get(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert buffer_state.task_id_hint == "task-1"
|
||||
assert event["event"] == "node_started"
|
||||
|
||||
def test_start_buffering_should_drop_old_event_when_queue_is_full(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class QueueWithSingleFull:
|
||||
def __init__(self) -> None:
|
||||
self._first_put = True
|
||||
self.items: list[dict[str, Any]] = [{"event": "old"}]
|
||||
|
||||
def put_nowait(self, item: dict[str, Any]) -> None:
|
||||
if self._first_put:
|
||||
self._first_put = False
|
||||
raise queue.Full
|
||||
self.items.append(item)
|
||||
|
||||
def get_nowait(self) -> dict[str, Any]:
|
||||
if not self.items:
|
||||
raise queue.Empty
|
||||
return self.items.pop(0)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.items) == 0
|
||||
|
||||
fake_queue = QueueWithSingleFull()
|
||||
monkeypatch.setattr(service_module.queue, "Queue", lambda maxsize=2048: fake_queue)
|
||||
|
||||
class Subscription:
|
||||
def __init__(self) -> None:
|
||||
self._calls = 0
|
||||
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
return b'{"event":"node_started","task_id":"task-2"}'
|
||||
return None
|
||||
|
||||
subscription = Subscription()
|
||||
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
ready = buffer_state.task_id_ready.wait(timeout=1)
|
||||
buffer_state.stop_event.set()
|
||||
finished = buffer_state.done_event.wait(timeout=1)
|
||||
|
||||
assert ready is True
|
||||
assert finished is True
|
||||
assert fake_queue.items[-1]["task_id"] == "task-2"
|
||||
|
||||
def test_start_buffering_should_set_done_event_when_subscription_raises(self) -> None:
|
||||
class Subscription:
|
||||
def receive(self, timeout: int = 1) -> bytes | None:
|
||||
raise RuntimeError("subscription failure")
|
||||
|
||||
subscription = Subscription()
|
||||
buffer_state = service_module._start_buffering(subscription)
|
||||
|
||||
assert buffer_state.done_event.wait(timeout=1) is True
|
||||
|
||||
|
||||
class TestBuildWorkflowEventStream:
|
||||
def test_build_workflow_event_stream_should_emit_ping_and_terminal_snapshot_event(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_get_message_context",
|
||||
MagicMock(return_value=MessageContext("conv-1", "msg-1", 1700000000)),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
monkeypatch.setattr(
|
||||
service_module,
|
||||
"_build_snapshot_events",
|
||||
MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value, "task_id": "task-1"}]),
|
||||
)
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
finished_event = cast(Mapping[str, Any], events[1])
|
||||
assert finished_event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
node_repo.get_execution_snapshots_by_workflow_run.assert_called_once()
|
||||
called_kwargs = node_repo.get_execution_snapshots_by_workflow_run.call_args.kwargs
|
||||
assert called_kwargs["workflow_run_id"] == "run-1"
|
||||
|
||||
def test_build_workflow_event_stream_should_emit_periodic_ping_and_stop_after_idle_timeout(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
|
||||
class AlwaysEmptyQueue:
|
||||
def empty(self) -> bool:
|
||||
return False
|
||||
|
||||
def get(self, timeout: int = 1) -> None:
|
||||
raise queue.Empty
|
||||
|
||||
buffer_state = BufferState(
|
||||
queue=AlwaysEmptyQueue(), # type: ignore[arg-type]
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
time_values = cycle([0.0, 6.0, 21.0, 26.0])
|
||||
monkeypatch.setattr(service_module.time, "time", lambda: next(time_values))
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
idle_timeout=20.0,
|
||||
ping_interval=5.0,
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [StreamEvent.PING.value, StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
def test_build_workflow_event_stream_should_exit_when_buffer_done_and_empty(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock())
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", MagicMock(return_value=[]))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
buffer_state.done_event.set()
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [StreamEvent.PING.value]
|
||||
assert buffer_state.stop_event.is_set() is True
|
||||
|
||||
def test_build_workflow_event_stream_should_continue_when_pause_loading_fails(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
workflow_run = _build_workflow_run(status=WorkflowExecutionStatus.PAUSED)
|
||||
topic = _Topic(_StaticSubscription())
|
||||
workflow_run_repo = SimpleNamespace(get_workflow_pause=MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_repo = SimpleNamespace(get_execution_snapshots_by_workflow_run=MagicMock(return_value=[]))
|
||||
factory = SimpleNamespace(
|
||||
create_api_workflow_run_repository=MagicMock(return_value=workflow_run_repo),
|
||||
create_api_workflow_node_execution_repository=MagicMock(return_value=node_repo),
|
||||
)
|
||||
monkeypatch.setattr(service_module, "DifyAPIRepositoryFactory", factory)
|
||||
monkeypatch.setattr(service_module.MessageGenerator, "get_response_topic", MagicMock(return_value=topic))
|
||||
monkeypatch.setattr(service_module, "_load_resumption_context", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(service_module, "_resolve_task_id", MagicMock(return_value="task-1"))
|
||||
snapshot_builder = MagicMock(return_value=[{"event": StreamEvent.WORKFLOW_FINISHED.value}])
|
||||
monkeypatch.setattr(service_module, "_build_snapshot_events", snapshot_builder)
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint="task-1",
|
||||
)
|
||||
monkeypatch.setattr(service_module, "_start_buffering", MagicMock(return_value=buffer_state))
|
||||
|
||||
events = list(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=MagicMock(),
|
||||
)
|
||||
)
|
||||
|
||||
assert events[0] == StreamEvent.PING.value
|
||||
assert snapshot_builder.call_args.kwargs["pause_entity"] is None
|
||||
26
api/uv.lock
generated
26
api/uv.lock
generated
@@ -4704,21 +4704,21 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "12.1.1"
|
||||
version = "12.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1f/42/5c74462b4fd957fcd7b13b04fb3205ff8349236ea74c7c375766d6c82288/pillow-12.1.1.tar.gz", hash = "sha256:9ad8fa5937ab05218e2b6a4cff30295ad35afd2f83ac592e68c0d871bb0fdbc4", size = 46980264, upload-time = "2026-02-11T04:23:07.146Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5", size = 46987819, upload-time = "2026-04-01T14:46:17.687Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/d3/8df65da0d4df36b094351dce696f2989bec731d4f10e743b1c5f4da4d3bf/pillow-12.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab323b787d6e18b3d91a72fc99b1a2c28651e4358749842b8f8dfacd28ef2052", size = 5262803, upload-time = "2026-02-11T04:20:47.653Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/71/5026395b290ff404b836e636f51d7297e6c83beceaa87c592718747e670f/pillow-12.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:adebb5bee0f0af4909c30db0d890c773d1a92ffe83da908e2e9e720f8edf3984", size = 4657601, upload-time = "2026-02-11T04:20:49.328Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/2e/1001613d941c67442f745aff0f7cc66dd8df9a9c084eb497e6a543ee6f7e/pillow-12.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb66b7cc26f50977108790e2456b7921e773f23db5630261102233eb355a3b79", size = 6234995, upload-time = "2026-02-11T04:20:51.032Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/26/246ab11455b2549b9233dbd44d358d033a2f780fa9007b61a913c5b2d24e/pillow-12.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aee2810642b2898bb187ced9b349e95d2a7272930796e022efaf12e99dccd293", size = 8045012, upload-time = "2026-02-11T04:20:52.882Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/8b/07587069c27be7535ac1fe33874e32de118fbd34e2a73b7f83436a88368c/pillow-12.1.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a0b1cd6232e2b618adcc54d9882e4e662a089d5768cd188f7c245b4c8c44a397", size = 6349638, upload-time = "2026-02-11T04:20:54.444Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/79/6df7b2ee763d619cda2fb4fea498e5f79d984dae304d45a8999b80d6cf5c/pillow-12.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7aac39bcf8d4770d089588a2e1dd111cbaa42df5a94be3114222057d68336bd0", size = 7041540, upload-time = "2026-02-11T04:20:55.97Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/5e/2ba19e7e7236d7529f4d873bdaf317a318896bac289abebd4bb00ef247f0/pillow-12.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ab174cd7d29a62dd139c44bf74b698039328f45cb03b4596c43473a46656b2f3", size = 6462613, upload-time = "2026-02-11T04:20:57.542Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/03/31216ec124bb5c3dacd74ce8efff4cc7f52643653bad4825f8f08c697743/pillow-12.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:339ffdcb7cbeaa08221cd401d517d4b1fe7a9ed5d400e4a8039719238620ca35", size = 7166745, upload-time = "2026-02-11T04:20:59.196Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/e7/7c4552d80052337eb28653b617eafdef39adfb137c49dd7e831b8dc13bc5/pillow-12.1.1-cp312-cp312-win32.whl", hash = "sha256:5d1f9575a12bed9e9eedd9a4972834b08c97a352bd17955ccdebfeca5913fa0a", size = 6328823, upload-time = "2026-02-11T04:21:01.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/17/688626d192d7261bbbf98846fc98995726bddc2c945344b65bec3a29d731/pillow-12.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:21329ec8c96c6e979cd0dfd29406c40c1d52521a90544463057d2aaa937d66a6", size = 7033367, upload-time = "2026-02-11T04:21:03.536Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/fe/a0ef1f73f939b0eca03ee2c108d0043a87468664770612602c63266a43c4/pillow-12.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:af9a332e572978f0218686636610555ae3defd1633597be015ed50289a03c523", size = 2453811, upload-time = "2026-02-11T04:21:05.116Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/be/7482c8a5ebebbc6470b3eb791812fff7d5e0216c2be3827b30b8bb6603ed/pillow-12.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2d192a155bbcec180f8564f693e6fd9bccff5a7af9b32e2e4bf8c9c69dbad6b5", size = 5308279, upload-time = "2026-04-01T14:43:13.246Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/95/0a351b9289c2b5cbde0bacd4a83ebc44023e835490a727b2a3bd60ddc0f4/pillow-12.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3f40b3c5a968281fd507d519e444c35f0ff171237f4fdde090dd60699458421", size = 4695490, upload-time = "2026-04-01T14:43:15.584Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/af/4e8e6869cbed569d43c416fad3dc4ecb944cb5d9492defaed89ddd6fe871/pillow-12.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:03e7e372d5240cc23e9f07deca4d775c0817bffc641b01e9c3af208dbd300987", size = 6284462, upload-time = "2026-04-01T14:43:18.268Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/9e/c05e19657fd57841e476be1ab46c4d501bffbadbafdc31a6d665f8b737b6/pillow-12.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b86024e52a1b269467a802258c25521e6d742349d760728092e1bc2d135b4d76", size = 8094744, upload-time = "2026-04-01T14:43:20.716Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/54/1789c455ed10176066b6e7e6da1b01e50e36f94ba584dc68d9eebfe9156d/pillow-12.2.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7371b48c4fa448d20d2714c9a1f775a81155050d383333e0a6c15b1123dda005", size = 6398371, upload-time = "2026-04-01T14:43:23.443Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/e3/fdc657359e919462369869f1c9f0e973f353f9a9ee295a39b1fea8ee1a77/pillow-12.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62f5409336adb0663b7caa0da5c7d9e7bdbaae9ce761d34669420c2a801b2780", size = 7087215, upload-time = "2026-04-01T14:43:26.758Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/f8/2f6825e441d5b1959d2ca5adec984210f1ec086435b0ed5f52c19b3b8a6e/pillow-12.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:01afa7cf67f74f09523699b4e88c73fb55c13346d212a59a2db1f86b0a63e8c5", size = 6509783, upload-time = "2026-04-01T14:43:29.56Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/f9/029a27095ad20f854f9dba026b3ea6428548316e057e6fc3545409e86651/pillow-12.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc3d34d4a8fbec3e88a79b92e5465e0f9b842b628675850d860b8bd300b159f5", size = 7212112, upload-time = "2026-04-01T14:43:32.091Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/42/025cfe05d1be22dbfdb4f264fe9de1ccda83f66e4fc3aac94748e784af04/pillow-12.2.0-cp312-cp312-win32.whl", hash = "sha256:58f62cc0f00fd29e64b29f4fd923ffdb3859c9f9e6105bfc37ba1d08994e8940", size = 6378489, upload-time = "2026-04-01T14:43:34.601Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/7b/25a221d2c761c6a8ae21bfa3874988ff2583e19cf8a27bf2fee358df7942/pillow-12.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f84204dee22a783350679a0333981df803dac21a0190d706a50475e361c93f5", size = 7084129, upload-time = "2026-04-01T14:43:37.213Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/e1/542a474affab20fd4a0f1836cb234e8493519da6b76899e30bcc5d990b8b/pillow-12.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:af73337013e0b3b46f175e79492d96845b16126ddf79c438d7ea7ff27783a414", size = 2463612, upload-time = "2026-04-01T14:43:39.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
129
e2e/AGENTS.md
129
e2e/AGENTS.md
@@ -165,3 +165,132 @@ Open the HTML report locally with:
|
||||
```bash
|
||||
open cucumber-report/report.html
|
||||
```
|
||||
|
||||
## Writing new scenarios
|
||||
|
||||
### Workflow
|
||||
|
||||
1. Create a `.feature` file under `features/<capability>/`
|
||||
2. Add step definitions under `features/step-definitions/<capability>/`
|
||||
3. Reuse existing steps from `common/` and other definition files before writing new ones
|
||||
4. Run with `pnpm -C e2e e2e -- --tags @your-tag` to verify
|
||||
5. Run `pnpm -C e2e check` before committing
|
||||
|
||||
### Feature file conventions
|
||||
|
||||
Tag every feature or scenario with a capability tag. Add auth tags only when they clarify intent or change the browser session behavior:
|
||||
|
||||
```gherkin
|
||||
@datasets @authenticated
|
||||
Feature: Create dataset
|
||||
Scenario: Create a new empty dataset
|
||||
Given I am signed in as the default E2E admin
|
||||
When I open the datasets page
|
||||
...
|
||||
```
|
||||
|
||||
- Capability tags (`@apps`, `@auth`, `@datasets`, …) group related scenarios for selective runs
|
||||
- Auth/session tags:
|
||||
- default behavior — scenarios run with the shared authenticated storageState unless marked otherwise
|
||||
- `@unauthenticated` — uses a clean BrowserContext with no cookies or storage
|
||||
- `@authenticated` — optional intent tag for readability or selective runs; it does not currently change hook behavior on its own
|
||||
- `@fresh` — only runs in `e2e:full` mode (requires uninitialized instance)
|
||||
- `@skip` — excluded from all runs
|
||||
|
||||
Keep scenarios short and declarative. Each step should describe **what** the user does, not **how** the UI works.
|
||||
|
||||
### Step definition conventions
|
||||
|
||||
```typescript
|
||||
import { When, Then } from '@cucumber/cucumber'
|
||||
import { expect } from '@playwright/test'
|
||||
import type { DifyWorld } from '../../support/world'
|
||||
|
||||
When('I open the datasets page', async function (this: DifyWorld) {
|
||||
await this.getPage().goto('/datasets')
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Always type `this` as `DifyWorld` for proper context access
|
||||
- Use `async function` (not arrow functions — Cucumber binds `this`)
|
||||
- One step = one user-visible action or one assertion
|
||||
- Keep steps stateless across scenarios; use `DifyWorld` properties for in-scenario state
|
||||
|
||||
### Locator priority
|
||||
|
||||
Follow the Playwright recommended locator strategy, in order of preference:
|
||||
|
||||
| Priority | Locator | Example | When to use |
|
||||
| -------- | ------------------ | ----------------------------------------- | ----------------------------------------- |
|
||||
| 1 | `getByRole` | `getByRole('button', { name: 'Create' })` | Default choice — accessible and resilient |
|
||||
| 2 | `getByLabel` | `getByLabel('App name')` | Form inputs with visible labels |
|
||||
| 3 | `getByPlaceholder` | `getByPlaceholder('Enter name')` | Inputs without visible labels |
|
||||
| 4 | `getByText` | `getByText('Welcome')` | Static text content |
|
||||
| 5 | `getByTestId` | `getByTestId('workflow-canvas')` | Only when no semantic locator works |
|
||||
|
||||
Avoid raw CSS/XPath selectors. They break when the DOM structure changes.
|
||||
|
||||
### Assertions
|
||||
|
||||
Use `@playwright/test` `expect` — it auto-waits and retries until the condition is met or the timeout expires:
|
||||
|
||||
```typescript
|
||||
// URL assertion
|
||||
await expect(page).toHaveURL(/\/datasets\/[a-f0-9-]+\/documents/)
|
||||
|
||||
// Element visibility
|
||||
await expect(page.getByRole('button', { name: 'Save' })).toBeVisible()
|
||||
|
||||
// Element state
|
||||
await expect(page.getByRole('button', { name: 'Submit' })).toBeEnabled()
|
||||
|
||||
// Negation
|
||||
await expect(page.getByText('Loading')).not.toBeVisible()
|
||||
```
|
||||
|
||||
Do not use manual `waitForTimeout` or polling loops. If you need a longer wait for a specific assertion, pass `{ timeout: 30_000 }` to the assertion.
|
||||
|
||||
### Cucumber expressions
|
||||
|
||||
Use Cucumber expression parameter types to extract values from Gherkin steps:
|
||||
|
||||
| Type | Pattern | Example step |
|
||||
| ---------- | ------------- | ---------------------------------- |
|
||||
| `{string}` | Quoted string | `I select the "Workflow" app type` |
|
||||
| `{int}` | Integer | `I should see {int} items` |
|
||||
| `{float}` | Decimal | `the progress is {float} percent` |
|
||||
| `{word}` | Single word | `I click the {word} tab` |
|
||||
|
||||
Prefer `{string}` for UI labels, names, and text content — it maps naturally to Gherkin's quoted values.
|
||||
|
||||
### Scoping locators
|
||||
|
||||
When the page has multiple similar elements, scope locators to a container:
|
||||
|
||||
```typescript
|
||||
When('I fill in the app name in the dialog', async function (this: DifyWorld) {
|
||||
const dialog = this.getPage().getByRole('dialog')
|
||||
await dialog.getByPlaceholder('Give your app a name').fill('My App')
|
||||
})
|
||||
```
|
||||
|
||||
### Failure diagnostics
|
||||
|
||||
The `After` hook automatically captures on failure:
|
||||
|
||||
- Full-page screenshot (PNG)
|
||||
- Page HTML dump
|
||||
- Console errors and page errors
|
||||
|
||||
Artifacts are saved to `cucumber-report/artifacts/` and attached to the HTML report. No extra code needed in step definitions.
|
||||
|
||||
## Reusing existing steps
|
||||
|
||||
Before writing a new step definition, inspect the existing step definition files first. Reuse a matching step when the wording and behavior already fit, and only add a new step when the scenario needs a genuinely new user action or assertion. Steps in `common/` are designed for broad reuse across all features.
|
||||
|
||||
Or browse the step definition files directly:
|
||||
|
||||
- `features/step-definitions/common/` — auth guards and navigation assertions shared by all features
|
||||
- `features/step-definitions/<capability>/` — domain-specific steps scoped to a single feature area
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user