Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox

# Conflicts:
#	api/core/app/apps/workflow/app_generator.py
This commit is contained in:
yyh
2026-01-19 16:32:43 +08:00
81 changed files with 8164 additions and 1193 deletions

View File

@@ -71,6 +71,8 @@ def create_app() -> DifyApp:
def initialize_extensions(app: DifyApp):
# Initialize Flask context capture for workflow execution
from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
@@ -100,6 +102,8 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
)
init_flask_context()
extensions = [
ext_timezone,
ext_logging,

View File

@@ -862,8 +862,27 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
@click.option(
"--before-days",
"--days",
default=30,
show_default=True,
type=click.IntRange(min=0),
help="Delete workflow runs created before N days ago.",
)
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
@click.option(
"--from-days-ago",
default=None,
type=click.IntRange(min=0),
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
)
@click.option(
"--to-days-ago",
default=None,
type=click.IntRange(min=0),
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
)
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
@@ -882,8 +901,10 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
help="Preview cleanup results without deleting any workflow run data.",
)
def clean_workflow_runs(
days: int,
before_days: int,
batch_size: int,
from_days_ago: int | None,
to_days_ago: int | None,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
@@ -894,11 +915,24 @@ def clean_workflow_runs(
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
if (from_days_ago is None) ^ (to_days_ago is None):
raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
if from_days_ago is not None and to_days_ago is not None:
if start_from or end_before:
raise click.UsageError("Choose either day offsets or explicit dates, not both.")
if from_days_ago <= to_days_ago:
raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
now = datetime.datetime.now()
start_from = now - datetime.timedelta(days=from_days_ago)
end_before = now - datetime.timedelta(days=to_days_ago)
before_days = 0
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=days,
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,

74
api/context/__init__.py Normal file
View File

@@ -0,0 +1,74 @@
"""
Core Context - Framework-agnostic context management.
This module provides context management that is independent of any specific
web framework. Framework-specific implementations register their context
capture functions at application initialization time.
This ensures the workflow layer remains completely decoupled from Flask
or any other web framework.
"""
import contextvars
from collections.abc import Callable
from core.workflow.context.execution_context import (
ExecutionContext,
IExecutionContext,
NullAppContext,
)
# Global capturer function - set by framework-specific modules
_capturer: Callable[[], IExecutionContext] | None = None
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""
Register a context capture function.
This should be called by framework-specific modules (e.g., Flask)
during application initialization.
Args:
capturer: Function that captures current context and returns IExecutionContext
"""
global _capturer
_capturer = capturer
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context.
This function uses the registered context capturer. If no capturer
is registered, it returns a minimal context with only contextvars
(suitable for non-framework environments like tests or standalone scripts).
Returns:
IExecutionContext with captured context
"""
if _capturer is None:
# No framework registered - return minimal context
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
def reset_context_provider() -> None:
"""
Reset the context capturer.
This is primarily useful for testing to ensure a clean state.
"""
global _capturer
_capturer = None
__all__ = [
"capture_current_context",
"register_context_capturer",
"reset_context_provider",
]

View File

@@ -0,0 +1,198 @@
"""
Flask App Context - Flask implementation of AppContext interface.
"""
import contextvars
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final
from flask import Flask, current_app, g
from context import register_context_capturer
from core.workflow.context.execution_context import (
AppContext,
IExecutionContext,
)
@final
class FlaskAppContext(AppContext):
"""
Flask implementation of AppContext.
This adapts Flask's app context to the AppContext interface.
"""
def __init__(self, flask_app: Flask) -> None:
"""
Initialize Flask app context.
Args:
flask_app: The Flask application instance
"""
self._flask_app = flask_app
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default)
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name."""
return self._flask_app.extensions.get(name)
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context."""
with self._flask_app.app_context():
yield
@property
def flask_app(self) -> Flask:
"""Get the underlying Flask app instance."""
return self._flask_app
def capture_flask_context(user: Any = None) -> IExecutionContext:
"""
Capture current Flask execution context.
This function captures the Flask app context and contextvars from the
current environment. It should be called from within a Flask request or
app context.
Args:
user: Optional user object to include in context
Returns:
IExecutionContext with captured Flask context
Raises:
RuntimeError: If called outside Flask context
"""
# Get Flask app instance
flask_app = current_app._get_current_object() # type: ignore
# Save current user if available
saved_user = user
if saved_user is None:
# Check for user in g (flask-login)
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Capture contextvars
context_vars = contextvars.copy_context()
return FlaskExecutionContext(
flask_app=flask_app,
context_vars=context_vars,
user=saved_user,
)
@final
class FlaskExecutionContext:
"""
Flask-specific execution context.
This is a specialized version of ExecutionContext that includes Flask app
context. It provides the same interface as ExecutionContext but with
Flask-specific implementation.
"""
def __init__(
self,
flask_app: Flask,
context_vars: contextvars.Context,
user: Any = None,
) -> None:
"""
Initialize Flask execution context.
Args:
flask_app: Flask application instance
context_vars: Python contextvars
user: Optional user object
"""
self._app_context = FlaskAppContext(flask_app)
self._context_vars = context_vars
self._user = user
self._flask_app = flask_app
@property
def app_context(self) -> FlaskAppContext:
"""Get Flask app context."""
return self._app_context
@property
def context_vars(self) -> contextvars.Context:
"""Get context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get user object."""
return self._user
def __enter__(self) -> "FlaskExecutionContext":
"""Enter the Flask execution context."""
# Restore context variables
for var, val in self._context_vars.items():
var.set(val)
# Save current user from g if available
saved_user = None
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
self._cm = self._app_context.enter()
self._cm.__enter__()
# Restore user in new app context
if saved_user is not None:
g._login_user = saved_user
return self
def __exit__(self, *args: Any) -> None:
"""Exit the Flask execution context."""
if hasattr(self, "_cm"):
self._cm.__exit__(*args)
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter Flask execution context as context manager."""
# Restore context variables
for var, val in self._context_vars.items():
var.set(val)
# Save current user from g if available
saved_user = None
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
with self._flask_app.app_context():
# Restore user in new app context
if saved_user is not None:
g._login_user = saved_user
yield
def init_flask_context() -> None:
"""
Initialize Flask context capture by registering the capturer.
This function should be called during Flask application initialization
to register the Flask-specific context capturer with the core context module.
Example:
app = Flask(__name__)
init_flask_context() # Register Flask context capturer
Note:
This function does not need the app instance as it uses Flask's
`current_app` to get the app when capturing context.
"""
register_context_capturer(capture_flask_context)

View File

@@ -1,4 +1,3 @@
import re
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@@ -68,48 +67,6 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
# XSS prevention: patterns that could lead to XSS attacks
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
_XSS_PATTERNS = [
r"<script[^>]*>.*?</script>", # Script tags
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
r"javascript:", # JavaScript protocol
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
r"<embed[^>]*>", # Embed tags (self-closing)
r"<link[^>]*>", # Link tags with javascript
]
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
"""
Validate that a string value doesn't contain potential XSS payloads.
Args:
value: The string value to validate
field_name: Name of the field for error messages
Returns:
The original value if safe
Raises:
ValueError: If the value contains XSS patterns
"""
if value is None:
return None
value_lower = value.lower()
for pattern in _XSS_PATTERNS:
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
raise ValueError(
f"{field_name} contains invalid characters or patterns. "
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
)
return value
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@@ -118,11 +75,6 @@ class CreateAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
@@ -133,11 +85,6 @@ class UpdateAppPayload(BaseModel):
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
@@ -146,11 +93,6 @@ class CopyAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@@ -69,6 +69,13 @@ class ActivateCheckApi(Resource):
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
# Check workspace permission
if tenant:
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None

View File

@@ -107,6 +107,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(inviter.current_tenant.id)
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL

View File

@@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@@ -28,6 +29,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@@ -288,3 +290,31 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/current/permission")
class WorkspacePermissionApi(Resource):
"""Get workspace permissions for the current workspace."""
@setup_required
@login_required
@account_initialization_required
@only_edition_enterprise
def get(self):
"""
Get workspace permission settings.
Returns permission flags that control workspace features like member invitations and owner transfer.
"""
_, current_tenant_id = current_account_with_tenant()
if not current_tenant_id:
raise ValueError("No current tenant")
# Get workspace permissions from enterprise service
permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
return {
"workspace_id": permission.workspace_id,
"allow_member_invite": permission.allow_member_invite,
"allow_owner_transfer": permission.allow_owner_transfer,
}, 200

View File

@@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
from libs.workspace_permission import check_workspace_owner_transfer_permission
# otherwise, return 403
abort(403)
_, current_tenant_id = current_account_with_tenant()
# Check both billing/plan level and workspace policy level permissions
check_workspace_owner_transfer_permission(current_tenant_id)
return view(*args, **kwargs)
return decorated

View File

@@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -24,6 +24,7 @@ from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTas
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.sandbox_layer import SandboxLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@@ -479,7 +480,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,

View File

@@ -320,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
raise InvokeRateLimitError(description=args.get("description"))
raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
raise InvokeAuthorizationError(description=args.get("description"))
raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
raise InvokeBadRequestError(description=args.get("description"))
raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
raise InvokeConnectionError(description=args.get("description"))
raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
raise InvokeServerUnavailableError(description=args.get("description"))
raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -339,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
raise TriggerPluginInvokeError(description=error_object.get("description"))
raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
raise EventIgnoreError(description=error_object.get("description"))
raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:

View File

@@ -5,7 +5,6 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from core.db.session_factory import session_factory
@@ -29,6 +28,21 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
def _try_resolve_user_from_request() -> Account | EndUser | None:
"""
Try to resolve user from Flask request context.
Returns None if not in a request context or if user is not available.
"""
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
# Use _get_current_object() to dereference the proxy
user = getattr(current_user, "_get_current_object", lambda: current_user)()
# Check if we got a valid user object
if user is not None and hasattr(user, "id"):
return user
return None
class WorkflowTool(Tool):
"""
Workflow tool.
@@ -209,21 +223,13 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
if has_request_context():
return self._resolve_user_from_request()
else:
return self._resolve_user_from_database(user_id=user_id)
# Try to resolve user from request context first
user = _try_resolve_user_from_request()
if user is not None:
return user
def _resolve_user_from_request(self) -> Account | EndUser | None:
"""
Resolve user from Flask request context.
"""
try:
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
return getattr(current_user, "_get_current_object", lambda: current_user)()
except Exception as e:
logger.warning("Failed to resolve user from request context: %s", e)
return None
# Fall back to database resolution
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""

View File

@@ -0,0 +1,22 @@
"""
Execution Context - Context management for workflow execution.
This package provides Flask-independent context management for workflow
execution in multi-threaded environments.
"""
from core.workflow.context.execution_context import (
AppContext,
ExecutionContext,
IExecutionContext,
NullAppContext,
capture_current_context,
)
__all__ = [
"AppContext",
"ExecutionContext",
"IExecutionContext",
"NullAppContext",
"capture_current_context",
]

View File

@@ -0,0 +1,216 @@
"""
Execution Context - Abstracted context management for workflow execution.
"""
import contextvars
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
class AppContext(ABC):
"""
Abstract application context interface.
This abstraction allows workflow execution to work with or without Flask
by providing a common interface for application context management.
"""
@abstractmethod
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
pass
@abstractmethod
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name (e.g., 'db', 'cache')."""
pass
@abstractmethod
def enter(self) -> AbstractContextManager[None]:
"""Enter the application context."""
pass
@runtime_checkable
class IExecutionContext(Protocol):
"""
Protocol for execution context.
This protocol defines the interface that all execution contexts must implement,
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
"""
def __enter__(self) -> "IExecutionContext":
"""Enter the execution context."""
...
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
...
@property
def user(self) -> Any:
"""Get user object."""
...
@final
class ExecutionContext:
"""
Execution context for workflow execution in worker threads.
This class encapsulates all context needed for workflow execution:
- Application context (Flask app or standalone)
- Context variables for Python contextvars
- User information (optional)
It is designed to be serializable and passable to worker threads.
"""
def __init__(
self,
app_context: AppContext | None = None,
context_vars: contextvars.Context | None = None,
user: Any = None,
) -> None:
"""
Initialize execution context.
Args:
app_context: Application context (Flask or standalone)
context_vars: Python contextvars to preserve
user: User object (optional)
"""
self._app_context = app_context
self._context_vars = context_vars
self._user = user
@property
def app_context(self) -> AppContext | None:
"""Get application context."""
return self._app_context
@property
def context_vars(self) -> contextvars.Context | None:
"""Get context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get user object."""
return self._user
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""
Enter this execution context.
This is a convenience method that creates a context manager.
"""
# Restore context variables if provided
if self._context_vars:
for var, val in self._context_vars.items():
var.set(val)
# Enter app context if available
if self._app_context is not None:
with self._app_context.enter():
yield
else:
yield
def __enter__(self) -> "ExecutionContext":
"""Enter the execution context."""
self._cm = self.enter()
self._cm.__enter__()
return self
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
if hasattr(self, "_cm"):
self._cm.__exit__(*args)
class NullAppContext(AppContext):
"""
Null implementation of AppContext for non-Flask environments.
This is used when running without Flask (e.g., in tests or standalone mode).
"""
def __init__(self, config: dict[str, Any] | None = None) -> None:
"""
Initialize null app context.
Args:
config: Optional configuration dictionary
"""
self._config = config or {}
self._extensions: dict[str, Any] = {}
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
def set_extension(self, name: str, extension: Any) -> None:
"""Set extension by name."""
self._extensions[name] = extension
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield
class ExecutionContextBuilder:
"""
Builder for creating ExecutionContext instances.
This provides a fluent API for building execution contexts.
"""
def __init__(self) -> None:
self._app_context: AppContext | None = None
self._context_vars: contextvars.Context | None = None
self._user: Any = None
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
"""Set application context."""
self._app_context = app_context
return self
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
"""Set context variables."""
self._context_vars = context_vars
return self
def with_user(self, user: Any) -> "ExecutionContextBuilder":
"""Set user."""
self._user = user
return self
def build(self) -> ExecutionContext:
"""Build the execution context."""
return ExecutionContext(
app_context=self._app_context,
context_vars=self._context_vars,
user=self._user,
)
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context from the calling environment.
Returns:
IExecutionContext with captured context
"""
from context import capture_current_context
return capture_current_context()

View File

@@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability.
from __future__ import annotations
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from flask import Flask, current_app
from core.workflow.context import capture_current_context
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -159,17 +157,8 @@ class GraphEngine:
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
try:
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass
# Capture context variables for worker threads
context_vars = contextvars.copy_context()
# Capture execution context for worker threads
execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@@ -177,8 +166,7 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
execution_context=execution_context,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,

View File

@@ -5,26 +5,27 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from typing import TYPE_CHECKING, final
from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
if TYPE_CHECKING:
pass
@final
class Worker(threading.Thread):
@@ -44,8 +45,7 @@ class Worker(threading.Thread):
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@@ -56,19 +56,17 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._last_task_time = time.time()
self._execution_context = execution_context
self._stop_event = stop_event
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
@@ -135,11 +133,9 @@ class Worker(threading.Thread):
error: Exception | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node with preserved context if execution context is provided
if self._execution_context is not None:
with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()

View File

@@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
from typing import TYPE_CHECKING, final
from typing import final
from configs import dify_config
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@@ -20,11 +21,6 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class WorkerPool:
@@ -42,8 +38,7 @@ class WorkerPool:
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@@ -57,8 +52,7 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@@ -67,8 +61,7 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._execution_context = execution_context
self._layers = layers
# Scaling parameters with defaults
@@ -152,8 +145,7 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
execution_context=self._execution_context,
stop_event=self._stop_event,
)

View File

@@ -1,11 +1,9 @@
import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -51,6 +48,7 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
flask_app=current_app._get_current_object(), # type: ignore
context_vars=contextvars.copy_context(),
execution_context=self._capture_execution_context(),
)
future_to_index[future] = index
@@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
execution_context: "IExecutionContext",
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
def _capture_execution_context(self) -> "IExecutionContext":
"""Capture current execution context for parallel iterations."""
from core.workflow.context import capture_current_context
return capture_current_context()
def _handle_iteration_success(
self,
started_at: datetime,

View File

@@ -0,0 +1,74 @@
"""
Workspace permission helper functions.
These helpers check both billing/plan level and workspace-specific policy level permissions.
Checks are performed at two levels:
1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
"""
import logging
from werkzeug.exceptions import Forbidden
from configs import dify_config
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
def check_workspace_member_invite_permission(workspace_id: str) -> None:
"""
Check if workspace allows member invitations at both billing and policy levels.
Checks performed:
1. Billing/plan level - For future expansion (currently no plan-level restriction)
2. Enterprise policy level - Admin-configured workspace permission
Args:
workspace_id: The workspace ID to check permissions for
Raises:
Forbidden: If either billing plan or workspace policy prohibits member invitations
"""
# Check enterprise workspace policy level (only if enterprise enabled)
if dify_config.ENTERPRISE_ENABLED:
try:
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
if not permission.allow_member_invite:
raise Forbidden("Workspace policy prohibits member invitations")
except Forbidden:
raise
except Exception:
logger.exception("Failed to check workspace invite permission for %s", workspace_id)
def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
"""
Check if workspace allows owner transfer at both billing and policy levels.
Checks performed:
1. Billing/plan level - SANDBOX plan blocks owner transfer
2. Enterprise policy level - Admin-configured workspace permission
Args:
workspace_id: The workspace ID to check permissions for
Raises:
Forbidden: If either billing plan or workspace policy prohibits ownership transfer
"""
features = FeatureService.get_features(workspace_id)
if not features.is_allow_transfer_workspace:
raise Forbidden("Your current plan does not allow workspace ownership transfer")
# Check enterprise workspace policy level (only if enterprise enabled)
if dify_config.ENTERPRISE_ENABLED:
try:
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
if not permission.allow_owner_transfer:
raise Forbidden("Workspace policy prohibits ownership transfer")
except Forbidden:
raise
except Exception:
logger.exception("Failed to check workspace transfer permission for %s", workspace_id)

View File

@@ -0,0 +1,35 @@
"""change workflow node execution workflow_run index
Revision ID: 288345cd01d1
Revises: 3334862ee907
Create Date: 2026-01-16 17:15:00.000000
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "288345cd01d1"
down_revision = "3334862ee907"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
batch_op.drop_index("workflow_node_execution_workflow_run_idx")
batch_op.create_index(
"workflow_node_execution_workflow_run_id_idx",
["workflow_run_id"],
unique=False,
)
def downgrade():
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
batch_op.create_index(
"workflow_node_execution_workflow_run_idx",
["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
unique=False,
)

View File

@@ -820,11 +820,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(

View File

@@ -13,6 +13,8 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
from sqlalchemy.orm import Session
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecutionModel
@@ -130,6 +132,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
"""
Count node executions and offloads for the given workflow run ids.
"""
...
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
"""
Delete node executions and offloads for the given workflow run ids.
"""
...
def delete_executions_by_app(
self,
tenant_id: str,

View File

@@ -7,17 +7,15 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
from collections.abc import Sequence
from datetime import datetime
from typing import TypedDict, cast
from typing import cast
from sqlalchemy import asc, delete, desc, func, select, tuple_
from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import (
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowNodeExecutionTriggeredFrom,
)
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -49,26 +47,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
"""
self._session_maker = session_maker
@staticmethod
def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str:
"""
Map workflow run triggered_from values to workflow node execution triggered_from values.
"""
if triggered_from in {
WorkflowRunTriggeredFrom.APP_RUN.value,
WorkflowRunTriggeredFrom.DEBUGGING.value,
WorkflowRunTriggeredFrom.SCHEDULE.value,
WorkflowRunTriggeredFrom.PLUGIN.value,
WorkflowRunTriggeredFrom.WEBHOOK.value,
}:
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
if triggered_from in {
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
}:
return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value
return ""
def get_node_last_execution(
self,
tenant_id: str,
@@ -316,51 +294,16 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
session.commit()
return result.rowcount
class RunContext(TypedDict):
run_id: str
tenant_id: str
app_id: str
workflow_id: str
triggered_from: str
@staticmethod
def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
"""
Delete node executions (and offloads) for the given workflow runs using indexed columns.
Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id)
by filtering on those columns with tuple IN.
Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
"""
if not runs:
if not run_ids:
return 0, 0
tuple_values = [
(
run["tenant_id"],
run["app_id"],
run["workflow_id"],
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
run["triggered_from"]
),
run["run_id"],
)
for run in runs
]
node_execution_ids = session.scalars(
select(WorkflowNodeExecutionModel.id).where(
tuple_(
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.workflow_id,
WorkflowNodeExecutionModel.triggered_from,
WorkflowNodeExecutionModel.workflow_run_id,
).in_(tuple_values)
)
).all()
if not node_execution_ids:
return 0, 0
run_ids = list(run_ids)
run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
offloads_deleted = (
cast(
@@ -377,55 +320,32 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
node_executions_deleted = (
cast(
CursorResult,
session.execute(
delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
),
session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
).rowcount
or 0
)
return node_executions_deleted, offloads_deleted
@staticmethod
def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
"""
Count node executions (and offloads) for the given workflow runs using indexed columns.
Count node executions (and offloads) for the given workflow runs using workflow_run_id.
"""
if not runs:
if not run_ids:
return 0, 0
tuple_values = [
(
run["tenant_id"],
run["app_id"],
run["workflow_id"],
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
run["triggered_from"]
),
run["run_id"],
)
for run in runs
]
tuple_filter = tuple_(
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.workflow_id,
WorkflowNodeExecutionModel.triggered_from,
WorkflowNodeExecutionModel.workflow_run_id,
).in_(tuple_values)
run_ids = list(run_ids)
run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
node_executions_count = (
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
)
node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
offloads_count = (
session.scalar(
select(func.count())
.select_from(WorkflowNodeExecutionOffload)
.join(
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id,
)
.where(tuple_filter)
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
)
or 0
)

View File

@@ -1381,6 +1381,11 @@ class RegisterService:
normalized_email = email.lower()
"""Invite new member"""
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)

View File

@@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
class WorkspacePermission(BaseModel):
workspace_id: str = Field(
description="The ID of the workspace.",
alias="workspaceId",
)
allow_member_invite: bool = Field(
description="Whether to allow members to invite new members to the workspace.",
default=False,
alias="allowMemberInvite",
)
allow_owner_transfer: bool = Field(
description="Whether to allow owners to transfer ownership of the workspace.",
default=False,
alias="allowOwnerTransfer",
)
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WorkspacePermissionService:
@classmethod
def get_permission(cls, workspace_id: str):
if not workspace_id:
raise ValueError("workspace_id must be provided.")
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
if not data or "permission" not in data:
raise ValueError("No data found.")
return WorkspacePermission.model_validate(data["permission"])
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):

View File

@@ -10,9 +10,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.billing_service import BillingService, SubscriptionPlan
@@ -92,9 +90,12 @@ class WorkflowRunCleanup:
paid_or_skipped = len(run_rows) - len(free_runs)
if not free_runs:
skipped_message = (
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
)
click.echo(
click.style(
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
skipped_message,
fg="yellow",
)
)
@@ -255,21 +256,6 @@ class WorkflowRunCleanup:
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
return trigger_repo.count_by_run_ids(run_ids)
@staticmethod
def _build_run_contexts(
runs: Sequence[WorkflowRun],
) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]:
return [
{
"run_id": run.id,
"tenant_id": run.tenant_id,
"app_id": run.app_id,
"workflow_id": run.workflow_id,
"triggered_from": run.triggered_from,
}
for run in runs
]
@staticmethod
def _empty_related_counts() -> dict[str, int]:
return {
@@ -293,9 +279,15 @@ class WorkflowRunCleanup:
)
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_contexts = self._build_run_contexts(runs)
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts)
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
)
return repo.count_by_runs(session, run_ids)
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_contexts = self._build_run_contexts(runs)
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts)
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
)
return repo.delete_by_runs(session, run_ids)

View File

@@ -83,7 +83,30 @@
<p class="content1">Dear {{ to }},</p>
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
<p class="content2">Click the button below to log in to Dify and join the workspace.</p>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Login Here</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p class="content2">Best regards,</p>
<p class="content2">Dify Team</p>
</div>

View File

@@ -83,7 +83,30 @@
<p class="content1">尊敬的 {{ to }}</p>
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p class="content2">点击下方按钮即可登录 Dify 并且加入空间。</p>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">在此登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p class="content2">此致,</p>
<p class="content2">Dify 团队</p>
</div>

View File

@@ -115,7 +115,30 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here: </p>
<a href="{{ login_url }}" class="button">Log In</a>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Log In</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<p class="description">
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
class="reset-btn">Reset Password</a>

View File

@@ -115,7 +115,30 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录: </p>
<a href="{{ login_url }}" class="button">登录</a>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<p class="description">
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
</p>

View File

@@ -92,12 +92,34 @@
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
create, and collaborate to build and operate AI applications.</p>
<p class="content2">Click the button below to log in to {{application_title}} and join the workspace.</p>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none"
class="button" href="{{ url }}">Login Here</a></p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Login Here</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p class="content2">Best regards,</p>
<p class="content2">{{application_title}} Team</p>
</div>
</div>
</body>
</html>
</html>

View File

@@ -81,7 +81,30 @@
<p class="content1">尊敬的 {{ to }}</p>
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p class="content2">点击下方按钮即可登录 {{application_title}} 并且加入空间。</p>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">在此登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p class="content2">此致,</p>
<p class="content2">{{application_title}} 团队</p>
</div>

View File

@@ -111,7 +111,30 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here: </p>
<a href="{{ login_url }}" class="button">Log In</a>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Log In</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<p class="description">
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
class="reset-btn">Reset Password</a>

View File

@@ -111,7 +111,30 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录: </p>
<a href="{{ login_url }}" class="button">登录</a>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<p class="description">
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
</p>

View File

@@ -1,254 +0,0 @@
"""
Unit tests for XSS prevention in App payloads.
This test module validates that HTML tags, JavaScript, and other potentially
dangerous content are rejected in App names and descriptions.
"""
import pytest
from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload
class TestXSSPreventionUnit:
"""Unit tests for XSS prevention in App payloads."""
def test_create_app_valid_names(self):
"""Test CreateAppPayload with valid app names."""
# Normal app names should be valid
valid_names = [
"My App",
"Test App 123",
"App with - dash",
"App with _ underscore",
"App with + plus",
"App with () parentheses",
"App with [] brackets",
"App with {} braces",
"App with ! exclamation",
"App with @ at",
"App with # hash",
"App with $ dollar",
"App with % percent",
"App with ^ caret",
"App with & ampersand",
"App with * asterisk",
"Unicode: 测试应用",
"Emoji: 🤖",
"Mixed: Test 测试 123",
]
for name in valid_names:
payload = CreateAppPayload(
name=name,
mode="chat",
)
assert payload.name == name
def test_create_app_xss_script_tags(self):
"""Test CreateAppPayload rejects script tags."""
xss_payloads = [
"<script>alert(document.cookie)</script>",
"<Script>alert(1)</Script>",
"<SCRIPT>alert('XSS')</SCRIPT>",
"<script>alert(String.fromCharCode(88,83,83))</script>",
"<script src='evil.js'></script>",
"<script>document.location='http://evil.com'</script>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_iframe_tags(self):
"""Test CreateAppPayload rejects iframe tags."""
xss_payloads = [
"<iframe src='evil.com'></iframe>",
"<Iframe srcdoc='<script>alert(1)</script>'></iframe>",
"<IFRAME src='javascript:alert(1)'></iframe>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_javascript_protocol(self):
"""Test CreateAppPayload rejects javascript: protocol."""
xss_payloads = [
"javascript:alert(1)",
"JAVASCRIPT:alert(1)",
"JavaScript:alert(document.cookie)",
"javascript:void(0)",
"javascript://comment%0Aalert(1)",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_svg_onload(self):
"""Test CreateAppPayload rejects SVG with onload."""
xss_payloads = [
"<svg onload=alert(1)>",
"<SVG ONLOAD=alert(1)>",
"<svg/x/onload=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_event_handlers(self):
"""Test CreateAppPayload rejects HTML event handlers."""
xss_payloads = [
"<div onclick=alert(1)>",
"<img onerror=alert(1)>",
"<body onload=alert(1)>",
"<input onfocus=alert(1)>",
"<a onmouseover=alert(1)>",
"<DIV ONCLICK=alert(1)>",
"<img src=x onerror=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_object_embed(self):
"""Test CreateAppPayload rejects object and embed tags."""
xss_payloads = [
"<object data='evil.swf'></object>",
"<embed src='evil.swf'>",
"<OBJECT data='javascript:alert(1)'></OBJECT>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_link_javascript(self):
"""Test CreateAppPayload rejects link tags with javascript."""
xss_payloads = [
"<link href='javascript:alert(1)'>",
"<LINK HREF='javascript:alert(1)'>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_in_description(self):
"""Test CreateAppPayload rejects XSS in description."""
xss_descriptions = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for description in xss_descriptions:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(
name="Valid Name",
mode="chat",
description=description,
)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_valid_descriptions(self):
"""Test CreateAppPayload with valid descriptions."""
valid_descriptions = [
"A simple description",
"Description with < and > symbols",
"Description with & ampersand",
"Description with 'quotes' and \"double quotes\"",
"Description with / slashes",
"Description with \\ backslashes",
"Description with ; semicolons",
"Unicode: 这是一个描述",
"Emoji: 🎉🚀",
]
for description in valid_descriptions:
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=description,
)
assert payload.description == description
def test_create_app_none_description(self):
"""Test CreateAppPayload with None description."""
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=None,
)
assert payload.description is None
def test_update_app_xss_prevention(self):
"""Test UpdateAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
UpdateAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_update_app_valid_names(self):
"""Test UpdateAppPayload with valid names."""
payload = UpdateAppPayload(name="Valid Updated Name")
assert payload.name == "Valid Updated Name"
def test_copy_app_xss_prevention(self):
"""Test CopyAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
CopyAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_copy_app_valid_names(self):
"""Test CopyAppPayload with valid names."""
payload = CopyAppPayload(name="Valid Copy Name")
assert payload.name == "Valid Copy Name"
def test_copy_app_none_name(self):
"""Test CopyAppPayload with None name (should be allowed)."""
payload = CopyAppPayload(name=None)
assert payload.name is None
def test_edge_case_angle_brackets_content(self):
"""Test that angle brackets with actual content are rejected."""
# Angle brackets without valid HTML-like patterns should be checked
# The regex pattern <.*?on\w+\s*= should catch event handlers
# But let's verify other patterns too
# Valid: angle brackets used as symbols (not matched by our patterns)
# Our patterns specifically look for dangerous constructs
# Invalid: actual HTML tags with event handlers
invalid_names = [
"<div onclick=xss>",
"<img src=x onerror=alert(1)>",
]
for name in invalid_names:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()

View File

@@ -346,6 +346,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeRateLimitError",
"message": "Rate limit exceeded",
"args": {"description": "Rate limit exceeded"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -364,6 +365,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeAuthorizationError",
"message": "Invalid credentials",
"args": {"description": "Invalid credentials"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -382,6 +384,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeBadRequestError",
"message": "Invalid parameters",
"args": {"description": "Invalid parameters"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -400,6 +403,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeConnectionError",
"message": "Connection to external service failed",
"args": {"description": "Connection to external service failed"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -418,6 +422,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeServerUnavailableError",
"message": "Service temporarily unavailable",
"args": {"description": "Service temporarily unavailable"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})

View File

@@ -0,0 +1 @@
"""Tests for workflow context management."""

View File

@@ -0,0 +1,258 @@
"""Tests for execution context module."""
import contextvars
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.workflow.context.execution_context import (
AppContext,
ExecutionContext,
ExecutionContextBuilder,
IExecutionContext,
NullAppContext,
)
class TestAppContext:
"""Test AppContext abstract base class."""
def test_app_context_is_abstract(self):
"""Test that AppContext cannot be instantiated directly."""
with pytest.raises(TypeError):
AppContext() # type: ignore
class TestNullAppContext:
"""Test NullAppContext implementation."""
def test_null_app_context_get_config(self):
"""Test get_config returns value from config dict."""
config = {"key1": "value1", "key2": "value2"}
ctx = NullAppContext(config=config)
assert ctx.get_config("key1") == "value1"
assert ctx.get_config("key2") == "value2"
def test_null_app_context_get_config_default(self):
"""Test get_config returns default when key not found."""
ctx = NullAppContext()
assert ctx.get_config("nonexistent", "default") == "default"
assert ctx.get_config("nonexistent") is None
def test_null_app_context_get_extension(self):
"""Test get_extension returns stored extension."""
ctx = NullAppContext()
extension = MagicMock()
ctx.set_extension("db", extension)
assert ctx.get_extension("db") == extension
def test_null_app_context_get_extension_not_found(self):
"""Test get_extension returns None when extension not found."""
ctx = NullAppContext()
assert ctx.get_extension("nonexistent") is None
def test_null_app_context_enter_yield(self):
"""Test enter method yields without any side effects."""
ctx = NullAppContext()
with ctx.enter():
# Should not raise any exception
pass
class TestExecutionContext:
"""Test ExecutionContext class."""
def test_initialization_with_all_params(self):
"""Test ExecutionContext initialization with all parameters."""
app_ctx = NullAppContext()
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = ExecutionContext(
app_context=app_ctx,
context_vars=context_vars,
user=user,
)
assert ctx.app_context == app_ctx
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_initialization_with_minimal_params(self):
"""Test ExecutionContext initialization with minimal parameters."""
ctx = ExecutionContext()
assert ctx.app_context is None
assert ctx.context_vars is None
assert ctx.user is None
def test_enter_with_context_vars(self):
"""Test enter restores context variables."""
test_var = contextvars.ContextVar("test_var")
test_var.set("original_value")
# Copy context with the variable
context_vars = contextvars.copy_context()
# Change the variable
test_var.set("new_value")
# Create execution context and enter it
ctx = ExecutionContext(context_vars=context_vars)
with ctx.enter():
# Variable should be restored to original value
assert test_var.get() == "original_value"
# After exiting, variable stays at the value from within the context
# (this is expected Python contextvars behavior)
assert test_var.get() == "original_value"
def test_enter_with_app_context(self):
"""Test enter enters app context if available."""
app_ctx = NullAppContext()
ctx = ExecutionContext(app_context=app_ctx)
# Should not raise any exception
with ctx.enter():
pass
def test_enter_without_app_context(self):
"""Test enter works without app context."""
ctx = ExecutionContext(app_context=None)
# Should not raise any exception
with ctx.enter():
pass
def test_context_manager_protocol(self):
"""Test ExecutionContext supports context manager protocol."""
ctx = ExecutionContext()
with ctx:
# Should not raise any exception
pass
def test_user_property(self):
"""Test user property returns set user."""
user = MagicMock()
ctx = ExecutionContext(user=user)
assert ctx.user == user
class TestIExecutionContextProtocol:
"""Test IExecutionContext protocol."""
def test_execution_context_implements_protocol(self):
"""Test that ExecutionContext implements IExecutionContext protocol."""
ctx = ExecutionContext()
# Should have __enter__ and __exit__ methods
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
assert hasattr(ctx, "user")
def test_protocol_compatibility(self):
"""Test that ExecutionContext can be used where IExecutionContext is expected."""
def accept_context(context: IExecutionContext) -> Any:
"""Function that accepts IExecutionContext protocol."""
# Just verify it has the required protocol attributes
assert hasattr(context, "__enter__")
assert hasattr(context, "__exit__")
assert hasattr(context, "user")
return context.user
ctx = ExecutionContext(user="test_user")
result = accept_context(ctx)
assert result == "test_user"
def test_protocol_with_flask_execution_context(self):
"""Test that IExecutionContext protocol is compatible with different implementations."""
# Verify the protocol works with ExecutionContext
ctx = ExecutionContext(user="test_user")
# Should have the required protocol attributes
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
assert hasattr(ctx, "user")
assert ctx.user == "test_user"
# Should work as context manager
with ctx:
assert ctx.user == "test_user"
class TestExecutionContextBuilder:
"""Test ExecutionContextBuilder class."""
def test_builder_with_all_params(self):
"""Test builder with all parameters set."""
app_ctx = NullAppContext()
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = (
ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build()
)
assert ctx.app_context == app_ctx
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_builder_with_partial_params(self):
"""Test builder with only some parameters set."""
app_ctx = NullAppContext()
ctx = ExecutionContextBuilder().with_app_context(app_ctx).build()
assert ctx.app_context == app_ctx
assert ctx.context_vars is None
assert ctx.user is None
def test_builder_fluent_interface(self):
"""Test builder provides fluent interface."""
builder = ExecutionContextBuilder()
# Each method should return the builder
assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder)
assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder)
assert isinstance(builder.with_user(None), ExecutionContextBuilder)
class TestCaptureCurrentContext:
"""Test capture_current_context function."""
def test_capture_current_context_returns_context(self):
"""Test that capture_current_context returns a valid context."""
from core.workflow.context.execution_context import capture_current_context
result = capture_current_context()
# Should return an object that implements IExecutionContext
assert hasattr(result, "__enter__")
assert hasattr(result, "__exit__")
assert hasattr(result, "user")
def test_capture_current_context_captures_contextvars(self):
"""Test that capture_current_context captures context variables."""
# Set a context variable before capturing
import contextvars
test_var = contextvars.ContextVar("capture_test_var")
test_var.set("test_value_123")
from core.workflow.context.execution_context import capture_current_context
result = capture_current_context()
# Context variables should be captured
assert result.context_vars is not None

View File

@@ -0,0 +1,316 @@
"""Tests for Flask app context module."""
import contextvars
from unittest.mock import MagicMock, patch
import pytest
class TestFlaskAppContext:
"""Test FlaskAppContext implementation."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app."""
app = MagicMock()
app.config = {"TEST_KEY": "test_value"}
app.extensions = {"db": MagicMock(), "cache": MagicMock()}
app.app_context = MagicMock()
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return app
def test_flask_app_context_initialization(self, mock_flask_app):
"""Test FlaskAppContext initialization."""
# Import here to avoid Flask dependency in test environment
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.flask_app == mock_flask_app
def test_flask_app_context_get_config(self, mock_flask_app):
"""Test get_config returns Flask app config value."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_config("TEST_KEY") == "test_value"
def test_flask_app_context_get_config_default(self, mock_flask_app):
"""Test get_config returns default when key not found."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_config("NONEXISTENT", "default") == "default"
def test_flask_app_context_get_extension(self, mock_flask_app):
"""Test get_extension returns Flask extension."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
db_ext = mock_flask_app.extensions["db"]
assert ctx.get_extension("db") == db_ext
def test_flask_app_context_get_extension_not_found(self, mock_flask_app):
"""Test get_extension returns None when extension not found."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_extension("nonexistent") is None
def test_flask_app_context_enter(self, mock_flask_app):
"""Test enter method enters Flask app context."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
with ctx.enter():
# Should not raise any exception
pass
# Verify app_context was called
mock_flask_app.app_context.assert_called_once()
class TestFlaskExecutionContext:
"""Test FlaskExecutionContext class."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app."""
app = MagicMock()
app.config = {}
app.app_context = MagicMock()
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return app
def test_initialization(self, mock_flask_app):
"""Test FlaskExecutionContext initialization."""
from context.flask_app_context import FlaskExecutionContext
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=context_vars,
user=user,
)
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_app_context_property(self, mock_flask_app):
"""Test app_context property returns FlaskAppContext."""
from context.flask_app_context import FlaskAppContext, FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
assert isinstance(ctx.app_context, FlaskAppContext)
assert ctx.app_context.flask_app == mock_flask_app
def test_context_manager_protocol(self, mock_flask_app):
"""Test FlaskExecutionContext supports context manager protocol."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
# Should have __enter__ and __exit__ methods
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
# Should work as context manager
with ctx:
pass
class TestCaptureFlaskContext:
"""Test capture_flask_context function."""
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
"""Test capture_flask_context captures Flask app."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
assert ctx._flask_app == mock_app
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
"""Test capture_flask_context captures user from Flask g object."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
mock_user = MagicMock()
mock_user.id = "user_123"
mock_g._login_user = mock_user
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
assert ctx.user == mock_user
@patch("context.flask_app_context.current_app")
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
"""Test capture_flask_context uses explicit user parameter."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
explicit_user = MagicMock()
explicit_user.id = "user_456"
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context(user=explicit_user)
assert ctx.user == explicit_user
@patch("context.flask_app_context.current_app")
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
"""Test capture_flask_context captures context variables."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
# Set a context variable
test_var = contextvars.ContextVar("test_var")
test_var.set("test_value")
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
# Context variables should be captured
assert ctx.context_vars is not None
# Verify the variable is in the captured context
captured_value = ctx.context_vars[test_var]
assert captured_value == "test_value"
class TestFlaskExecutionContextIntegration:
"""Integration tests for FlaskExecutionContext."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app with proper app context."""
app = MagicMock()
app.config = {"TEST": "value"}
app.extensions = {"db": MagicMock()}
# Mock app context
mock_app_context = MagicMock()
mock_app_context.__enter__ = MagicMock(return_value=None)
mock_app_context.__exit__ = MagicMock(return_value=None)
app.app_context.return_value = mock_app_context
return app
def test_enter_restores_context_vars(self, mock_flask_app):
"""Test that enter restores captured context variables."""
# Create a context variable and set a value
test_var = contextvars.ContextVar("integration_test_var")
test_var.set("original_value")
# Capture the context
context_vars = contextvars.copy_context()
# Change the value
test_var.set("new_value")
# Create FlaskExecutionContext and enter it
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=context_vars,
)
with ctx:
# Value should be restored to original
assert test_var.get() == "original_value"
# After exiting, variable stays at the value from within the context
# (this is expected Python contextvars behavior)
assert test_var.get() == "original_value"
def test_enter_enters_flask_app_context(self, mock_flask_app):
"""Test that enter enters Flask app context."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
with ctx:
# Verify app context was entered
assert mock_flask_app.app_context.called
@patch("context.flask_app_context.g")
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
"""Test that enter restores user in Flask g object."""
mock_user = MagicMock()
mock_user.id = "test_user"
# Note: FlaskExecutionContext saves user from g before entering context,
# then restores it after entering the app context.
# The user passed to constructor is NOT restored to g.
# So we need to test the actual behavior.
# Create FlaskExecutionContext with user in constructor
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
user=mock_user,
)
# Set user in g before entering (simulating existing user in g)
mock_g._login_user = mock_user
with ctx:
# After entering, the user from g before entry should be restored
assert mock_g._login_user == mock_user
# The user in constructor is stored but not automatically restored to g
# (it's available via ctx.user property)
assert ctx.user == mock_user
def test_enter_method_as_context_manager(self, mock_flask_app):
"""Test enter method returns a proper context manager."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
# enter() should return a generator/context manager
with ctx.enter():
# Should work without issues
pass
# Verify app context was called
assert mock_flask_app.app_context.called

View File

@@ -0,0 +1,142 @@
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.workspace_permission import (
check_workspace_member_invite_permission,
check_workspace_owner_transfer_permission,
)
class TestWorkspacePermissionHelper:
"""Test workspace permission helper functions."""
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.EnterpriseService")
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
"""Community edition should always allow invitations without calling any service."""
mock_config.ENTERPRISE_ENABLED = False
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
# EnterpriseService should NOT be called in community edition
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
"""Community edition should check billing plan but not call enterprise service."""
mock_config.ENTERPRISE_ENABLED = False
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True
mock_feature_service.get_features.return_value = mock_features
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should block invitations when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = False
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should allow invitations when workspace policy is True."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = True
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
mock_feature_service.get_features.return_value = mock_features
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
# Enterprise service should NOT be called since billing plan already blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
"""Enterprise edition should block transfer when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = False # Workspace policy blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_allows_transfer_when_both_enabled(
self, mock_feature_service, mock_config, mock_enterprise_service
):
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = True # Workspace policy allows
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.logger")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
"""On enterprise service error, should fail-open (allow) and log error."""
mock_config.ENTERPRISE_ENABLED = True
# Simulate enterprise service error
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
# Should not raise (fail-open)
check_workspace_member_invite_permission("test-workspace-id")
# Should log the error
mock_logger.exception.assert_called_once()
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)

View File

@@ -0,0 +1,101 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { RETRIEVE_METHOD } from '@/types/app'
import EconomicalRetrievalMethodConfig from './index'
// Mock dependencies
vi.mock('../../settings/option-card', () => ({
default: ({ children, title, description, disabled, id }: {
children?: React.ReactNode
title?: string
description?: React.ReactNode
disabled?: boolean
id?: string
}) => (
<div data-testid="option-card" data-title={title} data-id={id} data-disabled={disabled}>
<div>{description}</div>
{children}
</div>
),
}))
vi.mock('../retrieval-param-config', () => ({
default: ({ value, onChange, type }: {
value: Record<string, unknown>
onChange: (value: Record<string, unknown>) => void
type?: string
}) => (
<div data-testid="retrieval-param-config" data-type={type}>
<button onClick={() => onChange({ ...value, newProp: 'changed' })}>
Change Value
</button>
</div>
),
}))
vi.mock('@/app/components/base/icons/src/vender/knowledge', () => ({
VectorSearch: () => <svg data-testid="vector-search-icon" />,
}))
describe('EconomicalRetrievalMethodConfig', () => {
const mockOnChange = vi.fn()
const defaultProps = {
value: {
search_method: RETRIEVE_METHOD.keywordSearch,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 2,
score_threshold_enabled: false,
score_threshold: 0.5,
},
onChange: mockOnChange,
}
beforeEach(() => {
vi.clearAllMocks()
})
it('should render correctly', () => {
render(<EconomicalRetrievalMethodConfig {...defaultProps} />)
expect(screen.getByTestId('option-card')).toBeInTheDocument()
expect(screen.getByTestId('retrieval-param-config')).toBeInTheDocument()
// Check if title and description are rendered (mocked i18n returns key)
expect(screen.getByText('dataset.retrieval.keyword_search.description')).toBeInTheDocument()
})
it('should pass correct props to OptionCard', () => {
render(<EconomicalRetrievalMethodConfig {...defaultProps} disabled={true} />)
const card = screen.getByTestId('option-card')
expect(card).toHaveAttribute('data-disabled', 'true')
expect(card).toHaveAttribute('data-id', RETRIEVE_METHOD.keywordSearch)
})
it('should pass correct props to RetrievalParamConfig', () => {
render(<EconomicalRetrievalMethodConfig {...defaultProps} />)
const config = screen.getByTestId('retrieval-param-config')
expect(config).toHaveAttribute('data-type', RETRIEVE_METHOD.keywordSearch)
})
it('should handle onChange events', () => {
render(<EconomicalRetrievalMethodConfig {...defaultProps} />)
fireEvent.click(screen.getByText('Change Value'))
expect(mockOnChange).toHaveBeenCalledTimes(1)
expect(mockOnChange).toHaveBeenCalledWith({
...defaultProps.value,
newProp: 'changed',
})
})
it('should default disabled prop to false', () => {
render(<EconomicalRetrievalMethodConfig {...defaultProps} />)
const card = screen.getByTestId('option-card')
expect(card).toHaveAttribute('data-disabled', 'false')
})
})

View File

@@ -0,0 +1,148 @@
import type { ReactNode } from 'react'
import { render, screen } from '@testing-library/react'
import { RETRIEVE_METHOD } from '@/types/app'
import { retrievalIcon } from '../../create/icons'
import RetrievalMethodInfo, { getIcon } from './index'
// Mock next/image
vi.mock('next/image', () => ({
default: ({ src, alt, className }: { src: string, alt: string, className?: string }) => (
<img src={src} alt={alt || ''} className={className} data-testid="method-icon" />
),
}))
// Mock RadioCard
vi.mock('@/app/components/base/radio-card', () => ({
default: ({ title, description, chosenConfig, icon }: { title: string, description: string, chosenConfig: ReactNode, icon: ReactNode }) => (
<div data-testid="radio-card">
<div data-testid="card-title">{title}</div>
<div data-testid="card-description">{description}</div>
<div data-testid="card-icon">{icon}</div>
<div data-testid="chosen-config">{chosenConfig}</div>
</div>
),
}))
// Mock icons
vi.mock('../../create/icons', () => ({
retrievalIcon: {
vector: 'vector-icon.png',
fullText: 'fulltext-icon.png',
hybrid: 'hybrid-icon.png',
},
}))
describe('RetrievalMethodInfo', () => {
const defaultConfig = {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: 'test-provider',
reranking_model_name: 'test-model',
},
top_k: 5,
score_threshold_enabled: true,
score_threshold: 0.8,
}
beforeEach(() => {
vi.clearAllMocks()
})
it('should render correctly with full config', () => {
render(<RetrievalMethodInfo value={defaultConfig} />)
expect(screen.getByTestId('radio-card')).toBeInTheDocument()
// Check Title & Description (mocked i18n returns key prefixed with ns)
expect(screen.getByTestId('card-title')).toHaveTextContent('dataset.retrieval.semantic_search.title')
expect(screen.getByTestId('card-description')).toHaveTextContent('dataset.retrieval.semantic_search.description')
// Check Icon
const icon = screen.getByTestId('method-icon')
expect(icon).toHaveAttribute('src', 'vector-icon.png')
// Check Config Details
expect(screen.getByText('test-model')).toBeInTheDocument() // Rerank model
expect(screen.getByText('5')).toBeInTheDocument() // Top K
expect(screen.getByText('0.8')).toBeInTheDocument() // Score threshold
})
it('should not render reranking model if missing', () => {
const configWithoutRerank = {
...defaultConfig,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
}
render(<RetrievalMethodInfo value={configWithoutRerank} />)
expect(screen.queryByText('test-model')).not.toBeInTheDocument()
// Other fields should still be there
expect(screen.getByText('5')).toBeInTheDocument()
})
it('should handle different retrieval methods', () => {
// Test Hybrid
const hybridConfig = { ...defaultConfig, search_method: RETRIEVE_METHOD.hybrid }
const { unmount } = render(<RetrievalMethodInfo value={hybridConfig} />)
expect(screen.getByTestId('card-title')).toHaveTextContent('dataset.retrieval.hybrid_search.title')
expect(screen.getByTestId('method-icon')).toHaveAttribute('src', 'hybrid-icon.png')
unmount()
// Test FullText
const fullTextConfig = { ...defaultConfig, search_method: RETRIEVE_METHOD.fullText }
render(<RetrievalMethodInfo value={fullTextConfig} />)
expect(screen.getByTestId('card-title')).toHaveTextContent('dataset.retrieval.full_text_search.title')
expect(screen.getByTestId('method-icon')).toHaveAttribute('src', 'fulltext-icon.png')
})
describe('getIcon utility', () => {
it('should return correct icon for each type', () => {
expect(getIcon(RETRIEVE_METHOD.semantic)).toBe(retrievalIcon.vector)
expect(getIcon(RETRIEVE_METHOD.fullText)).toBe(retrievalIcon.fullText)
expect(getIcon(RETRIEVE_METHOD.hybrid)).toBe(retrievalIcon.hybrid)
expect(getIcon(RETRIEVE_METHOD.invertedIndex)).toBe(retrievalIcon.vector)
expect(getIcon(RETRIEVE_METHOD.keywordSearch)).toBe(retrievalIcon.vector)
})
it('should return default vector icon for unknown type', () => {
// Test fallback branch when type is not in the mapping
const unknownType = 'unknown_method' as RETRIEVE_METHOD
expect(getIcon(unknownType)).toBe(retrievalIcon.vector)
})
})
it('should not render score threshold if disabled', () => {
const configWithoutScoreThreshold = {
...defaultConfig,
score_threshold_enabled: false,
score_threshold: 0,
}
render(<RetrievalMethodInfo value={configWithoutScoreThreshold} />)
// score_threshold is still rendered but may be undefined
expect(screen.queryByText('0.8')).not.toBeInTheDocument()
})
it('should render correctly with invertedIndex search method', () => {
const invertedIndexConfig = { ...defaultConfig, search_method: RETRIEVE_METHOD.invertedIndex }
render(<RetrievalMethodInfo value={invertedIndexConfig} />)
// invertedIndex uses vector icon
expect(screen.getByTestId('method-icon')).toHaveAttribute('src', 'vector-icon.png')
})
it('should render correctly with keywordSearch search method', () => {
const keywordSearchConfig = { ...defaultConfig, search_method: RETRIEVE_METHOD.keywordSearch }
render(<RetrievalMethodInfo value={keywordSearchConfig} />)
// keywordSearch uses vector icon
expect(screen.getByTestId('method-icon')).toHaveAttribute('src', 'vector-icon.png')
})
})

View File

@@ -0,0 +1,46 @@
import { render, screen } from '@testing-library/react'
import EmbeddingSkeleton from './index'
// Mock Skeleton components
vi.mock('@/app/components/base/skeleton', () => ({
SkeletonContainer: ({ children }: { children?: React.ReactNode }) => <div data-testid="skeleton-container">{children}</div>,
SkeletonPoint: () => <div data-testid="skeleton-point" />,
SkeletonRectangle: () => <div data-testid="skeleton-rectangle" />,
SkeletonRow: ({ children }: { children?: React.ReactNode }) => <div data-testid="skeleton-row">{children}</div>,
}))
// Mock Divider
vi.mock('@/app/components/base/divider', () => ({
default: () => <div data-testid="divider" />,
}))
describe('EmbeddingSkeleton', () => {
it('should render correct number of skeletons', () => {
render(<EmbeddingSkeleton />)
// It renders 5 CardSkeletons. Each CardSkelton has multiple SkeletonContainers.
// Let's count the number of main wrapper divs (loop is 5)
// Each iteration renders a CardSkeleton and potentially a Divider.
// The component structure is:
// div.relative...
// div.absolute... (mask)
// map(5) -> div.w-full.px-11 -> CardSkelton + Divider (except last?)
// Actually the code says `index !== 9`, but the loop is length 5.
// So `index` goes 0..4. All are !== 9. So 5 dividers should be rendered.
expect(screen.getAllByTestId('divider')).toHaveLength(5)
// Just ensure it renders without crashing and contains skeleton elements
expect(screen.getAllByTestId('skeleton-container').length).toBeGreaterThan(0)
expect(screen.getAllByTestId('skeleton-rectangle').length).toBeGreaterThan(0)
})
it('should render the mask overlay', () => {
const { container } = render(<EmbeddingSkeleton />)
// Check for the absolute positioned mask
const mask = container.querySelector('.bg-dataset-chunk-list-mask-bg')
expect(mask).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,92 @@
import { RiArrowRightUpLine, RiBookOpenLine } from '@remixicon/react'
import Link from 'next/link'
import * as React from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import Switch from '@/app/components/base/switch'
import Indicator from '@/app/components/header/indicator'
import { useSelector as useAppContextSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url'
import { useDisableDatasetServiceApi, useEnableDatasetServiceApi } from '@/service/knowledge/use-dataset'
import { cn } from '@/utils/classnames'
type CardProps = {
apiEnabled: boolean
}
const Card = ({
apiEnabled,
}: CardProps) => {
const { t } = useTranslation()
const datasetId = useDatasetDetailContextWithSelector(state => state.dataset?.id)
const mutateDatasetRes = useDatasetDetailContextWithSelector(state => state.mutateDatasetRes)
const { mutateAsync: enableDatasetServiceApi } = useEnableDatasetServiceApi()
const { mutateAsync: disableDatasetServiceApi } = useDisableDatasetServiceApi()
const isCurrentWorkspaceManager = useAppContextSelector(state => state.isCurrentWorkspaceManager)
const apiReferenceUrl = useDatasetApiAccessUrl()
const onToggle = useCallback(async (state: boolean) => {
let result: 'success' | 'fail'
if (state)
result = (await enableDatasetServiceApi(datasetId ?? '')).result
else
result = (await disableDatasetServiceApi(datasetId ?? '')).result
if (result === 'success')
mutateDatasetRes?.()
}, [datasetId, enableDatasetServiceApi, mutateDatasetRes, disableDatasetServiceApi])
return (
<div className="w-[208px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg">
<div className="p-1">
<div className="p-2">
<div className="mb-1.5 flex justify-between">
<div className="flex items-center gap-1">
<Indicator
className="shrink-0"
color={apiEnabled ? 'green' : 'yellow'}
/>
<div
className={cn(
'system-xs-semibold-uppercase',
apiEnabled ? 'text-text-success' : 'text-text-warning',
)}
>
{apiEnabled
? t('serviceApi.enabled', { ns: 'dataset' })
: t('serviceApi.disabled', { ns: 'dataset' })}
</div>
</div>
<Switch
defaultValue={apiEnabled}
onChange={onToggle}
disabled={!isCurrentWorkspaceManager}
/>
</div>
<div className="system-xs-regular text-text-tertiary">
{t('appMenus.apiAccessTip', { ns: 'common' })}
</div>
</div>
</div>
<div className="h-px bg-divider-subtle"></div>
<div className="p-1">
<Link
href={apiReferenceUrl}
target="_blank"
rel="noopener noreferrer"
className="flex h-8 items-center space-x-[7px] rounded-lg px-2 text-text-tertiary hover:bg-state-base-hover"
>
<RiBookOpenLine className="size-3.5 shrink-0" />
<div className="system-sm-regular grow truncate">
{t('overview.apiInfo.doc', { ns: 'appOverview' })}
</div>
<RiArrowRightUpLine className="size-3.5 shrink-0" />
</Link>
</div>
</div>
)
}
export default React.memo(Card)

View File

@@ -0,0 +1,65 @@
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ApiAggregate } from '@/app/components/base/icons/src/vender/knowledge'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
import Indicator from '@/app/components/header/indicator'
import { cn } from '@/utils/classnames'
import Card from './card'
type ApiAccessProps = {
expand: boolean
apiEnabled: boolean
}
const ApiAccess = ({
expand,
apiEnabled,
}: ApiAccessProps) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const handleToggle = () => {
setOpen(!open)
}
return (
<div className="p-3 pt-2">
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement="top-start"
offset={{
mainAxis: 4,
crossAxis: -4,
}}
>
<PortalToFollowElemTrigger
className="w-full"
onClick={handleToggle}
>
<div className={cn(
'relative flex h-8 cursor-pointer items-center gap-2 rounded-lg border border-components-panel-border px-3',
!expand && 'w-8 justify-center',
open ? 'bg-state-base-hover' : 'hover:bg-state-base-hover',
)}
>
<ApiAggregate className="size-4 shrink-0 text-text-secondary" />
{expand && <div className="system-sm-medium grow text-text-secondary">{t('appMenus.apiAccess', { ns: 'common' })}</div>}
<Indicator
className={cn('shrink-0', !expand && 'absolute -right-px -top-px')}
color={apiEnabled ? 'green' : 'yellow'}
/>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-[10]">
<Card
apiEnabled={apiEnabled}
/>
</PortalToFollowElemContent>
</PortalToFollowElem>
</div>
)
}
export default React.memo(ApiAccess)

View File

@@ -1,8 +1,7 @@
import type { RelatedAppResponse } from '@/models/datasets'
import * as React from 'react'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDatasetApiBaseUrl } from '@/service/knowledge/use-dataset'
import ServiceApi from './service-api'
import ApiAccess from './api-access'
import Statistics from './statistics'
type IExtraInfoProps = {
@@ -17,7 +16,6 @@ const ExtraInfo = ({
expand,
}: IExtraInfoProps) => {
const apiEnabled = useDatasetDetailContextWithSelector(state => state.dataset?.enable_api)
const { data: apiBaseInfo } = useDatasetApiBaseUrl()
return (
<>
@@ -28,9 +26,8 @@ const ExtraInfo = ({
relatedApps={relatedApps}
/>
)}
<ServiceApi
<ApiAccess
expand={expand}
apiBaseUrl={apiBaseInfo?.api_base_url ?? ''}
apiEnabled={apiEnabled ?? false}
/>
</>

View File

@@ -6,45 +6,22 @@ import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import CopyFeedback from '@/app/components/base/copy-feedback'
import { ApiAggregate } from '@/app/components/base/icons/src/vender/knowledge'
import Switch from '@/app/components/base/switch'
import SecretKeyModal from '@/app/components/develop/secret-key/secret-key-modal'
import Indicator from '@/app/components/header/indicator'
import { useSelector as useAppContextSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url'
import { useDisableDatasetServiceApi, useEnableDatasetServiceApi } from '@/service/knowledge/use-dataset'
import { cn } from '@/utils/classnames'
type CardProps = {
apiEnabled: boolean
apiBaseUrl: string
}
const Card = ({
apiEnabled,
apiBaseUrl,
}: CardProps) => {
const { t } = useTranslation()
const datasetId = useDatasetDetailContextWithSelector(state => state.dataset?.id)
const mutateDatasetRes = useDatasetDetailContextWithSelector(state => state.mutateDatasetRes)
const { mutateAsync: enableDatasetServiceApi } = useEnableDatasetServiceApi()
const { mutateAsync: disableDatasetServiceApi } = useDisableDatasetServiceApi()
const [isSecretKeyModalVisible, setIsSecretKeyModalVisible] = useState(false)
const isCurrentWorkspaceManager = useAppContextSelector(state => state.isCurrentWorkspaceManager)
const apiReferenceUrl = useDatasetApiAccessUrl()
const onToggle = useCallback(async (state: boolean) => {
let result: 'success' | 'fail'
if (state)
result = (await enableDatasetServiceApi(datasetId ?? '')).result
else
result = (await disableDatasetServiceApi(datasetId ?? '')).result
if (result === 'success')
mutateDatasetRes?.()
}, [datasetId, enableDatasetServiceApi, disableDatasetServiceApi])
const handleOpenSecretKeyModal = useCallback(() => {
setIsSecretKeyModalVisible(true)
}, [])
@@ -68,24 +45,16 @@ const Card = ({
<div className="flex items-center gap-x-1">
<Indicator
className="shrink-0"
color={apiEnabled ? 'green' : 'yellow'}
color={
apiBaseUrl ? 'green' : 'yellow'
}
/>
<div
className={cn(
'system-xs-semibold-uppercase',
apiEnabled ? 'text-text-success' : 'text-text-warning',
)}
className="system-xs-semibold-uppercase text-text-success"
>
{apiEnabled
? t('serviceApi.enabled', { ns: 'dataset' })
: t('serviceApi.disabled', { ns: 'dataset' })}
{t('serviceApi.enabled', { ns: 'dataset' })}
</div>
</div>
<Switch
defaultValue={apiEnabled}
onChange={onToggle}
disabled={!isCurrentWorkspaceManager}
/>
</div>
<div className="flex flex-col">
<div className="system-xs-regular leading-6 text-text-tertiary">

View File

@@ -1,22 +1,17 @@
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ApiAggregate } from '@/app/components/base/icons/src/vender/knowledge'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
import Indicator from '@/app/components/header/indicator'
import { cn } from '@/utils/classnames'
import Card from './card'
type ServiceApiProps = {
expand: boolean
apiBaseUrl: string
apiEnabled: boolean
}
const ServiceApi = ({
expand,
apiBaseUrl,
apiEnabled,
}: ServiceApiProps) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
@@ -26,7 +21,7 @@ const ServiceApi = ({
}
return (
<div className="p-3 pt-2">
<div>
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
@@ -41,22 +36,21 @@ const ServiceApi = ({
onClick={handleToggle}
>
<div className={cn(
'relative flex h-8 cursor-pointer items-center gap-2 rounded-lg border border-components-panel-border px-3',
!expand && 'w-8 justify-center',
open ? 'bg-state-base-hover' : 'hover:bg-state-base-hover',
'relative flex h-8 cursor-pointer items-center gap-2 rounded-lg border-[0.5px] border-components-button-secondary-border-hover bg-components-button-secondary-bg px-3',
open ? 'bg-components-button-secondary-bg-hover' : 'hover:bg-components-button-secondary-bg-hover',
)}
>
<ApiAggregate className="size-4 shrink-0 text-text-secondary" />
{expand && <div className="system-sm-medium grow text-text-secondary">{t('serviceApi.title', { ns: 'dataset' })}</div>}
<Indicator
className={cn('shrink-0', !expand && 'absolute -right-px -top-px')}
color={apiEnabled ? 'green' : 'yellow'}
className={cn('shrink-0')}
color={
apiBaseUrl ? 'green' : 'yellow'
}
/>
<div className="system-sm-medium grow text-text-secondary">{t('serviceApi.title', { ns: 'dataset' })}</div>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-[10]">
<Card
apiEnabled={apiEnabled}
apiBaseUrl={apiBaseUrl}
/>
</PortalToFollowElemContent>

View File

@@ -0,0 +1,30 @@
import { render, screen } from '@testing-library/react'
import DatasetFooter from './index'
describe('DatasetFooter', () => {
it('should render correctly', () => {
render(<DatasetFooter />)
// Check main title (mocked i18n returns ns:key or key)
// The code uses t('didYouKnow', { ns: 'dataset' })
// With default mock it likely returns 'dataset.didYouKnow'
expect(screen.getByText('dataset.didYouKnow')).toBeInTheDocument()
// Check paragraph content
expect(screen.getByText(/dataset.intro1/)).toBeInTheDocument()
expect(screen.getByText(/dataset.intro2/)).toBeInTheDocument()
expect(screen.getByText(/dataset.intro3/)).toBeInTheDocument()
expect(screen.getByText(/dataset.intro4/)).toBeInTheDocument()
expect(screen.getByText(/dataset.intro5/)).toBeInTheDocument()
expect(screen.getByText(/dataset.intro6/)).toBeInTheDocument()
})
it('should have correct styling', () => {
const { container } = render(<DatasetFooter />)
const footer = container.querySelector('footer')
expect(footer).toHaveClass('shrink-0', 'px-12', 'py-6')
const h3 = container.querySelector('h3')
expect(h3).toHaveClass('text-gradient')
})
})

View File

@@ -14,13 +14,14 @@ import TagFilter from '@/app/components/base/tag-management/filter'
// Hooks
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
import { useAppContext } from '@/context/app-context'
import { useAppContext, useSelector as useAppContextSelector } from '@/context/app-context'
import { useExternalApiPanel } from '@/context/external-api-panel-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import { useDatasetApiBaseUrl } from '@/service/knowledge/use-dataset'
// Components
import ExternalAPIPanel from '../external-api/external-api-panel'
import ServiceApi from '../extra-info/service-api'
import DatasetFooter from './dataset-footer'
import Datasets from './datasets'
@@ -58,6 +59,9 @@ const List = () => {
return router.replace('/apps')
}, [currentWorkspace, router])
const isCurrentWorkspaceManager = useAppContextSelector(state => state.isCurrentWorkspaceManager)
const { data: apiBaseInfo } = useDatasetApiBaseUrl()
return (
<div className="scroll-container relative flex grow flex-col overflow-y-auto bg-background-body">
<div className="sticky top-0 z-10 flex items-center justify-end gap-x-1 bg-background-body px-12 pb-2 pt-4">
@@ -81,6 +85,11 @@ const List = () => {
onChange={e => handleKeywordsChange(e.target.value)}
onClear={() => handleKeywordsChange('')}
/>
{
isCurrentWorkspaceManager && (
<ServiceApi apiBaseUrl={apiBaseInfo?.api_base_url ?? ''} />
)
}
<div className="h-4 w-[1px] bg-divider-regular" />
<Button
className="shadows-shadow-xs gap-0.5"
@@ -96,7 +105,6 @@ const List = () => {
{showTagManagementModal && (
<TagManagementModal type="knowledge" show={showTagManagementModal} />
)}
{showExternalApiPanel && <ExternalAPIPanel onClose={() => setShowExternalApiPanel(false)} />}
</div>
)

View File

@@ -0,0 +1,49 @@
import { render, screen } from '@testing-library/react'
import NewDatasetCard from './index'
type MockOptionProps = {
text: string
href: string
}
// Mock dependencies
vi.mock('./option', () => ({
default: ({ text, href }: MockOptionProps) => (
<a data-testid="option-link" href={href}>
{text}
</a>
),
}))
vi.mock('@remixicon/react', () => ({
RiAddLine: () => <svg data-testid="icon-add" />,
RiFunctionAddLine: () => <svg data-testid="icon-function" />,
}))
vi.mock('@/app/components/base/icons/src/vender/solid/development', () => ({
ApiConnectionMod: () => <svg data-testid="icon-api" />,
}))
describe('NewDatasetCard', () => {
it('should render all options', () => {
render(<NewDatasetCard />)
const options = screen.getAllByTestId('option-link')
expect(options).toHaveLength(3)
// Check first option (Create Dataset)
const createDataset = options[0]
expect(createDataset).toHaveAttribute('href', '/datasets/create')
expect(createDataset).toHaveTextContent('dataset.createDataset')
// Check second option (Create from Pipeline)
const createFromPipeline = options[1]
expect(createFromPipeline).toHaveAttribute('href', '/datasets/create-from-pipeline')
expect(createFromPipeline).toHaveTextContent('dataset.createFromPipeline')
// Check third option (Connect Dataset)
const connectDataset = options[2]
expect(connectDataset).toHaveAttribute('href', '/datasets/connect')
expect(connectDataset).toHaveTextContent('dataset.connectDataset')
})
})

View File

@@ -0,0 +1,85 @@
import { render, screen } from '@testing-library/react'
import { ChunkingMode } from '@/models/datasets'
import ChunkStructure from './index'
type MockOptionCardProps = {
id: string
title: string
isActive?: boolean
disabled?: boolean
}
// Mock dependencies
vi.mock('../option-card', () => ({
default: ({ id, title, isActive, disabled }: MockOptionCardProps) => (
<div
data-testid="option-card"
data-id={id}
data-active={isActive}
data-disabled={disabled}
>
{title}
</div>
),
}))
// Mock hook
vi.mock('./hooks', () => ({
useChunkStructure: () => ({
options: [
{
id: ChunkingMode.text,
title: 'General',
description: 'General description',
icon: <svg />,
effectColor: 'indigo',
iconActiveColor: 'indigo',
},
{
id: ChunkingMode.parentChild,
title: 'Parent-Child',
description: 'PC description',
icon: <svg />,
effectColor: 'blue',
iconActiveColor: 'blue',
},
],
}),
}))
describe('ChunkStructure', () => {
it('should render all options', () => {
render(<ChunkStructure chunkStructure={ChunkingMode.text} />)
const options = screen.getAllByTestId('option-card')
expect(options).toHaveLength(2)
expect(options[0]).toHaveTextContent('General')
expect(options[1]).toHaveTextContent('Parent-Child')
})
it('should set active state correctly', () => {
// Render with 'text' active
const { unmount } = render(<ChunkStructure chunkStructure={ChunkingMode.text} />)
const options = screen.getAllByTestId('option-card')
expect(options[0]).toHaveAttribute('data-active', 'true')
expect(options[1]).toHaveAttribute('data-active', 'false')
unmount()
// Render with 'parentChild' active
render(<ChunkStructure chunkStructure={ChunkingMode.parentChild} />)
const newOptions = screen.getAllByTestId('option-card')
expect(newOptions[0]).toHaveAttribute('data-active', 'false')
expect(newOptions[1]).toHaveAttribute('data-active', 'true')
})
it('should be always disabled', () => {
render(<ChunkStructure chunkStructure={ChunkingMode.text} />)
const options = screen.getAllByTestId('option-card')
options.forEach((option) => {
expect(option).toHaveAttribute('data-disabled', 'true')
})
})
})

View File

@@ -1,10 +1,9 @@
'use client'
import type { InvitationResult } from '@/models/common'
import { RiPencilLine, RiUserAddLine } from '@remixicon/react'
import { RiPencilLine } from '@remixicon/react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Avatar from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Tooltip from '@/app/components/base/tooltip'
import { NUM_INFINITE } from '@/app/components/billing/config'
import { Plan } from '@/app/components/billing/type'
@@ -16,8 +15,8 @@ import { useProviderContext } from '@/context/provider-context'
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
import { LanguagesSupported } from '@/i18n-config/language'
import { useMembers } from '@/service/use-common'
import { cn } from '@/utils/classnames'
import EditWorkspaceModal from './edit-workspace-modal'
import InviteButton from './invite-button'
import InviteModal from './invite-modal'
import InvitedModal from './invited-modal'
import Operation from './operation'
@@ -37,7 +36,7 @@ const MembersPage = () => {
const { userProfile, currentWorkspace, isCurrentWorkspaceOwner, isCurrentWorkspaceManager } = useAppContext()
const { data, refetch } = useMembers()
const { systemFeatures } = useGlobalPublicStore()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const { formatTimeFromNow } = useFormatTimeFromNow()
const [inviteModalVisible, setInviteModalVisible] = useState(false)
const [invitationResults, setInvitationResults] = useState<InvitationResult[]>([])
@@ -104,10 +103,9 @@ const MembersPage = () => {
{isMemberFull && (
<UpgradeBtn className="mr-2" loc="member-invite" />
)}
<Button variant="primary" className={cn('shrink-0')} disabled={!isCurrentWorkspaceManager || isMemberFull} onClick={() => setInviteModalVisible(true)}>
<RiUserAddLine className="mr-1 h-4 w-4" />
{t('members.invite', { ns: 'common' })}
</Button>
<div className="shrink-0">
<InviteButton disabled={!isCurrentWorkspaceManager || isMemberFull} onClick={() => setInviteModalVisible(true)} />
</div>
</div>
<div className="overflow-visible lg:overflow-visible">
<div className="flex min-w-[480px] items-center border-b border-divider-regular py-[7px]">

View File

@@ -0,0 +1,34 @@
import { RiUserAddLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useWorkspacePermissions } from '@/service/use-workspace'
type InviteButtonProps = {
disabled?: boolean
onClick?: () => void
}
const InviteButton = (props: InviteButtonProps) => {
const { t } = useTranslation()
const { currentWorkspace } = useAppContext()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const { data: workspacePermissions, isFetching: isFetchingWorkspacePermissions } = useWorkspacePermissions(currentWorkspace!.id, systemFeatures.branding.enabled)
if (systemFeatures.branding.enabled) {
if (isFetchingWorkspacePermissions) {
return <Loading />
}
if (!workspacePermissions || workspacePermissions.allow_member_invite !== true) {
return null
}
}
return (
<Button variant="primary" {...props}>
<RiUserAddLine className="mr-1 h-4 w-4" />
{t('members.invite', { ns: 'common' })}
</Button>
)
}
export default InviteButton

View File

@@ -5,6 +5,10 @@ import {
} from '@remixicon/react'
import { Fragment } from 'react'
import { useTranslation } from 'react-i18next'
import Loading from '@/app/components/base/loading'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useWorkspacePermissions } from '@/service/use-workspace'
import { cn } from '@/utils/classnames'
type Props = {
@@ -13,6 +17,17 @@ type Props = {
const TransferOwnership = ({ onOperate }: Props) => {
const { t } = useTranslation()
const { currentWorkspace } = useAppContext()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const { data: workspacePermissions, isFetching: isFetchingWorkspacePermissions } = useWorkspacePermissions(currentWorkspace!.id, systemFeatures.branding.enabled)
if (systemFeatures.branding.enabled) {
if (isFetchingWorkspacePermissions) {
return <Loading />
}
if (!workspacePermissions || workspacePermissions.allow_owner_transfer !== true) {
return <span className="system-sm-regular px-3 text-text-secondary">{t('members.owner', { ns: 'common' })}</span>
}
}
return (
<Menu as="div" className="relative h-full w-full">

View File

@@ -0,0 +1,130 @@
import type { PluginDetail } from '@/app/components/plugins/types'
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import ActionList from './action-list'
// Mock dependencies
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.num !== undefined)
return `${options.num} ${options.action || 'actions'}`
return key
},
}),
}))
const mockToolData = [
{ name: 'tool-1', label: { en_US: 'Tool 1' } },
{ name: 'tool-2', label: { en_US: 'Tool 2' } },
]
const mockProvider = {
name: 'test-plugin/test-tool',
type: 'builtin',
}
vi.mock('@/service/use-tools', () => ({
useAllToolProviders: () => ({ data: [mockProvider] }),
useBuiltinTools: (key: string) => ({
data: key ? mockToolData : undefined,
}),
}))
vi.mock('@/app/components/tools/provider/tool-item', () => ({
default: ({ tool }: { tool: { name: string } }) => (
<div data-testid="tool-item">{tool.name}</div>
),
}))
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
tool: {
identity: {
author: 'test-author',
name: 'test-tool',
description: { en_US: 'Test' },
icon: 'icon.png',
label: { en_US: 'Test Tool' },
tags: [],
},
credentials_schema: [],
},
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
...overrides,
})
describe('ActionList', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render tool items when data is available', () => {
const detail = createPluginDetail()
render(<ActionList detail={detail} />)
expect(screen.getByText('2 actions')).toBeInTheDocument()
expect(screen.getAllByTestId('tool-item')).toHaveLength(2)
})
it('should render tool names', () => {
const detail = createPluginDetail()
render(<ActionList detail={detail} />)
expect(screen.getByText('tool-1')).toBeInTheDocument()
expect(screen.getByText('tool-2')).toBeInTheDocument()
})
it('should return null when no tool declaration', () => {
const detail = createPluginDetail({
declaration: {} as PluginDetail['declaration'],
})
const { container } = render(<ActionList detail={detail} />)
expect(container).toBeEmptyDOMElement()
})
it('should return null when providerKey is empty', () => {
const detail = createPluginDetail({
declaration: {
tool: {
identity: undefined,
},
} as unknown as PluginDetail['declaration'],
})
const { container } = render(<ActionList detail={detail} />)
expect(container).toBeEmptyDOMElement()
})
})
describe('Props', () => {
it('should use plugin_id in provider key construction', () => {
const detail = createPluginDetail()
render(<ActionList detail={detail} />)
// The provider key is constructed from plugin_id and tool identity name
// When they match the mock, it renders
expect(screen.getByText('2 actions')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,131 @@
import type { PluginDetail, StrategyDetail } from '@/app/components/plugins/types'
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import AgentStrategyList from './agent-strategy-list'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.num !== undefined)
return `${options.num} ${options.strategy || 'strategies'}`
return key
},
}),
}))
const mockStrategies = [
{
identity: {
author: 'author-1',
name: 'strategy-1',
icon: 'icon.png',
label: { en_US: 'Strategy 1' },
provider: 'provider-1',
},
parameters: [],
description: { en_US: 'Strategy 1 desc' },
output_schema: {},
features: [],
},
] as unknown as StrategyDetail[]
let mockStrategyProviderDetail: { declaration: { identity: unknown, strategies: StrategyDetail[] } } | undefined
vi.mock('@/service/use-strategy', () => ({
useStrategyProviderDetail: () => ({
data: mockStrategyProviderDetail,
}),
}))
vi.mock('@/app/components/plugins/plugin-detail-panel/strategy-item', () => ({
default: ({ detail }: { detail: StrategyDetail }) => (
<div data-testid="strategy-item">{detail.identity.name}</div>
),
}))
const createPluginDetail = (): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
agent_strategy: {
identity: {
author: 'test-author',
name: 'test-strategy',
label: { en_US: 'Test Strategy' },
description: { en_US: 'Test' },
icon: 'icon.png',
tags: [],
},
},
} as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
})
describe('AgentStrategyList', () => {
beforeEach(() => {
vi.clearAllMocks()
mockStrategyProviderDetail = {
declaration: {
identity: { author: 'test', name: 'test' },
strategies: mockStrategies,
},
}
})
describe('Rendering', () => {
it('should render strategy items when data is available', () => {
render(<AgentStrategyList detail={createPluginDetail()} />)
expect(screen.getByText('1 strategy')).toBeInTheDocument()
expect(screen.getByTestId('strategy-item')).toBeInTheDocument()
})
it('should return null when no strategy provider detail', () => {
mockStrategyProviderDetail = undefined
const { container } = render(<AgentStrategyList detail={createPluginDetail()} />)
expect(container).toBeEmptyDOMElement()
})
it('should render multiple strategies', () => {
mockStrategyProviderDetail = {
declaration: {
identity: { author: 'test', name: 'test' },
strategies: [
...mockStrategies,
{ ...mockStrategies[0], identity: { ...mockStrategies[0].identity, name: 'strategy-2' } },
],
},
}
render(<AgentStrategyList detail={createPluginDetail()} />)
expect(screen.getByText('2 strategies')).toBeInTheDocument()
expect(screen.getAllByTestId('strategy-item')).toHaveLength(2)
})
})
describe('Props', () => {
it('should pass tenant_id to provider detail', () => {
const detail = createPluginDetail()
detail.tenant_id = 'custom-tenant'
render(<AgentStrategyList detail={detail} />)
expect(screen.getByTestId('strategy-item')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,104 @@
import type { PluginDetail } from '@/app/components/plugins/types'
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import DatasourceActionList from './datasource-action-list'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.num !== undefined)
return `${options.num} ${options.action || 'actions'}`
return key
},
}),
}))
const mockDataSourceList = [
{ plugin_id: 'test-plugin', name: 'Data Source 1' },
]
let mockDataSourceListData: typeof mockDataSourceList | undefined
vi.mock('@/service/use-pipeline', () => ({
useDataSourceList: () => ({ data: mockDataSourceListData }),
}))
vi.mock('@/app/components/workflow/block-selector/utils', () => ({
transformDataSourceToTool: (ds: unknown) => ds,
}))
const createPluginDetail = (): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
datasource: {
identity: {
author: 'test-author',
name: 'test-datasource',
description: { en_US: 'Test' },
icon: 'icon.png',
label: { en_US: 'Test Datasource' },
tags: [],
},
credentials_schema: [],
},
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
})
describe('DatasourceActionList', () => {
beforeEach(() => {
vi.clearAllMocks()
mockDataSourceListData = mockDataSourceList
})
describe('Rendering', () => {
it('should render action count when data and provider exist', () => {
render(<DatasourceActionList detail={createPluginDetail()} />)
// The component always shows "0 action" because data is hardcoded as empty array
expect(screen.getByText('0 action')).toBeInTheDocument()
})
it('should return null when no provider found', () => {
mockDataSourceListData = []
const { container } = render(<DatasourceActionList detail={createPluginDetail()} />)
expect(container).toBeEmptyDOMElement()
})
it('should return null when dataSourceList is undefined', () => {
mockDataSourceListData = undefined
const { container } = render(<DatasourceActionList detail={createPluginDetail()} />)
expect(container).toBeEmptyDOMElement()
})
})
describe('Props', () => {
it('should use plugin_id to find matching datasource', () => {
const detail = createPluginDetail()
detail.plugin_id = 'different-plugin'
mockDataSourceListData = [{ plugin_id: 'different-plugin', name: 'Different DS' }]
render(<DatasourceActionList detail={detail} />)
expect(screen.getByText('0 action')).toBeInTheDocument()
})
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,386 @@
import type { EndpointListItem, PluginDetail } from '../types'
import { act, fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import EndpointCard from './endpoint-card'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('copy-to-clipboard', () => ({
default: vi.fn(),
}))
const mockHandleChange = vi.fn()
const mockEnableEndpoint = vi.fn()
const mockDisableEndpoint = vi.fn()
const mockDeleteEndpoint = vi.fn()
const mockUpdateEndpoint = vi.fn()
// Flags to control whether operations should fail
const failureFlags = {
enable: false,
disable: false,
delete: false,
update: false,
}
vi.mock('@/service/use-endpoints', () => ({
useEnableEndpoint: ({ onSuccess, onError }: { onSuccess: () => void, onError: () => void }) => ({
mutate: (id: string) => {
mockEnableEndpoint(id)
if (failureFlags.enable)
onError()
else
onSuccess()
},
}),
useDisableEndpoint: ({ onSuccess, onError }: { onSuccess: () => void, onError: () => void }) => ({
mutate: (id: string) => {
mockDisableEndpoint(id)
if (failureFlags.disable)
onError()
else
onSuccess()
},
}),
useDeleteEndpoint: ({ onSuccess, onError }: { onSuccess: () => void, onError: () => void }) => ({
mutate: (id: string) => {
mockDeleteEndpoint(id)
if (failureFlags.delete)
onError()
else
onSuccess()
},
}),
useUpdateEndpoint: ({ onSuccess, onError }: { onSuccess: () => void, onError: () => void }) => ({
mutate: (data: unknown) => {
mockUpdateEndpoint(data)
if (failureFlags.update)
onError()
else
onSuccess()
},
}),
}))
vi.mock('@/app/components/header/indicator', () => ({
default: ({ color }: { color: string }) => <span data-testid="indicator" data-color={color} />,
}))
vi.mock('@/app/components/tools/utils/to-form-schema', () => ({
toolCredentialToFormSchemas: (schemas: unknown[]) => schemas,
addDefaultValue: (value: unknown) => value,
}))
vi.mock('./endpoint-modal', () => ({
default: ({ onCancel, onSaved }: { onCancel: () => void, onSaved: (state: unknown) => void }) => (
<div data-testid="endpoint-modal">
<button data-testid="modal-cancel" onClick={onCancel}>Cancel</button>
<button data-testid="modal-save" onClick={() => onSaved({ name: 'Updated' })}>Save</button>
</div>
),
}))
const mockEndpointData: EndpointListItem = {
id: 'ep-1',
name: 'Test Endpoint',
url: 'https://api.example.com',
enabled: true,
created_at: '2024-01-01',
updated_at: '2024-01-02',
settings: {},
tenant_id: 'tenant-1',
plugin_id: 'plugin-1',
expired_at: '',
hook_id: 'hook-1',
declaration: {
settings: [],
endpoints: [
{ path: '/api/test', method: 'GET' },
{ path: '/api/hidden', method: 'POST', hidden: true },
],
},
}
const mockPluginDetail: PluginDetail = {
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {} as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
}
describe('EndpointCard', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.useFakeTimers()
// Reset failure flags
failureFlags.enable = false
failureFlags.disable = false
failureFlags.delete = false
failureFlags.update = false
// Mock Toast.notify to prevent toast elements from accumulating in DOM
vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
})
afterEach(() => {
vi.useRealTimers()
})
describe('Rendering', () => {
it('should render endpoint name', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
expect(screen.getByText('Test Endpoint')).toBeInTheDocument()
})
it('should render visible endpoints only', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
expect(screen.getByText('GET')).toBeInTheDocument()
expect(screen.getByText('https://api.example.com/api/test')).toBeInTheDocument()
expect(screen.queryByText('POST')).not.toBeInTheDocument()
})
it('should show active status when enabled', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
expect(screen.getByText('detailPanel.serviceOk')).toBeInTheDocument()
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'green')
})
it('should show disabled status when not enabled', () => {
const disabledData = { ...mockEndpointData, enabled: false }
render(<EndpointCard pluginDetail={mockPluginDetail} data={disabledData} handleChange={mockHandleChange} />)
expect(screen.getByText('detailPanel.disabled')).toBeInTheDocument()
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'gray')
})
})
describe('User Interactions', () => {
it('should show disable confirm when switching off', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
fireEvent.click(screen.getByRole('switch'))
expect(screen.getByText('detailPanel.endpointDisableTip')).toBeInTheDocument()
})
it('should call disableEndpoint when confirm disable', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
fireEvent.click(screen.getByRole('switch'))
// Click confirm button in the Confirm dialog
fireEvent.click(screen.getByRole('button', { name: 'operation.confirm' }))
expect(mockDisableEndpoint).toHaveBeenCalledWith('ep-1')
})
it('should show delete confirm when delete clicked', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
// Find delete button by its destructive class
const allButtons = screen.getAllByRole('button')
const deleteButton = allButtons.find(btn => btn.classList.contains('text-text-tertiary'))
expect(deleteButton).toBeDefined()
if (deleteButton)
fireEvent.click(deleteButton)
expect(screen.getByText('detailPanel.endpointDeleteTip')).toBeInTheDocument()
})
it('should call deleteEndpoint when confirm delete', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const allButtons = screen.getAllByRole('button')
const deleteButton = allButtons.find(btn => btn.classList.contains('text-text-tertiary'))
expect(deleteButton).toBeDefined()
if (deleteButton)
fireEvent.click(deleteButton)
fireEvent.click(screen.getByRole('button', { name: 'operation.confirm' }))
expect(mockDeleteEndpoint).toHaveBeenCalledWith('ep-1')
})
it('should show edit modal when edit clicked', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const actionButtons = screen.getAllByRole('button', { name: '' })
const editButton = actionButtons[0]
if (editButton)
fireEvent.click(editButton)
expect(screen.getByTestId('endpoint-modal')).toBeInTheDocument()
})
it('should call updateEndpoint when save in modal', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const actionButtons = screen.getAllByRole('button', { name: '' })
const editButton = actionButtons[0]
if (editButton)
fireEvent.click(editButton)
fireEvent.click(screen.getByTestId('modal-save'))
expect(mockUpdateEndpoint).toHaveBeenCalled()
})
})
describe('Copy Functionality', () => {
it('should reset copy state after timeout', async () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
// Find copy button by its class
const allButtons = screen.getAllByRole('button')
const copyButton = allButtons.find(btn => btn.classList.contains('ml-2'))
expect(copyButton).toBeDefined()
if (copyButton) {
fireEvent.click(copyButton)
act(() => {
vi.advanceTimersByTime(2000)
})
// After timeout, the component should still be rendered correctly
expect(screen.getByText('Test Endpoint')).toBeInTheDocument()
}
})
})
describe('Edge Cases', () => {
it('should handle empty endpoints', () => {
const dataWithNoEndpoints = {
...mockEndpointData,
declaration: { settings: [], endpoints: [] },
}
render(<EndpointCard pluginDetail={mockPluginDetail} data={dataWithNoEndpoints} handleChange={mockHandleChange} />)
expect(screen.getByText('Test Endpoint')).toBeInTheDocument()
})
it('should call handleChange after enable', () => {
const disabledData = { ...mockEndpointData, enabled: false }
render(<EndpointCard pluginDetail={mockPluginDetail} data={disabledData} handleChange={mockHandleChange} />)
fireEvent.click(screen.getByRole('switch'))
expect(mockHandleChange).toHaveBeenCalled()
})
it('should hide disable confirm and revert state when cancel clicked', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
fireEvent.click(screen.getByRole('switch'))
expect(screen.getByText('detailPanel.endpointDisableTip')).toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'operation.cancel' }))
// Confirm should be hidden
expect(screen.queryByText('detailPanel.endpointDisableTip')).not.toBeInTheDocument()
})
it('should hide delete confirm when cancel clicked', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const allButtons = screen.getAllByRole('button')
const deleteButton = allButtons.find(btn => btn.classList.contains('text-text-tertiary'))
expect(deleteButton).toBeDefined()
if (deleteButton)
fireEvent.click(deleteButton)
expect(screen.getByText('detailPanel.endpointDeleteTip')).toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'operation.cancel' }))
expect(screen.queryByText('detailPanel.endpointDeleteTip')).not.toBeInTheDocument()
})
it('should hide edit modal when cancel clicked', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const actionButtons = screen.getAllByRole('button', { name: '' })
const editButton = actionButtons[0]
if (editButton)
fireEvent.click(editButton)
expect(screen.getByTestId('endpoint-modal')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('modal-cancel'))
expect(screen.queryByTestId('endpoint-modal')).not.toBeInTheDocument()
})
})
describe('Error Handling', () => {
it('should show error toast when enable fails', () => {
failureFlags.enable = true
const disabledData = { ...mockEndpointData, enabled: false }
render(<EndpointCard pluginDetail={mockPluginDetail} data={disabledData} handleChange={mockHandleChange} />)
fireEvent.click(screen.getByRole('switch'))
expect(mockEnableEndpoint).toHaveBeenCalled()
})
it('should show error toast when disable fails', () => {
failureFlags.disable = true
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
fireEvent.click(screen.getByRole('switch'))
fireEvent.click(screen.getByRole('button', { name: 'operation.confirm' }))
expect(mockDisableEndpoint).toHaveBeenCalled()
})
it('should show error toast when delete fails', () => {
failureFlags.delete = true
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const allButtons = screen.getAllByRole('button')
const deleteButton = allButtons.find(btn => btn.classList.contains('text-text-tertiary'))
if (deleteButton)
fireEvent.click(deleteButton)
fireEvent.click(screen.getByRole('button', { name: 'operation.confirm' }))
expect(mockDeleteEndpoint).toHaveBeenCalled()
})
it('should show error toast when update fails', () => {
render(<EndpointCard pluginDetail={mockPluginDetail} data={mockEndpointData} handleChange={mockHandleChange} />)
const actionButtons = screen.getAllByRole('button', { name: '' })
const editButton = actionButtons[0]
expect(editButton).toBeDefined()
if (editButton)
fireEvent.click(editButton)
// Verify modal is open
expect(screen.getByTestId('endpoint-modal')).toBeInTheDocument()
// Set failure flag before save is clicked
failureFlags.update = true
fireEvent.click(screen.getByTestId('modal-save'))
expect(mockUpdateEndpoint).toHaveBeenCalled()
// On error, handleChange is not called
expect(mockHandleChange).not.toHaveBeenCalled()
})
})
})

View File

@@ -0,0 +1,222 @@
import type { PluginDetail } from '@/app/components/plugins/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import EndpointList from './endpoint-list'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/context/i18n', () => ({
useDocLink: () => (path: string) => `https://docs.example.com${path}`,
}))
vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
const mockEndpoints = [
{ id: 'ep-1', name: 'Endpoint 1', url: 'https://api.example.com', declaration: { settings: [], endpoints: [] } },
]
let mockEndpointListData: { endpoints: typeof mockEndpoints } | undefined
const mockInvalidateEndpointList = vi.fn()
const mockCreateEndpoint = vi.fn()
vi.mock('@/service/use-endpoints', () => ({
useEndpointList: () => ({ data: mockEndpointListData }),
useInvalidateEndpointList: () => mockInvalidateEndpointList,
useCreateEndpoint: ({ onSuccess }: { onSuccess: () => void }) => ({
mutate: (data: unknown) => {
mockCreateEndpoint(data)
onSuccess()
},
}),
}))
vi.mock('@/app/components/tools/utils/to-form-schema', () => ({
toolCredentialToFormSchemas: (schemas: unknown[]) => schemas,
}))
vi.mock('./endpoint-card', () => ({
default: ({ data }: { data: { name: string } }) => (
<div data-testid="endpoint-card">{data.name}</div>
),
}))
vi.mock('./endpoint-modal', () => ({
default: ({ onCancel, onSaved }: { onCancel: () => void, onSaved: (state: unknown) => void }) => (
<div data-testid="endpoint-modal">
<button data-testid="modal-cancel" onClick={onCancel}>Cancel</button>
<button data-testid="modal-save" onClick={() => onSaved({ name: 'New Endpoint' })}>Save</button>
</div>
),
}))
const createPluginDetail = (): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
endpoint: { settings: [], endpoints: [] },
tool: undefined,
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
})
describe('EndpointList', () => {
beforeEach(() => {
vi.clearAllMocks()
mockEndpointListData = { endpoints: mockEndpoints }
})
describe('Rendering', () => {
it('should render endpoint list', () => {
render(<EndpointList detail={createPluginDetail()} />)
expect(screen.getByText('detailPanel.endpoints')).toBeInTheDocument()
})
it('should render endpoint cards', () => {
render(<EndpointList detail={createPluginDetail()} />)
expect(screen.getByTestId('endpoint-card')).toBeInTheDocument()
expect(screen.getByText('Endpoint 1')).toBeInTheDocument()
})
it('should return null when no data', () => {
mockEndpointListData = undefined
const { container } = render(<EndpointList detail={createPluginDetail()} />)
expect(container).toBeEmptyDOMElement()
})
it('should show empty message when no endpoints', () => {
mockEndpointListData = { endpoints: [] }
render(<EndpointList detail={createPluginDetail()} />)
expect(screen.getByText('detailPanel.endpointsEmpty')).toBeInTheDocument()
})
it('should render add button', () => {
render(<EndpointList detail={createPluginDetail()} />)
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
expect(addButton).toBeDefined()
})
})
describe('User Interactions', () => {
it('should show modal when add button clicked', () => {
render(<EndpointList detail={createPluginDetail()} />)
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (addButton)
fireEvent.click(addButton)
expect(screen.getByTestId('endpoint-modal')).toBeInTheDocument()
})
it('should hide modal when cancel clicked', () => {
render(<EndpointList detail={createPluginDetail()} />)
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (addButton)
fireEvent.click(addButton)
expect(screen.getByTestId('endpoint-modal')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('modal-cancel'))
expect(screen.queryByTestId('endpoint-modal')).not.toBeInTheDocument()
})
it('should call createEndpoint when save clicked', () => {
render(<EndpointList detail={createPluginDetail()} />)
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (addButton)
fireEvent.click(addButton)
fireEvent.click(screen.getByTestId('modal-save'))
expect(mockCreateEndpoint).toHaveBeenCalled()
})
})
describe('Border Style', () => {
it('should render with border style based on tool existence', () => {
const detail = createPluginDetail()
detail.declaration.tool = {} as PluginDetail['declaration']['tool']
render(<EndpointList detail={detail} />)
// Verify the component renders correctly
expect(screen.getByText('detailPanel.endpoints')).toBeInTheDocument()
})
})
describe('Multiple Endpoints', () => {
it('should render multiple endpoint cards', () => {
mockEndpointListData = {
endpoints: [
{ id: 'ep-1', name: 'Endpoint 1', url: 'https://api1.example.com', declaration: { settings: [], endpoints: [] } },
{ id: 'ep-2', name: 'Endpoint 2', url: 'https://api2.example.com', declaration: { settings: [], endpoints: [] } },
],
}
render(<EndpointList detail={createPluginDetail()} />)
expect(screen.getAllByTestId('endpoint-card')).toHaveLength(2)
})
})
describe('Tooltip', () => {
it('should render with tooltip content', () => {
render(<EndpointList detail={createPluginDetail()} />)
// Tooltip is rendered - the add button should be visible
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
expect(addButton).toBeDefined()
})
})
describe('Create Endpoint Flow', () => {
it('should invalidate endpoint list after successful create', () => {
render(<EndpointList detail={createPluginDetail()} />)
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (addButton)
fireEvent.click(addButton)
fireEvent.click(screen.getByTestId('modal-save'))
expect(mockInvalidateEndpointList).toHaveBeenCalledWith('test-plugin')
})
it('should pass correct params to createEndpoint', () => {
render(<EndpointList detail={createPluginDetail()} />)
const addButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (addButton)
fireEvent.click(addButton)
fireEvent.click(screen.getByTestId('modal-save'))
expect(mockCreateEndpoint).toHaveBeenCalledWith({
pluginUniqueID: 'test-uid',
state: { name: 'New Endpoint' },
})
})
})
})

View File

@@ -0,0 +1,519 @@
import type { FormSchema } from '../../base/form/types'
import type { PluginDetail } from '../types'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import EndpointModal from './endpoint-modal'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, opts?: Record<string, unknown>) => {
if (opts?.field)
return `${key}: ${opts.field}`
return key
},
}),
}))
vi.mock('@/hooks/use-i18n', () => ({
useRenderI18nObject: () => (obj: Record<string, string> | string) =>
typeof obj === 'string' ? obj : obj?.en_US || '',
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/model-modal/Form', () => ({
default: ({ value, onChange, fieldMoreInfo }: {
value: Record<string, unknown>
onChange: (v: Record<string, unknown>) => void
fieldMoreInfo?: (item: { url?: string }) => React.ReactNode
}) => {
return (
<div data-testid="form">
<input
data-testid="form-input"
value={value.name as string || ''}
onChange={e => onChange({ ...value, name: e.target.value })}
/>
{/* Render fieldMoreInfo to test url link */}
{fieldMoreInfo && (
<div data-testid="field-more-info">
{fieldMoreInfo({ url: 'https://example.com' })}
{fieldMoreInfo({})}
</div>
)}
</div>
)
},
}))
vi.mock('../readme-panel/entrance', () => ({
ReadmeEntrance: () => <div data-testid="readme-entrance" />,
}))
const mockFormSchemas = [
{ name: 'name', label: { en_US: 'Name' }, type: 'text-input', required: true, default: '' },
{ name: 'apiKey', label: { en_US: 'API Key' }, type: 'secret-input', required: false, default: '' },
] as unknown as FormSchema[]
const mockPluginDetail: PluginDetail = {
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {} as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
}
describe('EndpointModal', () => {
const mockOnCancel = vi.fn()
const mockOnSaved = vi.fn()
let mockToastNotify: ReturnType<typeof vi.spyOn>
beforeEach(() => {
vi.clearAllMocks()
mockToastNotify = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
})
describe('Rendering', () => {
it('should render drawer', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should render title and description', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByText('detailPanel.endpointModalTitle')).toBeInTheDocument()
expect(screen.getByText('detailPanel.endpointModalDesc')).toBeInTheDocument()
})
it('should render form with fieldMoreInfo url link', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByTestId('field-more-info')).toBeInTheDocument()
// Should render the "howToGet" link when url exists
expect(screen.getByText('howToGet')).toBeInTheDocument()
})
it('should render readme entrance', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByTestId('readme-entrance')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call onCancel when cancel clicked', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.cancel' }))
expect(mockOnCancel).toHaveBeenCalledTimes(1)
})
it('should call onCancel when close button clicked', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
// Find the close button (ActionButton with RiCloseLine icon)
const allButtons = screen.getAllByRole('button')
const closeButton = allButtons.find(btn => btn.classList.contains('action-btn'))
if (closeButton)
fireEvent.click(closeButton)
expect(mockOnCancel).toHaveBeenCalledTimes(1)
})
it('should update form value when input changes', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
const input = screen.getByTestId('form-input')
fireEvent.change(input, { target: { value: 'Test Name' } })
expect(input).toHaveValue('Test Name')
})
})
describe('Default Values', () => {
it('should use defaultValues when provided', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
defaultValues={{ name: 'Default Name' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByTestId('form-input')).toHaveValue('Default Name')
})
it('should extract default values from schemas when no defaultValues', () => {
const schemasWithDefaults = [
{ name: 'name', label: 'Name', type: 'text-input', required: true, default: 'Schema Default' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithDefaults}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByTestId('form-input')).toHaveValue('Schema Default')
})
it('should handle schemas without default values', () => {
const schemasNoDefault = [
{ name: 'name', label: 'Name', type: 'text-input', required: false },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasNoDefault}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
expect(screen.getByTestId('form')).toBeInTheDocument()
})
})
describe('Validation - handleSave', () => {
it('should show toast error when required field is empty', () => {
const schemasWithRequired = [
{ name: 'name', label: { en_US: 'Name Field' }, type: 'text-input', required: true, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithRequired}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'error',
message: expect.stringContaining('errorMsg.fieldRequired'),
})
expect(mockOnSaved).not.toHaveBeenCalled()
})
it('should show toast error with string label when required field is empty', () => {
const schemasWithStringLabel = [
{ name: 'name', label: 'String Label', type: 'text-input', required: true, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithStringLabel}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'error',
message: expect.stringContaining('String Label'),
})
})
it('should call onSaved when all required fields are filled', () => {
render(
<EndpointModal
formSchemas={mockFormSchemas}
defaultValues={{ name: 'Valid Name' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ name: 'Valid Name' })
})
it('should not validate non-required empty fields', () => {
const schemasOptional = [
{ name: 'optional', label: 'Optional', type: 'text-input', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasOptional}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockToastNotify).not.toHaveBeenCalled()
expect(mockOnSaved).toHaveBeenCalled()
})
})
describe('Boolean Field Processing', () => {
it('should convert string "true" to boolean true', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: 'true' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: true })
})
it('should convert string "1" to boolean true', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: '1' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: true })
})
it('should convert string "True" to boolean true', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: 'True' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: true })
})
it('should convert string "false" to boolean false', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: 'false' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: false })
})
it('should convert number 1 to boolean true', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: 1 }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: true })
})
it('should convert number 0 to boolean false', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: 0 }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: false })
})
it('should preserve boolean true value', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: true }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: true })
})
it('should preserve boolean false value', () => {
const schemasWithBoolean = [
{ name: 'enabled', label: 'Enabled', type: 'boolean', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithBoolean}
defaultValues={{ enabled: false }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ enabled: false })
})
it('should not process non-boolean fields', () => {
const schemasWithText = [
{ name: 'text', label: 'Text', type: 'text-input', required: false, default: '' },
] as unknown as FormSchema[]
render(
<EndpointModal
formSchemas={schemasWithText}
defaultValues={{ text: 'hello' }}
onCancel={mockOnCancel}
onSaved={mockOnSaved}
pluginDetail={mockPluginDetail}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'operation.save' }))
expect(mockOnSaved).toHaveBeenCalledWith({ text: 'hello' })
})
})
describe('Memoization', () => {
it('should be wrapped with React.memo', () => {
expect(EndpointModal).toBeDefined()
expect((EndpointModal as { $$typeof?: symbol }).$$typeof).toBeDefined()
})
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,103 @@
import type { PluginDetail } from '@/app/components/plugins/types'
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import ModelList from './model-list'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.num !== undefined)
return `${options.num} models`
return key
},
}),
}))
const mockModels = [
{ model: 'gpt-4', provider: 'openai' },
{ model: 'gpt-3.5', provider: 'openai' },
]
let mockModelListResponse: { data: typeof mockModels } | undefined
vi.mock('@/service/use-models', () => ({
useModelProviderModelList: () => ({
data: mockModelListResponse,
}),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/model-icon', () => ({
default: ({ modelName }: { modelName: string }) => (
<span data-testid="model-icon">{modelName}</span>
),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/model-name', () => ({
default: ({ modelItem }: { modelItem: { model: string } }) => (
<span data-testid="model-name">{modelItem.model}</span>
),
}))
const createPluginDetail = (): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
model: { provider: 'openai' },
} as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '1.0.0',
latest_unique_identifier: 'test-uid',
source: 'marketplace' as PluginDetail['source'],
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
})
describe('ModelList', () => {
beforeEach(() => {
vi.clearAllMocks()
mockModelListResponse = { data: mockModels }
})
describe('Rendering', () => {
it('should render model list when data is available', () => {
render(<ModelList detail={createPluginDetail()} />)
expect(screen.getByText('2 models')).toBeInTheDocument()
})
it('should render model icons and names', () => {
render(<ModelList detail={createPluginDetail()} />)
expect(screen.getAllByTestId('model-icon')).toHaveLength(2)
expect(screen.getAllByTestId('model-name')).toHaveLength(2)
// Both icon and name show the model name, so use getAllByText
expect(screen.getAllByText('gpt-4')).toHaveLength(2)
expect(screen.getAllByText('gpt-3.5')).toHaveLength(2)
})
it('should return null when no data', () => {
mockModelListResponse = undefined
const { container } = render(<ModelList detail={createPluginDetail()} />)
expect(container).toBeEmptyDOMElement()
})
it('should handle empty model list', () => {
mockModelListResponse = { data: [] }
render(<ModelList detail={createPluginDetail()} />)
expect(screen.getByText('0 models')).toBeInTheDocument()
expect(screen.queryByTestId('model-icon')).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,215 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../types'
import OperationDropdown from './operation-dropdown'
// Mock dependencies
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: <T,>(selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => T): T =>
selector({ systemFeatures: { enable_marketplace: true } }),
}))
vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
vi.mock('@/app/components/base/action-button', () => ({
default: ({ children, className, onClick }: { children: React.ReactNode, className?: string, onClick?: () => void }) => (
<button data-testid="action-button" className={className} onClick={onClick}>
{children}
</button>
),
}))
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<div data-testid="portal-elem" data-open={open}>{children}</div>
),
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
<div data-testid="portal-trigger" onClick={onClick}>{children}</div>
),
PortalToFollowElemContent: ({ children, className }: { children: React.ReactNode, className?: string }) => (
<div data-testid="portal-content" className={className}>{children}</div>
),
}))
describe('OperationDropdown', () => {
const mockOnInfo = vi.fn()
const mockOnCheckVersion = vi.fn()
const mockOnRemove = vi.fn()
const defaultProps = {
source: PluginSource.github,
detailUrl: 'https://github.com/test/repo',
onInfo: mockOnInfo,
onCheckVersion: mockOnCheckVersion,
onRemove: mockOnRemove,
}
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render trigger button', () => {
render(<OperationDropdown {...defaultProps} />)
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
expect(screen.getByTestId('action-button')).toBeInTheDocument()
})
it('should render dropdown content', () => {
render(<OperationDropdown {...defaultProps} />)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
})
it('should render info option for github source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.github} />)
expect(screen.getByText('detailPanel.operation.info')).toBeInTheDocument()
})
it('should render check update option for github source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.github} />)
expect(screen.getByText('detailPanel.operation.checkUpdate')).toBeInTheDocument()
})
it('should render view detail option for github source with marketplace enabled', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.github} />)
expect(screen.getByText('detailPanel.operation.viewDetail')).toBeInTheDocument()
})
it('should render view detail option for marketplace source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.marketplace} />)
expect(screen.getByText('detailPanel.operation.viewDetail')).toBeInTheDocument()
})
it('should always render remove option', () => {
render(<OperationDropdown {...defaultProps} />)
expect(screen.getByText('detailPanel.operation.remove')).toBeInTheDocument()
})
it('should not render info option for marketplace source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.marketplace} />)
expect(screen.queryByText('detailPanel.operation.info')).not.toBeInTheDocument()
})
it('should not render check update option for marketplace source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.marketplace} />)
expect(screen.queryByText('detailPanel.operation.checkUpdate')).not.toBeInTheDocument()
})
it('should not render view detail for local source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.local} />)
expect(screen.queryByText('detailPanel.operation.viewDetail')).not.toBeInTheDocument()
})
it('should not render view detail for debugging source', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.debugging} />)
expect(screen.queryByText('detailPanel.operation.viewDetail')).not.toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should toggle dropdown when trigger is clicked', () => {
render(<OperationDropdown {...defaultProps} />)
const trigger = screen.getByTestId('portal-trigger')
fireEvent.click(trigger)
// The portal-elem should reflect the open state
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
})
it('should call onInfo when info option is clicked', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.github} />)
fireEvent.click(screen.getByText('detailPanel.operation.info'))
expect(mockOnInfo).toHaveBeenCalledTimes(1)
})
it('should call onCheckVersion when check update option is clicked', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.github} />)
fireEvent.click(screen.getByText('detailPanel.operation.checkUpdate'))
expect(mockOnCheckVersion).toHaveBeenCalledTimes(1)
})
it('should call onRemove when remove option is clicked', () => {
render(<OperationDropdown {...defaultProps} />)
fireEvent.click(screen.getByText('detailPanel.operation.remove'))
expect(mockOnRemove).toHaveBeenCalledTimes(1)
})
it('should have correct href for view detail link', () => {
render(<OperationDropdown {...defaultProps} source={PluginSource.github} />)
const link = screen.getByText('detailPanel.operation.viewDetail').closest('a')
expect(link).toHaveAttribute('href', 'https://github.com/test/repo')
expect(link).toHaveAttribute('target', '_blank')
})
})
describe('Props Variations', () => {
it('should handle all plugin sources', () => {
const sources = [
PluginSource.github,
PluginSource.marketplace,
PluginSource.local,
PluginSource.debugging,
]
sources.forEach((source) => {
const { unmount } = render(
<OperationDropdown {...defaultProps} source={source} />,
)
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
expect(screen.getByText('detailPanel.operation.remove')).toBeInTheDocument()
unmount()
})
})
it('should handle different detail URLs', () => {
const urls = [
'https://github.com/owner/repo',
'https://marketplace.example.com/plugin/123',
]
urls.forEach((url) => {
const { unmount } = render(
<OperationDropdown {...defaultProps} detailUrl={url} source={PluginSource.github} />,
)
const link = screen.getByText('detailPanel.operation.viewDetail').closest('a')
expect(link).toHaveAttribute('href', url)
unmount()
})
})
})
describe('Memoization', () => {
it('should be wrapped with React.memo', () => {
// Verify the component is exported as a memo component
expect(OperationDropdown).toBeDefined()
// React.memo wraps the component, so it should have $$typeof
expect((OperationDropdown as { $$typeof?: symbol }).$$typeof).toBeDefined()
})
})
})

View File

@@ -0,0 +1,461 @@
import type { SimpleDetail } from './store'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it } from 'vitest'
import { usePluginStore } from './store'
// Factory function to create mock SimpleDetail
const createSimpleDetail = (overrides: Partial<SimpleDetail> = {}): SimpleDetail => ({
plugin_id: 'test-plugin-id',
name: 'Test Plugin',
plugin_unique_identifier: 'test-plugin-uid',
id: 'test-id',
provider: 'test-provider',
declaration: {
category: 'tool' as SimpleDetail['declaration']['category'],
name: 'test-declaration',
},
...overrides,
})
describe('usePluginStore', () => {
beforeEach(() => {
// Reset store state before each test
const { result } = renderHook(() => usePluginStore())
act(() => {
result.current.setDetail(undefined)
})
})
describe('Initial State', () => {
it('should have undefined detail initially', () => {
const { result } = renderHook(() => usePluginStore())
expect(result.current.detail).toBeUndefined()
})
it('should provide setDetail function', () => {
const { result } = renderHook(() => usePluginStore())
expect(typeof result.current.setDetail).toBe('function')
})
})
describe('setDetail', () => {
it('should set detail with valid SimpleDetail', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail()
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail).toEqual(detail)
})
it('should set detail to undefined', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail()
// First set a value
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail).toEqual(detail)
// Then clear it
act(() => {
result.current.setDetail(undefined)
})
expect(result.current.detail).toBeUndefined()
})
it('should update detail when called multiple times', () => {
const { result } = renderHook(() => usePluginStore())
const detail1 = createSimpleDetail({ plugin_id: 'plugin-1' })
const detail2 = createSimpleDetail({ plugin_id: 'plugin-2' })
act(() => {
result.current.setDetail(detail1)
})
expect(result.current.detail?.plugin_id).toBe('plugin-1')
act(() => {
result.current.setDetail(detail2)
})
expect(result.current.detail?.plugin_id).toBe('plugin-2')
})
it('should handle detail with trigger declaration', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({
declaration: {
trigger: {
subscription_schema: [],
subscription_constructor: null,
},
},
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.declaration.trigger).toEqual({
subscription_schema: [],
subscription_constructor: null,
})
})
it('should handle detail with partial declaration', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({
declaration: {
name: 'partial-plugin',
},
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.declaration.name).toBe('partial-plugin')
})
})
describe('Store Sharing', () => {
it('should share state across multiple hook instances', () => {
const { result: result1 } = renderHook(() => usePluginStore())
const { result: result2 } = renderHook(() => usePluginStore())
const detail = createSimpleDetail()
act(() => {
result1.current.setDetail(detail)
})
// Both hooks should see the same state
expect(result1.current.detail).toEqual(detail)
expect(result2.current.detail).toEqual(detail)
})
it('should update all hook instances when state changes', () => {
const { result: result1 } = renderHook(() => usePluginStore())
const { result: result2 } = renderHook(() => usePluginStore())
const detail1 = createSimpleDetail({ name: 'Plugin One' })
const detail2 = createSimpleDetail({ name: 'Plugin Two' })
act(() => {
result1.current.setDetail(detail1)
})
expect(result1.current.detail?.name).toBe('Plugin One')
expect(result2.current.detail?.name).toBe('Plugin One')
act(() => {
result2.current.setDetail(detail2)
})
expect(result1.current.detail?.name).toBe('Plugin Two')
expect(result2.current.detail?.name).toBe('Plugin Two')
})
})
describe('Selector Pattern', () => {
// Extract selectors to reduce nesting depth
const selectDetail = (state: ReturnType<typeof usePluginStore.getState>) => state.detail
const selectSetDetail = (state: ReturnType<typeof usePluginStore.getState>) => state.setDetail
it('should support selector to get specific field', () => {
const { result: setterResult } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({ plugin_id: 'selected-plugin' })
act(() => {
setterResult.current.setDetail(detail)
})
// Use selector to get only detail
const { result: selectorResult } = renderHook(() => usePluginStore(selectDetail))
expect(selectorResult.current?.plugin_id).toBe('selected-plugin')
})
it('should support selector to get setDetail function', () => {
const { result } = renderHook(() => usePluginStore(selectSetDetail))
expect(typeof result.current).toBe('function')
})
})
describe('Edge Cases', () => {
it('should handle empty string values in detail', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({
plugin_id: '',
name: '',
plugin_unique_identifier: '',
provider: '',
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.plugin_id).toBe('')
expect(result.current.detail?.name).toBe('')
})
it('should handle detail with empty declaration', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({
declaration: {},
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.declaration).toEqual({})
})
it('should handle rapid state updates', () => {
const { result } = renderHook(() => usePluginStore())
act(() => {
for (let i = 0; i < 10; i++)
result.current.setDetail(createSimpleDetail({ plugin_id: `plugin-${i}` }))
})
expect(result.current.detail?.plugin_id).toBe('plugin-9')
})
it('should handle setDetail called without arguments', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail()
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail).toBeDefined()
act(() => {
result.current.setDetail()
})
expect(result.current.detail).toBeUndefined()
})
})
describe('Type Safety', () => {
it('should preserve all SimpleDetail fields correctly', () => {
const { result } = renderHook(() => usePluginStore())
const detail: SimpleDetail = {
plugin_id: 'type-test-id',
name: 'Type Test Plugin',
plugin_unique_identifier: 'type-test-uid',
id: 'type-id',
provider: 'type-provider',
declaration: {
category: 'model' as SimpleDetail['declaration']['category'],
name: 'type-declaration',
version: '2.0.0',
author: 'test-author',
},
}
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail).toStrictEqual(detail)
expect(result.current.detail?.plugin_id).toBe('type-test-id')
expect(result.current.detail?.name).toBe('Type Test Plugin')
expect(result.current.detail?.plugin_unique_identifier).toBe('type-test-uid')
expect(result.current.detail?.id).toBe('type-id')
expect(result.current.detail?.provider).toBe('type-provider')
})
it('should handle declaration with subscription_constructor', () => {
const { result } = renderHook(() => usePluginStore())
const mockConstructor = {
credentials_schema: [],
oauth_schema: {
client_schema: [],
credentials_schema: [],
},
parameters: [],
}
const detail = createSimpleDetail({
declaration: {
trigger: {
subscription_schema: [],
subscription_constructor: mockConstructor as unknown as NonNullable<SimpleDetail['declaration']['trigger']>['subscription_constructor'],
},
},
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.declaration.trigger?.subscription_constructor).toBeDefined()
})
it('should handle declaration with subscription_schema', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({
declaration: {
trigger: {
subscription_schema: [],
subscription_constructor: null,
},
},
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.declaration.trigger?.subscription_schema).toEqual([])
})
})
describe('State Persistence', () => {
it('should maintain state after multiple renders', () => {
const detail = createSimpleDetail({ name: 'Persistent Plugin' })
const { result, rerender } = renderHook(() => usePluginStore())
act(() => {
result.current.setDetail(detail)
})
// Rerender multiple times
rerender()
rerender()
rerender()
expect(result.current.detail?.name).toBe('Persistent Plugin')
})
it('should maintain reference equality for unchanged state', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail()
act(() => {
result.current.setDetail(detail)
})
const firstDetailRef = result.current.detail
// Get state again without changing
const { result: result2 } = renderHook(() => usePluginStore())
expect(result2.current.detail).toBe(firstDetailRef)
})
})
describe('Concurrent Updates', () => {
it('should handle updates from multiple sources correctly', () => {
const { result: hook1 } = renderHook(() => usePluginStore())
const { result: hook2 } = renderHook(() => usePluginStore())
const { result: hook3 } = renderHook(() => usePluginStore())
act(() => {
hook1.current.setDetail(createSimpleDetail({ name: 'From Hook 1' }))
})
act(() => {
hook2.current.setDetail(createSimpleDetail({ name: 'From Hook 2' }))
})
act(() => {
hook3.current.setDetail(createSimpleDetail({ name: 'From Hook 3' }))
})
// All hooks should reflect the last update
expect(hook1.current.detail?.name).toBe('From Hook 3')
expect(hook2.current.detail?.name).toBe('From Hook 3')
expect(hook3.current.detail?.name).toBe('From Hook 3')
})
it('should handle interleaved read and write operations', () => {
const { result } = renderHook(() => usePluginStore())
act(() => {
result.current.setDetail(createSimpleDetail({ plugin_id: 'step-1' }))
})
expect(result.current.detail?.plugin_id).toBe('step-1')
act(() => {
result.current.setDetail(createSimpleDetail({ plugin_id: 'step-2' }))
})
expect(result.current.detail?.plugin_id).toBe('step-2')
act(() => {
result.current.setDetail(undefined)
})
expect(result.current.detail).toBeUndefined()
act(() => {
result.current.setDetail(createSimpleDetail({ plugin_id: 'step-3' }))
})
expect(result.current.detail?.plugin_id).toBe('step-3')
})
})
describe('Declaration Variations', () => {
it('should handle declaration with all optional fields', () => {
const { result } = renderHook(() => usePluginStore())
const detail = createSimpleDetail({
declaration: {
category: 'extension' as SimpleDetail['declaration']['category'],
name: 'full-declaration',
version: '1.0.0',
author: 'full-author',
icon: 'icon.png',
verified: true,
tags: ['tag1', 'tag2'],
},
})
act(() => {
result.current.setDetail(detail)
})
const decl = result.current.detail?.declaration
expect(decl?.category).toBe('extension')
expect(decl?.name).toBe('full-declaration')
expect(decl?.version).toBe('1.0.0')
expect(decl?.author).toBe('full-author')
expect(decl?.icon).toBe('icon.png')
expect(decl?.verified).toBe(true)
expect(decl?.tags).toEqual(['tag1', 'tag2'])
})
it('should handle declaration with nested tool object', () => {
const { result } = renderHook(() => usePluginStore())
const mockTool = {
identity: {
author: 'tool-author',
name: 'tool-name',
icon: 'tool-icon.png',
tags: ['api', 'utility'],
},
credentials_schema: [],
}
const detail = createSimpleDetail({
declaration: {
tool: mockTool as unknown as SimpleDetail['declaration']['tool'],
},
})
act(() => {
result.current.setDetail(detail)
})
expect(result.current.detail?.declaration.tool?.identity.name).toBe('tool-name')
expect(result.current.detail?.declaration.tool?.identity.tags).toEqual(['api', 'utility'])
})
})
})

View File

@@ -0,0 +1,203 @@
import type { StrategyDetail as StrategyDetailType } from '@/app/components/plugins/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import StrategyDetail from './strategy-detail'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/hooks/use-i18n', () => ({
useRenderI18nObject: () => (obj: Record<string, string>) => obj?.en_US || '',
}))
vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
vi.mock('@/app/components/plugins/card/base/card-icon', () => ({
default: () => <span data-testid="card-icon" />,
}))
vi.mock('@/app/components/plugins/card/base/description', () => ({
default: ({ text }: { text: string }) => <div data-testid="description">{text}</div>,
}))
type ProviderType = Parameters<typeof StrategyDetail>[0]['provider']
const mockProvider = {
author: 'test-author',
name: 'test-provider',
description: { en_US: 'Provider desc' },
tenant_id: 'tenant-1',
icon: 'icon.png',
label: { en_US: 'Test Provider' },
tags: [],
} as unknown as ProviderType
const mockDetail = {
identity: {
author: 'author-1',
name: 'strategy-1',
icon: 'icon.png',
label: { en_US: 'Strategy Label' },
provider: 'provider-1',
},
parameters: [
{
name: 'param1',
label: { en_US: 'Parameter 1' },
type: 'text-input',
required: true,
human_description: { en_US: 'A text parameter' },
},
],
description: { en_US: 'Strategy description' },
output_schema: {
properties: {
result: { type: 'string', description: 'Result output' },
items: { type: 'array', items: { type: 'string' }, description: 'Array items' },
},
},
features: [],
} as unknown as StrategyDetailType
describe('StrategyDetail', () => {
const mockOnHide = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render drawer', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should render provider label', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
expect(screen.getByText('Test Provider')).toBeInTheDocument()
})
it('should render strategy label', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
expect(screen.getByText('Strategy Label')).toBeInTheDocument()
})
it('should render parameters section', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
expect(screen.getByText('setBuiltInTools.parameters')).toBeInTheDocument()
expect(screen.getByText('Parameter 1')).toBeInTheDocument()
})
it('should render output schema section', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
expect(screen.getByText('OUTPUT')).toBeInTheDocument()
expect(screen.getByText('result')).toBeInTheDocument()
expect(screen.getByText('String')).toBeInTheDocument()
})
it('should render BACK button', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
expect(screen.getByText('BACK')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call onHide when close button clicked', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
// Find the close button (ActionButton with action-btn class)
const closeButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (closeButton)
fireEvent.click(closeButton)
expect(mockOnHide).toHaveBeenCalledTimes(1)
})
it('should call onHide when BACK clicked', () => {
render(<StrategyDetail provider={mockProvider} detail={mockDetail} onHide={mockOnHide} />)
fireEvent.click(screen.getByText('BACK'))
expect(mockOnHide).toHaveBeenCalledTimes(1)
})
})
describe('Parameter Types', () => {
it('should display correct type for number-input', () => {
const detailWithNumber = {
...mockDetail,
parameters: [{ ...mockDetail.parameters[0], type: 'number-input' }],
}
render(<StrategyDetail provider={mockProvider} detail={detailWithNumber} onHide={mockOnHide} />)
expect(screen.getByText('setBuiltInTools.number')).toBeInTheDocument()
})
it('should display correct type for checkbox', () => {
const detailWithCheckbox = {
...mockDetail,
parameters: [{ ...mockDetail.parameters[0], type: 'checkbox' }],
}
render(<StrategyDetail provider={mockProvider} detail={detailWithCheckbox} onHide={mockOnHide} />)
expect(screen.getByText('boolean')).toBeInTheDocument()
})
it('should display correct type for file', () => {
const detailWithFile = {
...mockDetail,
parameters: [{ ...mockDetail.parameters[0], type: 'file' }],
}
render(<StrategyDetail provider={mockProvider} detail={detailWithFile} onHide={mockOnHide} />)
expect(screen.getByText('setBuiltInTools.file')).toBeInTheDocument()
})
it('should display correct type for array[tools]', () => {
const detailWithArrayTools = {
...mockDetail,
parameters: [{ ...mockDetail.parameters[0], type: 'array[tools]' }],
}
render(<StrategyDetail provider={mockProvider} detail={detailWithArrayTools} onHide={mockOnHide} />)
expect(screen.getByText('multiple-tool-select')).toBeInTheDocument()
})
it('should display original type for unknown types', () => {
const detailWithUnknown = {
...mockDetail,
parameters: [{ ...mockDetail.parameters[0], type: 'custom-type' }],
}
render(<StrategyDetail provider={mockProvider} detail={detailWithUnknown} onHide={mockOnHide} />)
expect(screen.getByText('custom-type')).toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should handle empty parameters', () => {
const detailEmpty = { ...mockDetail, parameters: [] }
render(<StrategyDetail provider={mockProvider} detail={detailEmpty} onHide={mockOnHide} />)
expect(screen.getByText('setBuiltInTools.parameters')).toBeInTheDocument()
})
it('should handle no output schema', () => {
const detailNoOutput = { ...mockDetail, output_schema: undefined as unknown as Record<string, unknown> }
render(<StrategyDetail provider={mockProvider} detail={detailNoOutput} onHide={mockOnHide} />)
expect(screen.queryByText('OUTPUT')).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,102 @@
import type { StrategyDetail } from '@/app/components/plugins/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import StrategyItem from './strategy-item'
vi.mock('@/hooks/use-i18n', () => ({
useRenderI18nObject: () => (obj: Record<string, string>) => obj?.en_US || '',
}))
vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
vi.mock('./strategy-detail', () => ({
default: ({ onHide }: { onHide: () => void }) => (
<div data-testid="strategy-detail-panel">
<button data-testid="hide-btn" onClick={onHide}>Hide</button>
</div>
),
}))
const mockProvider = {
author: 'test-author',
name: 'test-provider',
description: { en_US: 'Provider desc' } as Record<string, string>,
tenant_id: 'tenant-1',
icon: 'icon.png',
label: { en_US: 'Test Provider' } as Record<string, string>,
tags: [] as string[],
}
const mockDetail = {
identity: {
author: 'author-1',
name: 'strategy-1',
icon: 'icon.png',
label: { en_US: 'Strategy Label' } as Record<string, string>,
provider: 'provider-1',
},
parameters: [],
description: { en_US: 'Strategy description' } as Record<string, string>,
output_schema: {},
features: [],
} as StrategyDetail
describe('StrategyItem', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render strategy label', () => {
render(<StrategyItem provider={mockProvider} detail={mockDetail} />)
expect(screen.getByText('Strategy Label')).toBeInTheDocument()
})
it('should render strategy description', () => {
render(<StrategyItem provider={mockProvider} detail={mockDetail} />)
expect(screen.getByText('Strategy description')).toBeInTheDocument()
})
it('should not show detail panel initially', () => {
render(<StrategyItem provider={mockProvider} detail={mockDetail} />)
expect(screen.queryByTestId('strategy-detail-panel')).not.toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should show detail panel when clicked', () => {
render(<StrategyItem provider={mockProvider} detail={mockDetail} />)
fireEvent.click(screen.getByText('Strategy Label'))
expect(screen.getByTestId('strategy-detail-panel')).toBeInTheDocument()
})
it('should hide detail panel when hide is called', () => {
render(<StrategyItem provider={mockProvider} detail={mockDetail} />)
fireEvent.click(screen.getByText('Strategy Label'))
expect(screen.getByTestId('strategy-detail-panel')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('hide-btn'))
expect(screen.queryByTestId('strategy-detail-panel')).not.toBeInTheDocument()
})
})
describe('Props', () => {
it('should handle empty description', () => {
const detailWithEmptyDesc = {
...mockDetail,
description: { en_US: '' } as Record<string, string>,
} as StrategyDetail
render(<StrategyItem provider={mockProvider} detail={detailWithEmptyDesc} />)
expect(screen.getByText('Strategy Label')).toBeInTheDocument()
})
})
})

View File

@@ -1874,4 +1874,187 @@ describe('CommonCreateModal', () => {
expect(screen.getByTestId('modal')).toHaveAttribute('data-disabled', 'true')
})
})
describe('normalizeFormType Additional Branches', () => {
it('should handle "text" type by returning textInput', () => {
const detailWithText = createMockPluginDetail({
declaration: {
trigger: {
subscription_constructor: {
credentials_schema: [],
parameters: [
{ name: 'text_type_field', type: 'text' },
],
},
},
},
})
mockUsePluginStore.mockReturnValue(detailWithText)
const builder = createMockSubscriptionBuilder()
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.OAUTH} builder={builder} />)
expect(screen.getByTestId('form-field-text_type_field')).toBeInTheDocument()
})
it('should handle "secret" type by returning secretInput', () => {
const detailWithSecret = createMockPluginDetail({
declaration: {
trigger: {
subscription_constructor: {
credentials_schema: [],
parameters: [
{ name: 'secret_type_field', type: 'secret' },
],
},
},
},
})
mockUsePluginStore.mockReturnValue(detailWithSecret)
const builder = createMockSubscriptionBuilder()
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.OAUTH} builder={builder} />)
expect(screen.getByTestId('form-field-secret_type_field')).toBeInTheDocument()
})
})
describe('HandleManualPropertiesChange Provider Fallback', () => {
it('should not call updateBuilder when provider is empty', async () => {
const detailWithEmptyProvider = createMockPluginDetail({
provider: '',
declaration: {
trigger: {
subscription_schema: [
{ name: 'webhook_url', type: 'text', required: true },
],
subscription_constructor: {
credentials_schema: [],
parameters: [],
},
},
},
})
mockUsePluginStore.mockReturnValue(detailWithEmptyProvider)
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} />)
const input = screen.getByTestId('form-field-webhook_url')
fireEvent.change(input, { target: { value: 'https://example.com/webhook' } })
// updateBuilder should not be called when provider is empty
expect(mockUpdateBuilder).not.toHaveBeenCalled()
})
})
describe('Configuration Step Without Endpoint', () => {
it('should handle builder without endpoint', async () => {
const builderWithoutEndpoint = createMockSubscriptionBuilder({
endpoint: '',
})
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} builder={builderWithoutEndpoint} />)
// Component should render without errors
expect(screen.getByTestId('modal')).toBeInTheDocument()
})
})
describe('ApiKeyStep Flow Additional Coverage', () => {
it('should handle verify when no builder created yet', async () => {
const detailWithCredentials = createMockPluginDetail({
declaration: {
trigger: {
subscription_constructor: {
credentials_schema: [
{ name: 'api_key', type: 'secret', required: true },
],
},
},
},
})
mockUsePluginStore.mockReturnValue(detailWithCredentials)
// Make createBuilder slow
mockCreateBuilder.mockImplementation(() => new Promise(resolve => setTimeout(resolve, 1000)))
render(<CommonCreateModal {...defaultProps} />)
// Click verify before builder is created
fireEvent.click(screen.getByTestId('modal-confirm'))
// Should still attempt to verify
expect(screen.getByTestId('modal')).toBeInTheDocument()
})
})
describe('Auto Parameters Not For APIKEY in Configuration', () => {
it('should include parameters for APIKEY in configuration step', async () => {
const detailWithParams = createMockPluginDetail({
declaration: {
trigger: {
subscription_constructor: {
credentials_schema: [
{ name: 'api_key', type: 'secret', required: true },
],
parameters: [
{ name: 'extra_param', type: 'string', required: true },
],
},
},
},
})
mockUsePluginStore.mockReturnValue(detailWithParams)
// First verify credentials
mockVerifyCredentials.mockImplementation((params, { onSuccess }) => {
onSuccess()
})
const builder = createMockSubscriptionBuilder()
render(<CommonCreateModal {...defaultProps} builder={builder} />)
// Click verify
fireEvent.click(screen.getByTestId('modal-confirm'))
await waitFor(() => {
expect(mockVerifyCredentials).toHaveBeenCalled()
})
// Now in configuration step, should see extra_param
expect(screen.getByTestId('form-field-extra_param')).toBeInTheDocument()
})
})
describe('needCheckValidatedValues Option', () => {
it('should pass needCheckValidatedValues: false for manual properties', async () => {
const detailWithManualSchema = createMockPluginDetail({
declaration: {
trigger: {
subscription_schema: [
{ name: 'webhook_url', type: 'text', required: true },
],
subscription_constructor: {
credentials_schema: [],
parameters: [],
},
},
},
})
mockUsePluginStore.mockReturnValue(detailWithManualSchema)
render(<CommonCreateModal {...defaultProps} createType={SupportedCreationMethods.MANUAL} />)
await waitFor(() => {
expect(mockCreateBuilder).toHaveBeenCalled()
})
const input = screen.getByTestId('form-field-webhook_url')
fireEvent.change(input, { target: { value: 'test' } })
await waitFor(() => {
expect(mockUpdateBuilder).toHaveBeenCalled()
})
})
})
})

View File

@@ -1475,4 +1475,213 @@ describe('CreateSubscriptionButton', () => {
})
})
})
// ==================== OAuth Callback Edge Cases ====================
describe('OAuth Callback - Falsy Data', () => {
it('should not open modal when OAuth callback returns falsy data', async () => {
// Arrange
const { openOAuthPopup } = await import('@/hooks/use-oauth')
vi.mocked(openOAuthPopup).mockImplementation((url: string, callback: (data?: unknown) => void) => {
callback(undefined) // falsy callback data
return null
})
const mockBuilder: TriggerSubscriptionBuilder = {
id: 'oauth-builder',
name: 'OAuth Builder',
provider: 'test-provider',
credential_type: TriggerCredentialTypeEnum.Oauth2,
credentials: {},
endpoint: 'https://test.com',
parameters: {},
properties: {},
workflows_in_use: 0,
}
mockInitiateOAuth.mockImplementation((_provider: string, callbacks: { onSuccess: (response: { authorization_url: string, subscription_builder: TriggerSubscriptionBuilder }) => void }) => {
callbacks.onSuccess({
authorization_url: 'https://oauth.test.com/authorize',
subscription_builder: mockBuilder,
})
})
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.OAUTH, SupportedCreationMethods.MANUAL],
}),
oauthConfig: createOAuthConfig({ configured: true }),
})
const props = createDefaultProps()
// Act
render(<CreateSubscriptionButton {...props} />)
// Click on OAuth option
const oauthOption = screen.getByTestId(`option-${SupportedCreationMethods.OAUTH}`)
fireEvent.click(oauthOption)
// Assert - modal should NOT open because callback data was falsy
await waitFor(() => {
expect(screen.queryByTestId('common-create-modal')).not.toBeInTheDocument()
})
})
})
// ==================== TriggerProps ClassName Branches ====================
describe('TriggerProps ClassName Branches', () => {
it('should apply pointer-events-none when non-default method with multiple supported methods', () => {
// Arrange - Single APIKEY method (methodType = APIKEY, not DEFAULT_METHOD)
// But we need multiple methods to test this branch
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.APIKEY, SupportedCreationMethods.MANUAL],
}),
})
const props = createDefaultProps()
// Act
render(<CreateSubscriptionButton {...props} />)
// The methodType will be DEFAULT_METHOD since multiple methods
// This verifies the render doesn't crash with multiple methods
expect(screen.getByTestId('custom-select')).toHaveAttribute('data-value', 'default')
})
})
// ==================== Tooltip Disabled Branches ====================
describe('Tooltip Disabled Branches', () => {
it('should enable tooltip when single method and not at max count', () => {
// Arrange
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.MANUAL],
}),
subscriptions: [createSubscription()], // Not at max
})
const props = createDefaultProps({ buttonType: CreateButtonType.ICON_BUTTON })
// Act
render(<CreateSubscriptionButton {...props} />)
// Assert - tooltip should be enabled (disabled prop = false for single method)
expect(screen.getByTestId('custom-trigger')).toBeInTheDocument()
})
it('should disable tooltip when multiple methods and not at max count', () => {
// Arrange
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.MANUAL, SupportedCreationMethods.APIKEY],
}),
subscriptions: [createSubscription()], // Not at max
})
const props = createDefaultProps({ buttonType: CreateButtonType.ICON_BUTTON })
// Act
render(<CreateSubscriptionButton {...props} />)
// Assert - tooltip should be disabled (neither single method nor at max)
expect(screen.getByTestId('custom-trigger')).toBeInTheDocument()
})
})
// ==================== Tooltip PopupContent Branches ====================
describe('Tooltip PopupContent Branches', () => {
it('should show max count message when at max subscriptions', () => {
// Arrange
const maxSubscriptions = createMaxSubscriptions()
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.MANUAL],
}),
subscriptions: maxSubscriptions,
})
const props = createDefaultProps({ buttonType: CreateButtonType.ICON_BUTTON })
// Act
render(<CreateSubscriptionButton {...props} />)
// Assert - component renders with max subscriptions
expect(screen.getByTestId('custom-trigger')).toBeInTheDocument()
})
it('should show method description when not at max', () => {
// Arrange
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.MANUAL],
}),
subscriptions: [], // Not at max
})
const props = createDefaultProps({ buttonType: CreateButtonType.ICON_BUTTON })
// Act
render(<CreateSubscriptionButton {...props} />)
// Assert - component renders without max subscriptions
expect(screen.getByTestId('custom-trigger')).toBeInTheDocument()
})
})
// ==================== Provider Info Fallbacks ====================
describe('Provider Info Fallbacks', () => {
it('should handle undefined supported_creation_methods', () => {
// Arrange - providerInfo with undefined supported_creation_methods
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: {
...createProviderInfo(),
supported_creation_methods: undefined as unknown as SupportedCreationMethods[],
},
})
const props = createDefaultProps()
// Act
const { container } = render(<CreateSubscriptionButton {...props} />)
// Assert - should render null when supported methods fallback to empty
expect(container).toBeEmptyDOMElement()
})
it('should handle providerInfo with null supported_creation_methods', () => {
// Arrange
mockProviderInfo = { data: { ...createProviderInfo(), supported_creation_methods: null as unknown as SupportedCreationMethods[] } }
mockOAuthConfig = { data: undefined, refetch: vi.fn() }
mockStoreDetail = createStoreDetail()
const props = createDefaultProps()
// Act
const { container } = render(<CreateSubscriptionButton {...props} />)
// Assert - should render null
expect(container).toBeEmptyDOMElement()
})
})
// ==================== Method Type Logic ====================
describe('Method Type Logic', () => {
it('should use single method as methodType when only one supported', () => {
// Arrange
setupMocks({
storeDetail: createStoreDetail(),
providerInfo: createProviderInfo({
supported_creation_methods: [SupportedCreationMethods.APIKEY],
}),
})
const props = createDefaultProps()
// Act
render(<CreateSubscriptionButton {...props} />)
// Assert
const customSelect = screen.getByTestId('custom-select')
expect(customSelect).toHaveAttribute('data-value', SupportedCreationMethods.APIKEY)
})
})
})

View File

@@ -1240,4 +1240,60 @@ describe('OAuthClientSettingsModal', () => {
vi.useRealTimers()
})
})
describe('OAuth Client Schema Params Fallback', () => {
it('should handle schema when params is truthy but schema name not in params', () => {
const configWithSchemaNotInParams = createMockOAuthConfig({
system_configured: false,
custom_enabled: true,
params: {
client_id: 'test-id',
client_secret: 'test-secret',
},
oauth_client_schema: [
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: { 'en-US': 'Client ID' } as unknown },
{ name: 'client_secret', type: 'secret-input' as unknown, required: true, label: { 'en-US': 'Client Secret' } as unknown },
{ name: 'extra_field', type: 'text-input' as unknown, required: false, label: { 'en-US': 'Extra' } as unknown },
] as TriggerOAuthConfig['oauth_client_schema'],
})
render(<OAuthClientSettingsModal {...defaultProps} oauthConfig={configWithSchemaNotInParams} />)
// extra_field should be rendered but without default value
const extraInput = screen.getByTestId('form-field-extra_field') as HTMLInputElement
expect(extraInput.defaultValue).toBe('')
})
it('should handle oauth_client_schema with undefined params', () => {
const configWithUndefinedParams = createMockOAuthConfig({
system_configured: false,
custom_enabled: true,
params: undefined as unknown as TriggerOAuthConfig['params'],
oauth_client_schema: [
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: { 'en-US': 'Client ID' } as unknown },
] as TriggerOAuthConfig['oauth_client_schema'],
})
render(<OAuthClientSettingsModal {...defaultProps} oauthConfig={configWithUndefinedParams} />)
// Form should not render because params is undefined (schema condition fails)
expect(screen.queryByTestId('base-form')).not.toBeInTheDocument()
})
it('should handle oauth_client_schema with null params', () => {
const configWithNullParams = createMockOAuthConfig({
system_configured: false,
custom_enabled: true,
params: null as unknown as TriggerOAuthConfig['params'],
oauth_client_schema: [
{ name: 'client_id', type: 'text-input' as unknown, required: true, label: { 'en-US': 'Client ID' } as unknown },
] as TriggerOAuthConfig['oauth_client_schema'],
})
render(<OAuthClientSettingsModal {...defaultProps} oauthConfig={configWithNullParams} />)
// Form should not render because params is null
expect(screen.queryByTestId('base-form')).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,287 @@
import type { TriggerEvent } from '@/app/components/plugins/types'
import type { TriggerProviderApiEntity } from '@/app/components/workflow/block-selector/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { EventDetailDrawer } from './event-detail-drawer'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
useLanguage: () => 'en_US',
}))
vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
vi.mock('@/app/components/plugins/card/base/card-icon', () => ({
default: () => <span data-testid="card-icon" />,
}))
vi.mock('@/app/components/plugins/card/base/description', () => ({
default: ({ text }: { text: string }) => <div data-testid="description">{text}</div>,
}))
vi.mock('@/app/components/plugins/card/base/org-info', () => ({
default: ({ orgName }: { orgName: string }) => <div data-testid="org-info">{orgName}</div>,
}))
vi.mock('@/app/components/tools/utils/to-form-schema', () => ({
triggerEventParametersToFormSchemas: (params: Array<Record<string, unknown>>) =>
params.map(p => ({
label: (p.label as Record<string, string>) || { en_US: p.name as string },
type: (p.type as string) || 'text-input',
required: (p.required as boolean) || false,
description: p.description as Record<string, string> | undefined,
})),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show/field', () => ({
default: ({ name }: { name: string }) => <div data-testid="output-field">{name}</div>,
}))
const mockEventInfo = {
name: 'test-event',
identity: {
author: 'test-author',
name: 'test-event',
label: { en_US: 'Test Event' },
},
description: { en_US: 'Test event description' },
parameters: [
{
name: 'param1',
label: { en_US: 'Parameter 1' },
type: 'text-input',
auto_generate: null,
template: null,
scope: null,
required: true,
multiple: false,
default: null,
min: null,
max: null,
precision: null,
description: { en_US: 'A test parameter' },
},
],
output_schema: {
properties: {
result: { type: 'string', description: 'Result' },
},
required: ['result'],
},
} as unknown as TriggerEvent
const mockProviderInfo = {
provider: 'test-provider',
author: 'test-author',
name: 'test-provider/test-name',
icon: 'icon.png',
description: { en_US: 'Provider desc' },
supported_creation_methods: [],
} as unknown as TriggerProviderApiEntity
describe('EventDetailDrawer', () => {
const mockOnClose = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render drawer', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should render event title', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('Test Event')).toBeInTheDocument()
})
it('should render event description', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByTestId('description')).toHaveTextContent('Test event description')
})
it('should render org info', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByTestId('org-info')).toBeInTheDocument()
})
it('should render parameters section', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('setBuiltInTools.parameters')).toBeInTheDocument()
expect(screen.getByText('Parameter 1')).toBeInTheDocument()
})
it('should render output section', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.output')).toBeInTheDocument()
expect(screen.getByTestId('output-field')).toHaveTextContent('result')
})
it('should render back button', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('detailPanel.operation.back')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call onClose when close button clicked', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
// Find the close button (ActionButton with action-btn class)
const closeButton = screen.getAllByRole('button').find(btn => btn.classList.contains('action-btn'))
if (closeButton)
fireEvent.click(closeButton)
expect(mockOnClose).toHaveBeenCalledTimes(1)
})
it('should call onClose when back clicked', () => {
render(<EventDetailDrawer eventInfo={mockEventInfo} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
fireEvent.click(screen.getByText('detailPanel.operation.back'))
expect(mockOnClose).toHaveBeenCalledTimes(1)
})
})
describe('Edge Cases', () => {
it('should handle no parameters', () => {
const eventWithNoParams = { ...mockEventInfo, parameters: [] }
render(<EventDetailDrawer eventInfo={eventWithNoParams} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.item.noParameters')).toBeInTheDocument()
})
it('should handle no output schema', () => {
const eventWithNoOutput = { ...mockEventInfo, output_schema: {} }
render(<EventDetailDrawer eventInfo={eventWithNoOutput} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.output')).toBeInTheDocument()
expect(screen.queryByTestId('output-field')).not.toBeInTheDocument()
})
})
describe('Parameter Types', () => {
it('should display correct type for number-input', () => {
const eventWithNumber = {
...mockEventInfo,
parameters: [{ ...mockEventInfo.parameters[0], type: 'number-input' }],
}
render(<EventDetailDrawer eventInfo={eventWithNumber} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('setBuiltInTools.number')).toBeInTheDocument()
})
it('should display correct type for checkbox', () => {
const eventWithCheckbox = {
...mockEventInfo,
parameters: [{ ...mockEventInfo.parameters[0], type: 'checkbox' }],
}
render(<EventDetailDrawer eventInfo={eventWithCheckbox} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('boolean')).toBeInTheDocument()
})
it('should display correct type for file', () => {
const eventWithFile = {
...mockEventInfo,
parameters: [{ ...mockEventInfo.parameters[0], type: 'file' }],
}
render(<EventDetailDrawer eventInfo={eventWithFile} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('setBuiltInTools.file')).toBeInTheDocument()
})
it('should display original type for unknown types', () => {
const eventWithUnknown = {
...mockEventInfo,
parameters: [{ ...mockEventInfo.parameters[0], type: 'custom-type' }],
}
render(<EventDetailDrawer eventInfo={eventWithUnknown} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('custom-type')).toBeInTheDocument()
})
})
describe('Output Schema Conversion', () => {
it('should handle array type in output schema', () => {
const eventWithArrayOutput = {
...mockEventInfo,
output_schema: {
properties: {
items: { type: 'array', items: { type: 'string' }, description: 'Array items' },
},
required: [],
},
}
render(<EventDetailDrawer eventInfo={eventWithArrayOutput} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.output')).toBeInTheDocument()
})
it('should handle nested properties in output schema', () => {
const eventWithNestedOutput = {
...mockEventInfo,
output_schema: {
properties: {
nested: {
type: 'object',
properties: { inner: { type: 'string' } },
required: ['inner'],
},
},
required: [],
},
}
render(<EventDetailDrawer eventInfo={eventWithNestedOutput} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.output')).toBeInTheDocument()
})
it('should handle enum in output schema', () => {
const eventWithEnumOutput = {
...mockEventInfo,
output_schema: {
properties: {
status: { type: 'string', enum: ['active', 'inactive'], description: 'Status' },
},
required: [],
},
}
render(<EventDetailDrawer eventInfo={eventWithEnumOutput} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.output')).toBeInTheDocument()
})
it('should handle array type schema', () => {
const eventWithArrayType = {
...mockEventInfo,
output_schema: {
properties: {
multi: { type: ['string', 'null'], description: 'Multi type' },
},
required: [],
},
}
render(<EventDetailDrawer eventInfo={eventWithArrayType} providerInfo={mockProviderInfo} onClose={mockOnClose} />)
expect(screen.getByText('events.output')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,146 @@
import type { TriggerEvent } from '@/app/components/plugins/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { TriggerEventsList } from './event-list'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.num !== undefined)
return `${options.num} ${options.event || 'events'}`
return key
},
}),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
useLanguage: () => 'en_US',
}))
vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
const mockTriggerEvents = [
{
name: 'event-1',
identity: {
author: 'author-1',
name: 'event-1',
label: { en_US: 'Event One' },
},
description: { en_US: 'Event one description' },
parameters: [],
output_schema: {},
},
] as unknown as TriggerEvent[]
let mockDetail: { plugin_id: string, provider: string } | undefined
let mockProviderInfo: { events: TriggerEvent[] } | undefined
vi.mock('../store', () => ({
usePluginStore: (selector: (state: { detail: typeof mockDetail }) => typeof mockDetail) =>
selector({ detail: mockDetail }),
}))
vi.mock('@/service/use-triggers', () => ({
useTriggerProviderInfo: () => ({ data: mockProviderInfo }),
}))
vi.mock('./event-detail-drawer', () => ({
EventDetailDrawer: ({ onClose }: { onClose: () => void }) => (
<div data-testid="event-detail-drawer">
<button data-testid="close-drawer" onClick={onClose}>Close</button>
</div>
),
}))
describe('TriggerEventsList', () => {
beforeEach(() => {
vi.clearAllMocks()
mockDetail = { plugin_id: 'test-plugin', provider: 'test-provider' }
mockProviderInfo = { events: mockTriggerEvents }
})
describe('Rendering', () => {
it('should render event count', () => {
render(<TriggerEventsList />)
expect(screen.getByText('1 events.event')).toBeInTheDocument()
})
it('should render event cards', () => {
render(<TriggerEventsList />)
expect(screen.getByText('Event One')).toBeInTheDocument()
expect(screen.getByText('Event one description')).toBeInTheDocument()
})
it('should return null when no provider info', () => {
mockProviderInfo = undefined
const { container } = render(<TriggerEventsList />)
expect(container).toBeEmptyDOMElement()
})
it('should return null when no events', () => {
mockProviderInfo = { events: [] }
const { container } = render(<TriggerEventsList />)
expect(container).toBeEmptyDOMElement()
})
it('should return null when no detail', () => {
mockDetail = undefined
mockProviderInfo = undefined
const { container } = render(<TriggerEventsList />)
expect(container).toBeEmptyDOMElement()
})
})
describe('User Interactions', () => {
it('should show detail drawer when event card clicked', () => {
render(<TriggerEventsList />)
fireEvent.click(screen.getByText('Event One'))
expect(screen.getByTestId('event-detail-drawer')).toBeInTheDocument()
})
it('should hide detail drawer when close clicked', () => {
render(<TriggerEventsList />)
fireEvent.click(screen.getByText('Event One'))
expect(screen.getByTestId('event-detail-drawer')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('close-drawer'))
expect(screen.queryByTestId('event-detail-drawer')).not.toBeInTheDocument()
})
})
describe('Multiple Events', () => {
it('should render multiple event cards', () => {
const secondEvent = {
name: 'event-2',
identity: {
author: 'author-2',
name: 'event-2',
label: { en_US: 'Event Two' },
},
description: { en_US: 'Event two description' },
parameters: [],
output_schema: {},
} as unknown as TriggerEvent
mockProviderInfo = {
events: [...mockTriggerEvents, secondEvent],
}
render(<TriggerEventsList />)
expect(screen.getByText('Event One')).toBeInTheDocument()
expect(screen.getByText('Event Two')).toBeInTheDocument()
expect(screen.getByText('2 events.events')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,72 @@
import { describe, expect, it } from 'vitest'
import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { NAME_FIELD } from './utils'
describe('utils', () => {
describe('NAME_FIELD', () => {
it('should have correct type', () => {
expect(NAME_FIELD.type).toBe(FormTypeEnum.textInput)
})
it('should have correct name', () => {
expect(NAME_FIELD.name).toBe('name')
})
it('should have label translations', () => {
expect(NAME_FIELD.label).toBeDefined()
expect(NAME_FIELD.label.en_US).toBe('Endpoint Name')
expect(NAME_FIELD.label.zh_Hans).toBe('端点名称')
expect(NAME_FIELD.label.ja_JP).toBe('エンドポイント名')
expect(NAME_FIELD.label.pt_BR).toBe('Nome do ponto final')
})
it('should have placeholder translations', () => {
expect(NAME_FIELD.placeholder).toBeDefined()
expect(NAME_FIELD.placeholder.en_US).toBe('Endpoint Name')
expect(NAME_FIELD.placeholder.zh_Hans).toBe('端点名称')
expect(NAME_FIELD.placeholder.ja_JP).toBe('エンドポイント名')
expect(NAME_FIELD.placeholder.pt_BR).toBe('Nome do ponto final')
})
it('should be required', () => {
expect(NAME_FIELD.required).toBe(true)
})
it('should have empty default value', () => {
expect(NAME_FIELD.default).toBe('')
})
it('should have null help', () => {
expect(NAME_FIELD.help).toBeNull()
})
it('should have all required field properties', () => {
const requiredKeys = ['type', 'name', 'label', 'placeholder', 'required', 'default', 'help']
requiredKeys.forEach((key) => {
expect(NAME_FIELD).toHaveProperty(key)
})
})
it('should match expected structure', () => {
expect(NAME_FIELD).toEqual({
type: FormTypeEnum.textInput,
name: 'name',
label: {
en_US: 'Endpoint Name',
zh_Hans: '端点名称',
ja_JP: 'エンドポイント名',
pt_BR: 'Nome do ponto final',
},
placeholder: {
en_US: 'Endpoint Name',
zh_Hans: '端点名称',
ja_JP: 'エンドポイント名',
pt_BR: 'Nome do ponto final',
},
required: true,
default: '',
help: null,
})
})
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,9 @@ import difyI18n from './eslint-rules/index.js'
export default antfu(
{
react: {
reactCompiler: true,
// This react compiler rules are pretty slow
// We can wait for https://github.com/Rel1cx/eslint-react/issues/1237
reactCompiler: false,
overrides: {
'react/no-context-provider': 'off',
'react/no-forward-ref': 'off',
@@ -57,47 +59,8 @@ export default antfu(
// sonar
{
rules: {
...sonar.configs.recommended.rules,
// code complexity
'sonarjs/cognitive-complexity': 'off',
'sonarjs/no-nested-functions': 'warn',
'sonarjs/no-nested-conditional': 'warn',
'sonarjs/nested-control-flow': 'warn', // 3 levels of nesting
'sonarjs/no-small-switch': 'off',
'sonarjs/no-nested-template-literals': 'warn',
'sonarjs/redundant-type-aliases': 'off',
'sonarjs/regex-complexity': 'warn',
// maintainability
'sonarjs/no-ignored-exceptions': 'off',
'sonarjs/no-commented-code': 'warn',
'sonarjs/no-unused-vars': 'warn',
'sonarjs/prefer-single-boolean-return': 'warn',
'sonarjs/duplicates-in-character-class': 'off',
'sonarjs/single-char-in-character-classes': 'off',
'sonarjs/anchor-precedence': 'warn',
'sonarjs/updated-loop-counter': 'off',
'sonarjs/no-dead-store': 'error',
'sonarjs/no-duplicated-branches': 'warn',
'sonarjs/max-lines': 'warn', // max 1000 lines
'sonarjs/no-variable-usage-before-declaration': 'error',
// security
'sonarjs/no-hardcoded-passwords': 'off', // detect the wrong code that is not password.
'sonarjs/no-hardcoded-secrets': 'off',
'sonarjs/pseudo-random': 'off',
// performance
'sonarjs/slow-regex': 'warn',
// others
'sonarjs/todo-tag': 'warn',
'sonarjs/table-header': 'off',
// new from this update
'sonarjs/unused-import': 'off',
'sonarjs/use-type-alias': 'warn',
'sonarjs/single-character-alternation': 'warn',
'sonarjs/no-os-command-from-path': 'warn',
'sonarjs/class-name': 'off',
'sonarjs/no-redundant-jump': 'warn',
// Manually pick rules that are actually useful and not slow.
// Or we can just drop the plugin entirely.
},
plugins: {
sonarjs: sonar,

View File

@@ -91,6 +91,7 @@
"apiBasedExtension.title": "API extensions provide centralized API management, simplifying configuration for easy use across Dify's applications.",
"apiBasedExtension.type": "Type",
"appMenus.apiAccess": "API Access",
"appMenus.apiAccessTip": "This knowledge base is accessible via the Service API",
"appMenus.logAndAnn": "Logs & Annotations",
"appMenus.logs": "Logs",
"appMenus.overview": "Monitoring",
@@ -281,7 +282,7 @@
"model.params.setToCurrentModelMaxTokenTip": "Max token is updated to the 80% maximum token of the current model {{maxToken}}.",
"model.params.stop_sequences": "Stop sequences",
"model.params.stop_sequencesPlaceholder": "Enter sequence and press Tab",
"model.params.stop_sequencesTip": "Up to four sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.",
"model.params.stop_sequencesTip": "Up to four sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.",
"model.params.temperature": "Temperature",
"model.params.temperatureTip": "Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.",
"model.params.top_p": "Top P",

View File

@@ -170,7 +170,7 @@
"serviceApi.card.endpoint": "Service API Endpoint",
"serviceApi.card.title": "Backend service api",
"serviceApi.disabled": "Disabled",
"serviceApi.enabled": "In Service",
"serviceApi.enabled": "Enabled",
"serviceApi.title": "Service API",
"unavailable": "Unavailable",
"updated": "Updated",

View File

@@ -91,6 +91,7 @@
"apiBasedExtension.title": "API 扩展提供了一个集中式的 API 管理,在此统一添加 API 配置后,方便在 Dify 上的各类应用中直接使用。",
"apiBasedExtension.type": "类型",
"appMenus.apiAccess": "访问 API",
"appMenus.apiAccessTip": "此知识库可通过服务 API 访问",
"appMenus.logAndAnn": "日志与标注",
"appMenus.logs": "日志",
"appMenus.overview": "监测",

View File

@@ -170,7 +170,7 @@
"serviceApi.card.endpoint": "API 端点",
"serviceApi.card.title": "后端服务 API",
"serviceApi.disabled": "已停用",
"serviceApi.enabled": "运行中",
"serviceApi.enabled": "已启用",
"serviceApi.title": "服务 API",
"unavailable": "不可用",
"updated": "更新于",

View File

@@ -0,0 +1,17 @@
import type { ICurrentWorkspace } from '@/models/common'
import { useQuery } from '@tanstack/react-query'
import { get } from './base'
type WorkspacePermissions = {
workspace_id: ICurrentWorkspace['id']
allow_member_invite: boolean
allow_owner_transfer: boolean
}
export function useWorkspacePermissions(workspaceId: ICurrentWorkspace['id'], enabled: boolean) {
return useQuery({
queryKey: ['workspace-permissions', workspaceId],
queryFn: () => get<WorkspacePermissions>('/workspaces/current/permission'),
enabled: enabled && !!workspaceId,
})
}