diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index f58377c4a5..82d9c21cbb 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -20,11 +20,11 @@ ```typescript // ❌ WRONG: Don't mock base components vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/button', () => ({ children }: any) => ) +vi.mock('@/app/components/base/ui/button', () => ({ children }: any) => ) // ✅ CORRECT: Import and use real base components import Loading from '@/app/components/base/loading' -import Button from '@/app/components/base/button' +import { Button } from '@/app/components/base/ui/button' // They will render normally in tests ``` diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 29f5b090f8..c32fc9d0cb 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -77,8 +77,6 @@ jobs: with: files: | web/** - e2e/** - sdks/nodejs-client/** packages/** package.json pnpm-lock.yaml @@ -97,14 +95,14 @@ jobs: id: eslint-cache-restore uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: - path: .eslintcache - key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} restore-keys: | - ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: . + working-directory: ./web run: vp run lint:ci - name: Web tsslint @@ -114,7 +112,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: . + working-directory: ./web run: vp run type-check - name: Web dead code check @@ -126,7 +124,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: - path: .eslintcache + path: web/.eslintcache key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} superlinter: diff --git a/.gitignore b/.gitignore index 3493a7c756..53dea88899 100644 --- a/.gitignore +++ b/.gitignore @@ -203,7 +203,6 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json.template -!.vscode/settings.example.json !.vscode/README.md api/.vscode # vscode Code History Extension @@ -243,5 +242,3 @@ scripts/stress-test/reports/ # Code Agent Folder .qoder/* - -.eslintcache diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index d48381bce2..cced022568 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -56,9 +56,44 @@ if $api_modified; then fi fi -if $skip_web_checks; then - echo "Git operation in progress, skipping web checks" - exit 0 -fi +if $web_modified; then + if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 + fi -vp staged + echo "Running ESLint on web module" + + if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then + web_ts_modified=false + else + ts_diff_status=$? + if [ $ts_diff_status -eq 1 ]; then + web_ts_modified=true + else + echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)." + exit $ts_diff_status + fi + fi + + cd ./web || exit 1 + pnpm exec vp staged + + if $web_ts_modified; then + echo "Running TypeScript type-check:tsgo" + if ! npm run type-check:tsgo; then + echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors." + exit 1 + fi + else + echo "No staged TypeScript changes detected, skipping type-check:tsgo" + fi + + echo "Running knip" + if ! npm run knip; then + echo "Knip check failed. Please run 'npm run knip' to fix the errors." + exit 1 + fi + + cd ../ +fi diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 0e9508d18a..23351beed9 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -120,6 +120,9 @@ from .explore import ( saved_message, trial, ) + +# Import snippet controllers +from .snippets import snippet_workflow, snippet_workflow_draft_variable from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import snippet controllers @@ -211,6 +214,9 @@ __all__ = [ "saved_message", "setup", "site", + "snippet_workflow", + "snippet_workflow_draft_variable", + "snippets", "socketio_workflow", "snippet_workflow", "snippet_workflow_draft_variable", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 27c8008e37..75a18a477a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,6 +5,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -29,7 +30,6 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db from fields.base import ResponseModel -from graphon.enums import WorkflowExecutionStatus from libs.helper import build_icon_url from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85a..78ddb904e1 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,7 +23,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a..d83925d173 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index c720a5e074..7101d5df7b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -19,7 +20,6 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index d517f695b8..5b1abc98dc 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -18,6 +18,12 @@ from models.enums import AppMCPServerStatus from models.model import AppMCPServer +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict[str, Any] = Field(..., description="Server parameters configuration") @@ -30,25 +36,19 @@ class MCPServerUpdatePayload(BaseModel): status: str | None = Field(default=None, description="Server status") -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class AppMCPServerResponse(ResponseModel): id: str name: str server_code: str description: str - status: AppMCPServerStatus + status: str parameters: dict[str, Any] | list[Any] | str created_at: int | None = None updated_at: int | None = None @field_validator("parameters", mode="before") @classmethod - def _normalize_parameters(cls, value: Any) -> Any: + def _parse_json_string(cls, value: Any) -> Any: if isinstance(value, str): try: return json.loads(value) @@ -70,9 +70,7 @@ class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @console_ns.doc(description="Get MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response( - 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__] - ) + @console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__]) @login_required @account_initialization_required @setup_required @@ -87,9 +85,7 @@ class AppMCPServerController(Resource): @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) - @console_ns.response( - 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__] - ) + @console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @get_app_model @@ -115,15 +111,13 @@ class AppMCPServerController(Resource): ) db.session.add(server) db.session.commit() - return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201 + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) - @console_ns.response( - 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__] - ) + @console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @get_app_model @@ -160,7 +154,7 @@ class AppMCPServerRefreshController(Resource): @console_ns.doc("refresh_app_mcp_server") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(params={"server_id": "Server ID"}) - @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @setup_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index e1e0b1eef0..16cad35f1c 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,12 +5,17 @@ from typing import Any, Literal from flask import abort, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.enums import NodeType +from graphon.file import File +from graphon.file import helpers as file_helpers +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services -from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload +from controllers.common.controller_schemas import DefaultBlockConfigQuery from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.workflow_run import workflow_run_node_execution_model @@ -37,11 +42,6 @@ from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from graphon.enums import NodeType -from graphon.file import File -from graphon.file import helpers as file_helpers -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -171,7 +171,6 @@ class WorkflowTypeConvertQuery(BaseModel): target_type: Literal["workflow", "evaluation"] - class WorkflowFeaturesPayload(BaseModel): features: dict[str, Any] = Field(..., description="Workflow feature configuration") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 4b39590235..6b402898e8 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -4,6 +4,7 @@ from typing import Any from dateutil.parser import isoparse from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker @@ -15,7 +16,6 @@ from extensions.ext_database import db from fields.base import ResponseModel from fields.end_user_fields import SimpleEndUser from fields.member_fields import SimpleAccount -from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 6748d95d6b..a1a075be71 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,6 +3,8 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -26,8 +28,6 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 727428c8e7..b55cda4244 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -5,11 +5,11 @@ from typing import Concatenate from flask import jsonify, request from flask.typing import ResponseReturnValue from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 0b493d2c71..14ca27acbd 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -4,6 +4,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -53,7 +54,6 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3372a967d9..98d4ad9412 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,19 +3,20 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack -from datetime import datetime from typing import Any, Literal, cast import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, marshal -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from pydantic import BaseModel, Field from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -30,14 +31,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.base import ResponseModel +from fields.dataset_fields import dataset_fields from fields.document_fields import ( + dataset_and_document_fields, document_fields, + document_metadata_fields, document_status_fields, document_with_segments_fields, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -71,100 +72,27 @@ from ..wraps import ( logger = logging.getLogger(__name__) -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value +# Register models for flask_restx to avoid dict type issues in Swagger +dataset_model = get_or_create_model("Dataset", dataset_fields) +document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -def _normalize_enum(value: Any) -> Any: - if isinstance(value, str) or value is None: - return value - return getattr(value, "value", value) +document_fields_copy = document_fields.copy() +document_fields_copy["doc_metadata"] = fields.List( + fields.Nested(document_metadata_model), attribute="doc_metadata_details" +) +document_model = get_or_create_model("Document", document_fields_copy) +document_with_segments_fields_copy = document_with_segments_fields.copy() +document_with_segments_fields_copy["doc_metadata"] = fields.List( + fields.Nested(document_metadata_model), attribute="doc_metadata_details" +) +document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) -class DatasetResponse(ResponseModel): - id: str - name: str - description: str | None = None - permission: str | None = None - data_source_type: str | None = None - indexing_technique: str | None = None - created_by: str | None = None - created_at: int | None = None - - @field_validator("data_source_type", "indexing_technique", mode="before") - @classmethod - def _normalize_enum_fields(cls, value: Any) -> Any: - return _normalize_enum(value) - - @field_validator("created_at", mode="before") - @classmethod - def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) - - -class DocumentMetadataResponse(ResponseModel): - id: str - name: str - type: str - value: str | None = None - - -class DocumentResponse(ResponseModel): - id: str - position: int | None = None - data_source_type: str | None = None - data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") - data_source_detail_dict: Any = None - dataset_process_rule_id: str | None = None - name: str - created_from: str | None = None - created_by: str | None = None - created_at: int | None = None - tokens: int | None = None - indexing_status: str | None = None - error: str | None = None - enabled: bool | None = None - disabled_at: int | None = None - disabled_by: str | None = None - archived: bool | None = None - display_status: str | None = None - word_count: int | None = None - hit_count: int | None = None - doc_form: str | None = None - doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") - summary_index_status: str | None = None - need_summary: bool | None = None - - @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") - @classmethod - def _normalize_enum_fields(cls, value: Any) -> Any: - return _normalize_enum(value) - - @field_validator("doc_metadata", mode="before") - @classmethod - def _normalize_doc_metadata(cls, value: Any) -> list[Any]: - if value is None: - return [] - return value - - @field_validator("created_at", "disabled_at", mode="before") - @classmethod - def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) - - -class DocumentWithSegmentsResponse(DocumentResponse): - process_rule_dict: Any = None - completed_segments: int | None = None - total_segments: int | None = None - - -class DatasetAndDocumentResponse(ResponseModel): - dataset: DatasetResponse - documents: list[DocumentResponse] - batch: str +dataset_and_document_fields_copy = dataset_and_document_fields.copy() +dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) +dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) +dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) class DocumentRetryPayload(BaseModel): @@ -179,11 +107,6 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] -class DocumentMetadataUpdatePayload(BaseModel): - doc_type: str | None = None - doc_metadata: Any = None - - class DocumentDatasetListParam(BaseModel): page: int = Field(1, title="Page", description="Page number.") limit: int = Field(20, title="Limit", description="Page size.") @@ -201,13 +124,7 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, - DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, - DatasetResponse, - DocumentMetadataResponse, - DocumentResponse, - DocumentWithSegmentsResponse, - DatasetAndDocumentResponse, ) @@ -440,10 +357,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required + @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -481,9 +398,7 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return DatasetAndDocumentResponse.model_validate( - {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True - ).model_dump(mode="json") + return {"dataset": dataset, "documents": documents, "batch": batch} @setup_required @login_required @@ -511,13 +426,12 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response( - 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] - ) + @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required + @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -565,9 +479,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return DatasetAndDocumentResponse.model_validate( - {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True - ).model_dump(mode="json") + response = {"dataset": dataset, "documents": documents, "batch": batch} + + return response @console_ns.route("/datasets//documents//indexing-estimate") @@ -1074,7 +988,15 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) + @console_ns.expect( + console_ns.model( + "UpdateDocumentMetadataRequest", + { + "doc_type": fields.String(description="Document type"), + "doc_metadata": fields.Raw(description="Document metadata"), + }, + ) + ) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1087,10 +1009,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) + req_data = request.get_json() - doc_type = req_data.doc_type - doc_metadata = req_data.doc_metadata + doc_type = req_data.get("doc_type") + doc_metadata = req_data.get("doc_metadata") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1272,7 +1194,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) + @marshal_with(document_model) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1290,7 +1212,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") + return document @console_ns.route("/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2647bb1f5a..354c299bef 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,6 +2,7 @@ import uuid from flask import request from flask_restx import Resource, marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -31,7 +32,6 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 699fa599c8..8fb3699849 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,6 +2,7 @@ import logging from typing import Any from flask_restx import marshal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -20,7 +21,6 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index fd0a8b33bc..bdf83b991e 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,6 +2,8 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -10,8 +12,6 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index b31d73f27d..3549f9542d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -4,6 +4,7 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -27,7 +28,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index ee146e8287..a8077d9eb0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -40,7 +41,6 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index ab660d9dc3..a37077af42 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,6 +1,7 @@ import logging from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -19,7 +20,6 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index ccdccceaa6..eacd7332fe 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,6 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 209667d1d0..64d55d7ca3 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -24,7 +25,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.enums import FeedbackRating diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1456301a24..0a3595454a 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,6 +3,8 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -59,8 +61,6 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 438cce4fd8..da88de6776 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,5 +1,7 @@ import logging +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -21,8 +23,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 2a46d2250a..551c86fd82 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,6 +2,7 @@ import urllib.parse import httpx from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -15,7 +16,6 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 764f488755..3fdcbc4710 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index f45b72f390..b6b9deb1f9 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a6f37aec8..e4cfca9fa4 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 4b10561fdb..cbb9677309 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b2d07ff8f9..f8f95304f0 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3e344ccea..aa674a63b3 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -14,7 +15,6 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 471594f349..c9956501e2 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -27,7 +28,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index d11b66244f..7a28a09861 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,6 +3,7 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden @@ -15,7 +16,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 72cab3de73..83c8fa02fe 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -29,7 +30,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 2f309262cb..a5846e2815 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,13 +20,10 @@ class TenantUserPayload(BaseModel): def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ - Get current user. + Get current user NOTE: user_id is not trusted, it could be maliciously set to any value. - As a result, it could only be considered as an end user id. Even when a - concrete end-user ID is supplied, lookups must stay tenant-scoped so one - tenant cannot bind another tenant's user record into the plugin request - context. + As a result, it could only be considered as an end user id. """ if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -45,14 +42,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: .limit(1) ) else: - user_model = session.scalar( - select(EndUser) - .where( - EndUser.id == user_id, - EndUser.tenant_id == tenant_id, - ) - .limit(1) - ) + user_model = session.get(EndUser, user_id) if not user_model: user_model = EndUser( diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index f652bbc581..8066f198bb 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,6 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -11,7 +12,6 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity, VariableEntityType from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index e818573b8f..907dd1b06d 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -21,7 +22,6 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 31f2797d66..3142e5118e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,6 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -28,7 +29,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c4353ca7b8..50851aea08 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, TypeAdapter, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound @@ -21,7 +22,6 @@ from fields.conversation_fields import ( ConversationInfiniteScrollPagination, SimpleConversation, ) -from graphon.variables.types import SegmentType from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 76519cad0a..fd954be6b1 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,6 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -18,7 +19,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 5992fa7410..971b63577c 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,6 +2,7 @@ from typing import Any from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -22,7 +23,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 5ac65fc4e6..c0a6cb0a76 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 3ad595f1f4..0ef4471018 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import field_validator from werkzeug.exceptions import InternalServerError @@ -21,7 +22,6 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0528184d79..e37f9af5f0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging from typing import Any, Literal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 07ecf8035b..39afdd843f 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.enums import FeedbackRating from models.model import AppMode diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index fe31e9d4ac..38aeccc642 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -13,7 +14,6 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 98211193a0..796e090976 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,5 +1,7 @@ import logging +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -22,8 +24,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 790602ef5d..06c746990d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,6 +4,20 @@ import uuid from decimal import Decimal from typing import Union, cast +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity @@ -29,20 +43,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0bc93ad34d..f07ac64498 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,6 +4,15 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any, TypedDict +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) + from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -15,14 +24,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index a2186be100..2b2e26987e 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 51a30998ae..d4c52a8eb1 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -9,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index d38d24d1e7..fdffde85d0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,13 +4,6 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -26,6 +19,14 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index f341ca5a0b..8cccd2be6d 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,9 +3,10 @@ import re from collections.abc import Generator from typing import Any, Union -from core.agent.entities import AgentScratchpadUnit from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from core.agent.entities import AgentScratchpadUnit + class CotAgentOutputParser: @classmethod diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index dbd7527fc6..b7dd55632e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,13 +1,14 @@ from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 02498c23e1..9d980e5ca3 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 4c07445df3..57c6d1c496 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,7 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessageRole + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -7,7 +9,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 53563dc5da..819aca864c 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,14 +1,14 @@ from enum import StrEnum, auto from typing import Any, Literal -from pydantic import BaseModel, Field - -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.entities import MetadataFilteringCondition from graphon.file import FileUploadConfig from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity +from pydantic import BaseModel, Field + +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.entities import MetadataFilteringCondition from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 8f20ef2ff9..959c3868b4 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any -from constants import DEFAULT_FILE_NUMBER_LIMITS from graphon.file import FileUploadConfig +from constants import DEFAULT_FILE_NUMBER_LIMITS + class FileUploadConfigManager: @classmethod diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 13ace32fd6..62e0c31d1a 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,8 @@ import re -from core.app.app_config.entities import RagPipelineVariableEntity from graphon.variables.input_entities import VariableEntity + +from core.app.app_config.entities import RagPipelineVariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 9e64b471cb..985ded0f74 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,6 +18,11 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -43,10 +48,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 4e57b4dedc..7b4cb98bd4 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -37,12 +43,6 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 5cdc477028..5872f6b264 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 09ddce327e..a20d3f3c38 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,9 @@ import logging from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -16,9 +19,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index d5edfaeb25..406d07927e 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union +from graphon.model_runtime.errors.invoke import InvokeError + from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index d1771452c5..20bf81aeec 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,6 +7,7 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod +from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -21,7 +22,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client -from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 1251b397e2..4aebc0cb30 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,17 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -30,16 +41,6 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 58afefe296..891dcece73 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 077c5239f3..050f763e95 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -16,8 +18,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 2a90fbdad0..ab277857fe 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.runtime import GraphRuntimeState +from core.workflow.system_variables import SystemVariableKey, get_system_text + if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index bd685d5189..a515531616 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,6 +6,19 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -55,19 +68,6 @@ from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 423bfdac51..61339b316a 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6bb1ecdcb1..b216f7cf7b 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -14,8 +16,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 4b2f17189b..83c74b86e5 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,6 +10,8 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, cast, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -41,8 +43,6 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 2ee0ae27eb..36daaf09e9 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,6 +2,12 @@ import logging import time from typing import cast +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -20,12 +26,6 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 1acb1acaf3..3421a13133 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,6 +8,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, overload from flask import Flask, current_app +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -34,10 +38,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index cfb9208486..2cb8088971 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Sequence from typing import cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -15,11 +21,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 15645add57..96387133b1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,6 +4,9 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -58,9 +61,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 047b54c86c..437432611d 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,6 +3,39 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -49,39 +82,6 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 09992f4bbf..a3fb7b4c5d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 221b7fb058..482f995d8e 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any -from pydantic import BaseModel, ConfigDict, Field - -from core.app.entities.agent_strategy import AgentStrategyInfo -from core.rag.entities import RetrievalSourceMetadata from graphon.entities import WorkflowStartReason from graphon.entities.pause_reason import PauseReason from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from pydantic import BaseModel, ConfigDict, Field + +from core.app.entities.agent_strategy import AgentStrategyInfo +from core.rag.entities import RetrievalSourceMetadata class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 6e4ca69cf0..88faf235d1 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any -from pydantic import BaseModel, ConfigDict, Field - -from core.app.entities.agent_strategy import AgentStrategyInfo -from core.rag.entities import RetrievalSourceMetadata from graphon.entities import WorkflowStartReason from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage from graphon.nodes.human_input.entities import FormInput, UserAction +from pydantic import BaseModel, ConfigDict, Field + +from core.app.entities.agent_strategy import AgentStrategyInfo +from core.rag.entities import RetrievalSourceMetadata class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d59f5125e3..d2d2fea4fb 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,8 +1,9 @@ import logging +from graphon.model_runtime.entities.message_entities import PromptMessage + from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 9811f9f830..c027f42788 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index bb9fc1b6fa..8c8daf8712 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore - from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent + from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index b60fe82ffe..77c7bec67e 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index c49c4eb0ac..278d0cb30b 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,15 +2,16 @@ from __future__ import annotations from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider + from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index b6039e1e4e..0bb10190c4 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,4 @@ +from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import sessionmaker @@ -7,7 +8,6 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 9e688589db..10b9c36d3e 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 4a7918032e..c577ce0754 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,16 +7,17 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final, override -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance + if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 87f005a250..ada065a943 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,13 +14,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -45,6 +38,14 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 9e3c187210..3d8a7a54f3 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,6 +6,9 @@ import re import threading from collections.abc import Iterable +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -15,8 +18,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index dc831e5cac..a5297fa33a 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,6 +3,9 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -28,9 +31,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 352e6bfd49..9c22d5e67c 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, TypedDict +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject, I18nObjectDict -from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 6a3f9e684a..c012e128f4 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,10 +2,11 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type +from graphon.file import File, FileTransferMethod, FileType + from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 04ae193396..d304c982cd 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field -from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index bfa4f56915..a440829b46 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -15,7 +16,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e99a131500..84d95c38c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from enum import StrEnum, auto -from pydantic import BaseModel, ConfigDict - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity +from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 1ab66cceee..d07f6f913a 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -8,6 +8,16 @@ from collections.abc import Iterator, Sequence from json import JSONDecodeError from typing import Any +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -24,16 +34,6 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 72b29c2277..95431c0e01 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Any, Union +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -12,7 +13,6 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 951e065b2c..35bfcfb6a5 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Any import httpx +from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index dc37a36943..a1e782a094 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,13 +2,14 @@ import logging import secrets from typing import cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 8bcb899b23..f8f56e12d2 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,12 +1,12 @@ from typing import Any from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index d2e375626f..a8ad7c9179 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,11 +5,6 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -26,6 +21,11 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance class ResponseFormat(StrEnum): diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 884610ca82..72171d1536 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,11 +3,12 @@ import logging from collections.abc import Mapping from typing import Any, NotRequired, TypedDict, cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7b5a7635f1..7e35044176 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d840ee213c..5809d6f74a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,14 +1,5 @@ from collections.abc import Sequence -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,6 +10,15 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index d8d8dfedd8..36beb55d7f 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,15 +2,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload -from configs import dify_config -from core.entities import PluginCredentialType -from core.entities.embedding_type import EmbeddingInputType -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.errors.error import ProviderTokenNotInitError -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from core.provider_manager import ProviderManager -from extensions.ext_redis import redis_client from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -24,6 +15,16 @@ from graphon.model_runtime.model_providers.__base.rerank_model import RerankMode from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel + +from configs import dify_config +from core.entities import PluginCredentialType +from core.entities.embedding_type import EmbeddingInputType +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 76e81242f4..70aaf2a07b 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -58,8 +60,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 2e02a186cc..aa35ac74c2 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any, TypedDict +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -15,8 +17,6 @@ from core.ops.aliyun_trace.entities.semconv import ( ) from core.rag.models.document import Document from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 78516e1a22..dd5edde630 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse +from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -40,7 +41,6 @@ from core.ops.entities.trace_entity import ( from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index d960038f15..490c64af84 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 87fcaeabcc..c070a937be 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow +from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -25,7 +26,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 672efe45bd..e0c7b9bfe5 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import Any, cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker @@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 36878dc58f..f79095d966 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -6,6 +6,8 @@ import json import logging from datetime import datetime +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -41,8 +43,6 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index d681b9da80..84f54d8a5a 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -4,6 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -25,10 +29,6 @@ from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index f79544f1c7..8d9ba4694d 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -6,6 +6,7 @@ from typing import Any, cast import wandb import weave +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -32,7 +33,6 @@ from core.ops.entities.trace_entity import ( from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index c92438960a..a4b24ff849 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -3,6 +3,20 @@ from binascii import hexlify, unhexlify from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -19,19 +33,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 9550e49992..9478997494 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,4 +1,3 @@ -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( @@ -9,6 +8,8 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) + +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 89e0e8881c..4d28032a57 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any +from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -13,7 +14,6 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 257638ad77..e0ddb746c7 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,8 @@ from datetime import datetime from enum import StrEnum from typing import Any +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -16,8 +18,6 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 1474883204..4a85952dcd 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,10 +4,6 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -25,6 +21,10 @@ from graphon.nodes.parameter_extractor.entities import ( from graphon.nodes.question_classifier.entities import ( ClassConfig, ) +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 9ee8469892..7f36560b49 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,6 +5,14 @@ from collections.abc import Callable, Generator from typing import Any, cast import httpx +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -29,14 +37,6 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 47608bdfa6..703af63f7c 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,6 +2,13 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -13,12 +20,6 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 12d8e282b2..90350f8400 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,8 +1,9 @@ from typing import Any -from core.tools.entities.tool_entities import ToolSelector from graphon.file import File +from core.tools.entities.tool_entities import ToolSelector + def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: for parameter_name, parameter in parameters.items(): diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 24e05ef865..19b5e9223a 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,13 +1,6 @@ from collections.abc import Mapping, Sequence from typing import cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -20,6 +13,14 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser + class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 8f1d51f08a..9be70199b7 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,10 +1,5 @@ from typing import cast -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.prompt_transform import PromptTransform from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, @@ -12,6 +7,12 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.prompt_transform import PromptTransform + class AgentHistoryPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 6ff2f44cdc..4539ae9f11 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,11 +1,12 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 1665bdeb52..dc8391a6a5 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,12 +4,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, TypedDict, cast -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -19,6 +13,13 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index ba76eb0c4e..dbda749925 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from typing import Any, cast -from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) +from core.prompt.simple_prompt_transform import ModelMode + class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c3bbe8fc09..39ef31632e 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -6,6 +6,14 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -33,14 +41,6 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 7e71d67ec0..f978e072f3 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired, TypedDict from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only @@ -23,7 +24,6 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -195,23 +195,6 @@ class RetrievalService: ) return all_documents - @classmethod - def _filter_documents_by_vector_score_threshold( - cls, documents: list[Document], score_threshold: float | None - ) -> list[Document]: - """Keep documents whose stored retrieval score meets the threshold. - - Used when hybrid search skips early vector thresholding but no rerank - runner applies a threshold afterward (same rule as ``calculate_vector_score``). - """ - if score_threshold is None: - return documents - return [ - document - for document in documents - if document.metadata and document.metadata.get("score", 0) >= score_threshold - ] - @classmethod def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: """Deduplicate documents in O(n) while preserving first-seen order. @@ -311,20 +294,13 @@ class RetrievalService: vector = Vector(dataset=dataset) documents = [] - # Hybrid search merges keyword / full-text / vector hits and then reranks - # (weighted fusion or reranking model). Applying the user score threshold at - # vector retrieval time uses embedding similarity, which is not comparable to - # reranked or fused scores and incorrectly drops high-quality chunks (#35233). - embedding_score_threshold = ( - 0.0 if retrieval_method == RetrievalMethod.HYBRID_SEARCH else score_threshold - ) if query_type == QueryType.TEXT_QUERY: documents.extend( vector.search_by_vector( query, search_type="similarity_score_threshold", top_k=top_k, - score_threshold=embedding_score_threshold, + score_threshold=score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -336,7 +312,7 @@ class RetrievalService: vector.search_by_file( file_id=query, top_k=top_k, - score_threshold=embedding_score_threshold, + score_threshold=score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -868,10 +844,6 @@ class RetrievalService: top_n=top_k, query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, ) - if not data_post_processor.rerank_runner and score_threshold: - all_documents_item = self._filter_documents_by_vector_score_threshold( - all_documents_item, score_threshold - ) all_documents.extend(all_documents_item) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 59d7f3c3c4..dddd5fc994 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,6 +4,7 @@ import time from abc import ABC, abstractmethod from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -18,7 +19,6 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index f4699f6869..8e9ebdd17a 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 4926f44f16..9f1c73ec88 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,6 +4,8 @@ import pickle from typing import Any, cast import numpy as np +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -13,8 +15,6 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index f8242efe31..a487c49053 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -7,6 +7,16 @@ from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -33,16 +43,6 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4ebf095904..087736d0b0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from pydantic import BaseModel, Field - from graphon.file import File +from pydantic import BaseModel, Field class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index bce08f998f..a8d37845a5 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,8 @@ import base64 +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult + from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -7,8 +10,6 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d0732b269a..49123e13d0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,6 +2,7 @@ import math from collections import Counter import numpy as np +from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -11,7 +12,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 1453fe020b..8ebc840b99 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,11 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker @@ -64,11 +69,6 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index e617a9660e..dce7b6226c 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,9 +1,10 @@ from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 2581c354dd..3383c7f3bd 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -7,9 +7,10 @@ import re from collections.abc import Collection from typing import Any, Literal +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index e87d1cd6b2..b07c63fdf0 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -7,11 +7,11 @@ providing improved performance by offloading database operations to background w import logging +from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 2451563317..cdb3af01a8 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -8,6 +8,7 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -15,7 +16,6 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) -from graphon.entities import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 4e83e70799..ce3ad15759 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 6be3902317..d74cc8f231 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -5,13 +5,13 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index b036687bc9..13e885672a 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import psycopg2.errors +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -19,10 +23,6 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 95660ab93b..e539074303 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,14 +2,15 @@ import io from collections.abc import Generator from typing import Any +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.file import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index ac3820f1ab..f49c669fe0 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,12 +2,13 @@ import io from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index d41503e1e6..14af63a962 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,12 @@ from __future__ import annotations +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 168e5f4493..0a2c37c563 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,6 +6,7 @@ from typing import Any, Union from urllib.parse import urlencode import httpx +from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -13,7 +14,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 42a88c0003..410ec72baf 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -9,7 +10,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 00fc8a8282..f6d09472b3 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,6 +6,8 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -21,7 +23,6 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 3caacb8706..d060fa8b49 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast +from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -32,7 +33,6 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index b3424cd9a5..d8674b3af9 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,6 +9,7 @@ from mimetypes import guess_extension, guess_type from uuid import uuid4 import httpx +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from sqlalchemy import select from configs import dify_config @@ -16,7 +17,6 @@ from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f4588904d3..be13d40f3e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,6 +8,7 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa +from graphon.runtime import VariablePool from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session @@ -28,13 +29,14 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db -from graphon.runtime import VariablePool from models.provider_ids import ToolProviderID from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -60,7 +62,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index b6890b2611..03e3c5918d 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,6 +1,7 @@ import threading from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -14,7 +15,6 @@ from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 79d0c114d4..81c85bc90d 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -9,11 +9,11 @@ from uuid import UUID import numpy as np import pytz +from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e1d41cb39..8d6f83dc07 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,9 +8,6 @@ import json from decimal import Decimal from typing import cast -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -23,6 +20,10 @@ from graphon.model_runtime.errors.invoke import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 45718cadb6..2159eb8638 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,12 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError + class WorkflowToolConfigurationUtils: @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 5905fd919e..a01004448a 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 52ab605963..7c4f8ee03a 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,6 +5,8 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -20,8 +22,6 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 24c1271488..61d1cd8540 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -27,7 +28,6 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 51452c29a3..c52aad150b 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e4f6b3b470..d9247b2593 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,12 +1,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.system_variables import SystemVariableKey, get_system_segment from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, @@ -18,6 +12,13 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment + from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 28966f2392..cad32f8d5b 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,10 +1,9 @@ from typing import Any, Literal, Union -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 260881e49c..04a10f9257 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,13 +1,13 @@ from typing import Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel from core.rag.entities import RerankingModelConfig, WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class RetrievalSetting(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d5cab05dbe..bb72fe3881 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,16 +2,17 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template + from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 3825f526a2..460ec693ce 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,11 +1,11 @@ from typing import Literal -from pydantic import BaseModel, Field - -from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.llm.entities import ModelConfig, VisionConfig +from pydantic import BaseModel, Field + +from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig __all__ = ["Condition"] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 47ad14b499..13624b27b3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,11 +8,6 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( @@ -32,6 +27,12 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference + from .entities import ( Condition, KnowledgeRetrievalNodeData, diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index 23ed2cd408..bf5be2379a 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index c848a86255..e50de11bb9 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,12 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID + from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 683c8d420f..04f1f7e6bb 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b46cc76a6e..a9753ab387 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,12 @@ from collections.abc import Mapping -from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID + from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index b261039448..a30f877e4b 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from pydantic import BaseModel, Field, field_validator - -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from graphon.entities.base_node_data import BaseNodeData from graphon.enums import NodeType from graphon.variables.types import SegmentType +from pydantic import BaseModel, Field, field_validator + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE _WEBHOOK_HEADER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset((SegmentType.STRING,)) diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 13c4f05bfd..d942a718cc 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,10 +2,6 @@ import logging from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult @@ -14,6 +10,11 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type + from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 0c535a1c5b..b7e7a6e60f 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -6,9 +6,9 @@ import click from sqlalchemy import select from werkzeug.exceptions import NotFound -from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.document_index_event import document_index_created +from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document from models.enums import IndexingStatus @@ -22,25 +22,24 @@ def handle(sender, **kwargs): document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() - with session_factory.create_session() as session: - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = session.scalar( - select(Document).where( - Document.id == document_id, - Document.dataset_id == dataset_id, - ) + document = db.session.scalar( + select(Document).where( + Document.id == document_id, + Document.dataset_id == dataset_id, ) + ) - if not document: - raise NotFound("Document not found") + if not document: + raise NotFound("Document not found") - document.indexing_status = IndexingStatus.PARSING - document.processing_started_at = naive_utc_now() - documents.append(document) - session.add(document) - session.commit() + document.indexing_status = IndexingStatus.PARSING + document.processing_started_at = naive_utc_now() + documents.append(document) + db.session.add(document) + db.session.commit() with contextlib.suppress(Exception): try: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 5e2a456dce..84be592b1a 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,5 @@ -from core.db.session_factory import session_factory from events.app_event import app_was_created +from extensions.ext_database import db from models.enums import CustomizeTokenStrategy from models.model import Site @@ -22,6 +22,6 @@ def handle(sender, **kwargs): created_by=app.created_by, updated_by=app.updated_by, ) - with session_factory.create_session() as session: - session.add(site) - session.commit() + + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index ba9758175f..7bd8e88231 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,11 +1,12 @@ import logging +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity + from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 6769b94cde..86b5b2bbf0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db -from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 69d1f1ab07..5cc58f27c4 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,12 +5,11 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk + from graphon.model_runtime.errors.invoke import InvokeRateLimitError from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from graphon.model_runtime.errors.invoke import InvokeRateLimitError - try: from langfuse._utils import parse_error diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 64ff0f0674..db599c5d49 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 7f77a0437a..2745141431 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string -from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 544109276d..d0f3e2e244 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -3,14 +3,14 @@ import logging import os import time +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index fbf379b3e5..23d324f9ea 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes -from graphon.enums import BuiltinNodeTypes -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index ec3c78a12d..335c5cc29e 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 56672d1fd4..6df5f62c15 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index 75ddbba448..b9fdd9e1ca 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from opentelemetry.trace import Span - -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps -from extensions.otel.semconv.gen_ai import ToolAttributes from graphon.enums import WorkflowNodeExecutionMetadataKey from graphon.graph_events import GraphNodeEventBase from graphon.nodes.base.node import Node from graphon.nodes.tool.entities import ToolNodeData +from opentelemetry.trace import Span + +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import ToolAttributes class ToolNodeOTelParser: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index fd7acb14d3..57205b5739 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,11 +8,6 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -36,6 +31,12 @@ from graphon.variables.variables import ( VariableBase, ) +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) + __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 67b320beaa..cfe0015918 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,10 +3,10 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields +from graphon.file import helpers as file_helpers from pydantic import computed_field, field_validator from fields.base import ResponseModel -from graphon.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index ca18f1c203..1a871204a0 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,12 +3,12 @@ from __future__ import annotations from datetime import datetime from uuid import uuid4 +from graphon.file import File from pydantic import Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.base import ResponseModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile -from graphon.file import File type JSONValueType = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index ee6f53b360..4c65cdab7a 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,5 +1,4 @@ from flask_restx import fields - from graphon.file import File diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 94549d2152..e103abee43 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/libs/helper.py b/api/libs/helper.py index ac69a11084..69bd483515 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,6 +16,8 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, TypeAdapter from pydantic.functional_validators import AfterValidator from typing_extensions import TypedDict @@ -23,8 +25,6 @@ from typing_extensions import TypedDict from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account diff --git a/api/models/comment.py b/api/models/comment.py index 308339e6f6..7018c7e1f2 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -7,7 +7,7 @@ from sqlalchemy import Index, func from sqlalchemy.orm import Mapped, mapped_column, relationship from .account import Account -from .base import Base +from .base import Base, gen_uuidv7_string from .engine import db from .types import StringUUID @@ -41,7 +41,7 @@ class WorkflowComment(Base): Index("workflow_comments_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position_x: Mapped[float] = mapped_column(db.Float) @@ -148,7 +148,7 @@ class WorkflowCommentReply(Base): Index("comment_replies_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) comment_id: Mapped[str] = mapped_column( StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) @@ -193,7 +193,7 @@ class WorkflowCommentMention(Base): Index("comment_mentions_user_idx", "mentioned_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) comment_id: Mapped[str] = mapped_column( StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) diff --git a/api/models/dataset.py b/api/models/dataset.py index eee5c39a0e..50301dd2d7 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase): ) -class DocumentSegmentSummary(TypeBase): +class DocumentSegmentSummary(Base): __tablename__ = "document_segment_summaries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), @@ -1725,40 +1725,25 @@ class DocumentSegmentSummary(TypeBase): sa.Index("document_segment_summaries_status_idx", "status"), ) - id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - insert_default=lambda: str(uuid4()), - default_factory=lambda: str(uuid4()), - init=False, - ) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # corresponds to DocumentSegment.id or parent chunk id chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) - summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) - summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) - status: Mapped[SummaryStatus] = mapped_column( - EnumText(SummaryStatus, length=32), - nullable=False, - server_default=sa.text("'generating'"), - default=SummaryStatus.GENERATING, - ) - error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) - disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) - disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - created_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), init=False + summary_content: Mapped[str] = mapped_column(LongText, nullable=True) + summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) + summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + status: Mapped[str] = mapped_column( + EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") ) + error: Mapped[str] = mapped_column(LongText, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - DateTime, - nullable=False, - server_default=func.current_timestamp(), - onupdate=func.current_timestamp(), - init=False, + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) def __repr__(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index b4c7a634b6..79c5d62f6a 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship from core.workflow.human_input_compat import DeliveryMethodType -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 7fe0731098..8eabf45363 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,6 +14,9 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker @@ -21,9 +24,6 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping diff --git a/api/models/provider.py b/api/models/provider.py index 2bb67d605b..8270961b31 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,10 +6,10 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column -from graphon.model_runtime.entities.model_entities import ModelType from libs.uuid_utils import uuidv7 from .base import TypeBase diff --git a/api/models/workflow.py b/api/models/workflow.py index adedecfc9f..db2a647dc8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,6 +8,19 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -31,19 +44,6 @@ from core.workflow.variable_prefixes import ( ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.file.constants import maybe_file_object -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -53,10 +53,11 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase + from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 72b38e7906..100589804c 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol, TypedDict +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 44735eb769..d5c6a203b1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 474b200fc5..b760696c5e 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,15 +28,15 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 67f8795d3f..feba5f7eb6 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,6 +7,9 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -18,9 +21,6 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 78806927bc..74b800606d 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -10,6 +10,12 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel @@ -29,12 +35,6 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 6e9d6b1c73..0842e9d3e7 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,10 +5,11 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ +from graphon.graph_engine.manager import GraphEngineManager + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 60948e652b..1c7027efb4 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/billing_service.py b/api/services/billing_service.py index eeaddfee2f..068c93d94f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -53,6 +53,8 @@ class QuotaReleaseResult(TypedDict): _quota_reserve_adapter = TypeAdapter(QuotaReserveResult) _quota_commit_adapter = TypeAdapter(QuotaCommitResult) _quota_release_adapter = TypeAdapter(QuotaReleaseResult) + + class _BillingQuota(TypedDict): size: int limit: int diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index dcc93b4b0f..ea12e40420 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker @@ -13,7 +14,6 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ee8a1c4edd..f5085af59b 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable, Sequence from typing import Any +from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 287d513f48..95a8951951 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ +from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 416bc8cef9..364c4a86a0 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -17,7 +18,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 6679c08ebd..a944ef6acd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,15 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -15,15 +24,6 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from models.provider import ProviderType diff --git a/api/services/file_service.py b/api/services/file_service.py index 52da2a7951..79a935de4b 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,6 +8,7 @@ from tempfile import NamedTemporaryFile from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile +from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -23,7 +24,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType -from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 68ef67dec1..77576fa4c0 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol +from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 76598d31ac..02a6620fc7 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,6 +3,12 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -11,12 +17,6 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index c269346f5f..b652e049ce 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,6 +2,12 @@ import json import logging from typing import Any, TypedDict +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -12,12 +18,6 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 968600d1bc..605689226a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,6 +9,15 @@ from typing import Any, cast from uuid import uuid4 from flask_login import current_user +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -44,15 +53,6 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index f315d053cb..7dd86f1581 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -13,6 +13,12 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel from sqlalchemy import select @@ -27,12 +33,6 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 21be411bea..ab60986bfe 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any, TypedDict import click +from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 5ff2c21749..3bfa221528 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,9 +2,9 @@ import json import logging from typing import Any, TypedDict, cast +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select -from sqlalchemy.orm import sessionmaker from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -16,13 +16,11 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ) -from core.tools.errors import ApiToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -118,85 +116,71 @@ class ApiToolManageService: privacy_policy: str, custom_disclaimer: str, labels: list[str], - ) -> dict[str, Any]: + ): """ - Create a new API tool provider. - - :param user_id: The ID of the user creating the provider. - :param tenant_id: The ID of the workspace/tenant. - :param provider_name: The name of the API tool provider. - :param icon: The icon configuration for the provider. - :param credentials: The credentials for the provider. - :param schema_type: The type of schema (e.g., OpenAPI). - :param schema: The raw schema string. - :param privacy_policy: The privacy policy URL or text. - :param custom_disclaimer: Custom disclaimer text. - :param labels: A list of labels for the provider. - :return: A dictionary indicating the result status. + create api tool provider """ - provider_name = provider_name.strip() # check if the provider exists - # Create new session with automatic transaction management - with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: - provider: ApiToolProvider | None = _session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ) - .limit(1) + provider = db.session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, ) + .limit(1) + ) - if provider is not None: - raise ValueError(f"provider {provider_name} already exists") + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: - raise ValueError("the number of apis should be less than 100") + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") - # create API tool provider - api_tool_provider = ApiToolProvider( - tenant_id=tenant_id, - user_id=user_id, - name=provider_name, - icon=json.dumps(icon), - schema=schema, - description=extra_info.get("description", ""), - schema_type_str=schema_type, - tools_str=json.dumps(jsonable_encoder(tool_bundles)), - credentials_str="{}", - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - ) + # create db provider + db_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str="{}", + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity - provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # create provider entity + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) - api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) + # encrypt credentials + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) + db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) - _session.add(api_tool_provider) + db.session.add(db_provider) + db.session.commit() - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels, _session) + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) return {"result": "success"} @@ -228,25 +212,16 @@ class ApiToolManageService: @staticmethod def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ - List tools provided by a specific API tool provider. - - :param user_id: The ID of the user requesting the list. - :param tenant_id: The ID of the workspace/tenant. - :param provider_name: The name of the API tool provider. - :return: A list of ToolApiEntity objects. + list api tool provider tools """ - - # create new session with automatic transaction management - provider: ApiToolProvider | None = None - with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: - provider = _session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ) - .limit(1) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, ) + .limit(1) + ) if provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -276,133 +251,103 @@ class ApiToolManageService: privacy_policy: str | None, custom_disclaimer: str, labels: list[str], - ) -> dict[str, Any]: + ): """ - Update an existing API tool provider. - - :param user_id: The ID of the user updating the provider. - :param tenant_id: The ID of the workspace/tenant. - :param provider_name: The new name of the API tool provider. - :param original_provider: The original name of the API tool provider. - :param icon: The icon configuration for the provider. - :param credentials: The credentials for the provider. - :param _schema_type: The type of schema (e.g., OpenAPI). - :param schema: The raw schema string. - :param privacy_policy: The privacy policy URL or text. - :param custom_disclaimer: Custom disclaimer text. - :param labels: A list of labels for the provider. - :return: A dictionary indicating the result status. + update api tool provider """ - provider_name = provider_name.strip() # check if the provider exists - # create new session with automatic transaction management - with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: - provider: ApiToolProvider | None = _session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, - ) - .limit(1) + provider = db.session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, ) + .limit(1) + ) - if provider is None: - raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id) + if provider is None: + raise ValueError(f"api provider {provider_name} does not exists") + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = schema_type + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer - # update db provider - provider.name = provider_name - provider.icon = json.dumps(icon) - provider.schema = schema - provider.description = extra_info.get("description", "") - provider.schema_type_str = schema_type - provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) - provider.privacy_policy = privacy_policy - provider.custom_disclaimer = custom_disclaimer + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # get original credentials if exists + encrypter, cache = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) - # get original credentials if exists - encrypter, cache = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] - original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_plugin_credentials(original_credentials) + credentials = dict(encrypter.encrypt(credentials)) + provider.credentials_str = json.dumps(credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - - credentials = dict(encrypter.encrypt(credentials)) - provider.credentials_str = json.dumps(credentials) - - _session.add(provider) - - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels, _session) + db.session.add(provider) + db.session.commit() # delete cache cache.delete() + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) + return {"result": "success"} @staticmethod def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - Delete an API tool provider. - - :param user_id: The ID of the user performing the deletion operation. - :param tenant_id: The ID of the workspace/tenant where the provider belongs. - :param provider_name: The unique name of the API tool provider to be deleted. - :raises ValueError: If the specified provider does not exist in the tenant. - :return: A dictionary indicating the result status. + delete tool provider """ - - # create new session with automatic transaction management - with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: - provider: ApiToolProvider | None = _session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ) - .limit(1) + provider = db.session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, ) + .limit(1) + ) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") - _session.delete(provider) + db.session.delete(provider) + db.session.commit() return {"result": "success"} @staticmethod - def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]: + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): """ - Get API tool provider details. - - :param user_id: The ID of the user requesting the provider. - :param tenant_id: The ID of the workspace/tenant. - :param provider: The name of the API tool provider. - :return: A dictionary containing the provider details. + get api tool provider """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) @@ -415,20 +360,10 @@ class ApiToolManageService: parameters: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, - ) -> dict[str, Any]: + ): """ - Test an API tool before adding the API tool provider. - - :param tenant_id: The ID of the workspace/tenant. - :param provider_name: The name of the API tool provider. - :param tool_name: The name of the specific tool to test. - :param credentials: The credentials for the provider. - :param parameters: The parameters to pass to the tool. - :param schema_type: The type of schema (e.g., OpenAPI). - :param schema: The raw schema string. - :return: A dictionary containing the result or error message. + test api tool before adding api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") @@ -442,21 +377,18 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - # create new session with automatic transaction management to get the provider - provider: ApiToolProvider | None = None - with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: - provider = _session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ) - .limit(1) + db_provider = db.session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, ) + .limit(1) + ) - if provider is None: + if not db_provider: # create a fake db provider - provider = ApiToolProvider( + db_provider = ApiToolProvider( tenant_id="", user_id="", name="", @@ -475,12 +407,12 @@ class ApiToolManageService: auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) # decrypt credentials - if provider.id: + if db_provider.id: encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, controller=provider_controller, @@ -511,21 +443,14 @@ class ApiToolManageService: @staticmethod def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ - List all API tools for a specific tenant. - - :param tenant_id: The ID of the workspace/tenant. - :return: A list of ToolProviderApiEntity objects. + list api tools """ # get all api providers - # create new session with automatic transaction management - providers: list[ApiToolProvider] = [] - with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: - providers = list( - _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() - ) + db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() result: list[ToolProviderApiEntity] = [] - for provider in providers: + + for provider in db_providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 911331e357..5a5d13b96d 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -20,7 +21,6 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index c96050ce13..4d58a9cf12 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,7 +5,6 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, overload -from configs import dify_config from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( @@ -22,6 +21,8 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments +from configs import dify_config + _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 58193d75a9..9827c8dfbc 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,6 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager @@ -12,7 +13,6 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1658bdf99f..8e92449c6e 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,6 +1,11 @@ import json from typing import Any, TypedDict +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from sqlalchemy import select from core.app.app_config.entities import ( @@ -19,11 +24,6 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 8afb565955..fae5dea3cb 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -7,6 +7,19 @@ from datetime import datetime from enum import StrEnum from typing import Any, ClassVar, NotRequired, TypedDict +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -27,19 +40,6 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable -from graphon.enums import NodeType -from graphon.file import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 5fca444723..601e9261fc 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,6 +9,10 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -22,10 +26,6 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index a24d121c76..0caa066485 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,32 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from sqlalchemy import and_, exists, or_, select +from graphon.entities import WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from sqlalchemy import and_, exists, or_, select from sqlalchemy.orm import Session, sessionmaker @@ -39,31 +65,6 @@ from events.app_event import app_draft_workflow_was_synced, app_published_workfl from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings -from graphon.entities import WorkflowNodeExecution -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission -from graphon.nodes.human_input.enums import HumanInputFormKind -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import load_into_variable_pool -from graphon.variables import VariableBase -from graphon.variables.input_entities import VariableEntityType -from graphon.variables.variables import Variable from libs.datetime_utils import naive_utc_now from libs.helper import escape_like_pattern from models import Account diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index c22e7e9918..8f2f5f261e 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,6 +7,7 @@ from typing import Annotated, Any from celery import shared_task from flask import current_app, json +from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index beb23d8354..4db551c73c 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -8,6 +8,7 @@ from typing import Any import click import pandas as pd from celery import shared_task +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.db.session_factory import session_factory @@ -15,7 +16,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index fd743205a1..ca73b4d374 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,6 +2,8 @@ import logging from datetime import timedelta from celery import shared_task +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -9,8 +11,6 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index f8ae3f4b6e..a316eec7b9 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,6 +6,7 @@ from typing import Any import click from celery import shared_task +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -14,7 +15,6 @@ from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index d9df2733fd..b9f382eccf 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,8 +29,6 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType -from services.quota_service import unlimited -from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, @@ -43,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -299,10 +299,10 @@ def dispatch_triggered_workflow( icon_dark_filename=trigger_entity.identity.icon_dark or "", ) - # consume quota before invoking trigger + # reserve quota before invoking trigger quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) logger.info( @@ -388,6 +388,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 5ca04fd7c2..b4f975f4da 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -10,11 +10,11 @@ import logging from typing import Any from celery import shared_task +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 0d5475a56d..128cdd72e1 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -10,13 +10,13 @@ import logging from typing import Any from celery import shared_task -from sqlalchemy import select - -from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from sqlalchemy import select + +from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index a876b0c4aa..91245e879e 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,8 +1,9 @@ from collections.abc import Generator +from graphon.node_events import StreamCompletedEvent + from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index b5318aaa2b..3fdea10976 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,8 +1,9 @@ -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.workflow.nodes.datasource.datasource_node import DatasourceNode from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult, StreamCompletedEvent +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.workflow.nodes.datasource.datasource_node import DatasourceNode + class _Seg: def __init__(self, v): diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index c4146d5ccd..ce04a158a8 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,9 +4,6 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -26,6 +23,9 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index e130644338..c7bb90f019 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,6 +3,10 @@ import unittest import uuid import pytest +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from sqlalchemy import delete, func, select from sqlalchemy.orm import Session @@ -11,10 +15,6 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 4f444598b1..3dfedd811d 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest +from graphon.variables.segments import StringSegment from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -209,6 +209,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -452,6 +453,7 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617bae..c0143faa85 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,11 +1,12 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3476c292b..4f41396c22 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest - -from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.code_executor",) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index aa6cf1e021..b1f937e738 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,6 +3,11 @@ import uuid from urllib.parse import urlencode import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -11,11 +16,6 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -192,7 +192,6 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -203,6 +202,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool + from core.workflow.system_variables import build_system_variables + # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index fa5d63cfbf..f0f3fcead1 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,11 +4,6 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent from graphon.nodes.llm.file_saver import LLMFileSaver @@ -17,6 +12,12 @@ from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 52886855b8..fe512c2585 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,16 +3,17 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9e3e1a47e3..2d728569be 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,14 +1,15 @@ import time import uuid -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5a22f81a69..ea95959a82 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,11 +4,11 @@ import json import uuid from flask.testing import FlaskClient +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index c342e8994b..b4b65abdb6 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,6 +22,13 @@ import uuid from time import time import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -33,13 +40,6 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 14d5740072..3b1570a9a8 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,6 +4,7 @@ from __future__ import annotations from uuid import uuid4 +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import ( Account, AccountStatus, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index da4f8847d6..3ecf621095 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,17 +4,6 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_engine import GraphEngine @@ -27,6 +16,17 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 2e207ddc67..cc72dc1cf3 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 641399c7f9..a68b3a08c7 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index aebe87839c..64c93ac07c 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 00a2f9a59f..4f3c0e4200 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,6 +842,7 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType + from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 77ce28b999..6c15587058 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest import yaml from faker import Faker +from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -16,7 +17,6 @@ from core.trigger.constants import ( TRIGGER_WEBHOOK_NODE_TYPE, ) from extensions.ext_redis import redis_client -from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import AppModelConfig, IconType from services import app_dsl_service diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f..f9bfa570cb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index ac0483a45d..2974e00897 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -3,10 +3,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.account import ( Account, AccountStatus, diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae516..c8f04e9215 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select -from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 18c5320d0a..c46b8fba0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,6 +3,8 @@ import uuid from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from core.workflow.human_input_compat import ( EmailDeliveryConfig, @@ -10,8 +12,6 @@ from core.workflow.human_input_compat import ( EmailRecipients, ExternalRecipient, ) -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8955a3b5f2..ba926bf675 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,10 +405,11 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity + # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -643,9 +644,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", model_type=ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 1e57b5603d..749c6fff5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session -from graphon.enums import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7..0c281c8c33 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker +from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135..7c43bf676b 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index fa3ac12cf0..2fb62e0fc0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,8 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker -from sqlalchemy import ColumnElement, func, select -from sqlalchemy.orm import Session +from sqlalchemy import func, select from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -22,14 +21,6 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password -def _count_documents(session: Session, condition: ColumnElement[bool]) -> int: - return session.scalar(select(func.count()).select_from(Document).where(condition)) or 0 - - -def _count_segments(session: Session, condition: ColumnElement[bool]) -> int: - return session.scalar(select(func.count()).select_from(DocumentSegment).where(condition)) or 0 - - class TestCleanNotionDocumentTask: """Integration tests for clean_notion_document_task using testcontainers.""" @@ -155,14 +146,29 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert _count_documents(db_session_with_containers, Document.id.in_(document_ids)) == 3 - assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 6 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id.in_(document_ids)) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ) + == 6 + ) # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) # Verify segments are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ) + == 0 + ) # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value @@ -322,7 +328,12 @@ class TestCleanNotionDocumentTask: # The task properly handles various index types and document configurations. # Verify segments are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 0 + ) # Reset mock for next iteration mock_index_processor_factory.reset_mock() @@ -405,7 +416,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 0 + ) # Note: This test successfully verifies that segments without index_node_ids # are properly deleted from the database. @@ -491,8 +507,18 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == 5 - assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 10 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) + ) + == 5 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) + == 10 + ) # Clean up only first 3 documents documents_to_clean = [doc.id for doc in documents[:3]] @@ -502,12 +528,29 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) # Verify only specified documents' segments are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(documents_to_clean)) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) + .where(DocumentSegment.document_id.in_(documents_to_clean)) + ) + == 0 + ) # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 - assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 4 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs)) + ) + == 2 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs)) + ) + == 4 + ) # Note: This test successfully verifies partial document cleanup operations. # The database operations work correctly, isolating only the specified documents. @@ -591,13 +634,23 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all segments exist before cleanup - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 4 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 4 + ) # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify all segments are deleted regardless of status - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 0 + ) # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. @@ -767,9 +820,16 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == num_documents assert ( - _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) + ) + == num_documents + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == num_documents * num_segments_per_doc ) @@ -778,7 +838,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) + == 0 + ) # Note: This test successfully verifies bulk document cleanup operations. # The database efficiently handles large-scale deletions. @@ -885,12 +950,29 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) # Verify only documents' segments from target dataset are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == target_document.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) + .where(DocumentSegment.document_id == target_document.id) + ) + == 0 + ) # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 - assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 6 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs)) + ) + == 2 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs)) + ) + == 6 + ) # Note: This test successfully verifies multi-tenant isolation. # Only documents from the target dataset are affected, maintaining tenant separation. @@ -985,9 +1067,13 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == len(document_statuses) + assert db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) + ) == len(document_statuses) assert ( - _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == len(document_statuses) * 2 ) @@ -996,7 +1082,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted regardless of status - assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) + == 0 + ) # Note: This test successfully verifies cleanup of documents in various states. # All documents are deleted regardless of their indexing status. @@ -1094,14 +1185,29 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert _count_documents(db_session_with_containers, Document.id == document.id) == 1 - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 3 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id == document.id) + ) + == 1 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 0 + ) # Note: This test successfully verifies cleanup of documents with rich metadata. # The task properly handles complex document structures and metadata fields. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 328bdbf055..1b4dcf28ea 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,6 +3,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete from configs import dify_config @@ -18,9 +21,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index b43b622870..b5bef145d5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,12 +2,12 @@ import uuid from unittest.mock import ANY, call, patch import pytest +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b00d827e37..6e98c0855a 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,12 +24,12 @@ from dataclasses import dataclass from datetime import timedelta import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 9c20118e27..7c4553d4a0 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,6 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient +from graphon.enums import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session @@ -24,7 +25,6 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index c4a8148446..e11102acb1 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow_run as workflow_run_module -from controllers.web.error import NotFoundError from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_run as workflow_run_module +from controllers.web.error import NotFoundError from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index b19a1740eb..740da1f1df 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal +from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -15,7 +16,6 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index b2f949c6e2..9c42ee9529 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,10 +11,9 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView -from werkzeug.exceptions import Forbidden - from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index c16ebad739..a26fea8fbd 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -29,7 +30,6 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 3364c07e62..57681d8f5b 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -34,7 +35,6 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 14c35a9ed5..97fdf1a011 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -20,6 +20,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.variables.types import SegmentType from werkzeug.exceptions import BadRequest, NotFound import services @@ -37,7 +38,6 @@ from controllers.service_api.app.conversation import ( ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError -from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index da09ec13ce..74a3c75839 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -20,6 +20,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -36,7 +37,6 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index eda270258d..4b8e3a738c 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,9 @@ from types import SimpleNamespace -from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField from graphon.enums import WorkflowExecutionStatus +from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField + def test_workflow_run_status_field_with_enum() -> None: field = WorkflowRunStatusField() diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 11b53dd0f9..8bde9c1f97 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,7 +1,8 @@ -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager + def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 45d4b0e321..1fb0dc6cf1 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory -from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index b3ea1a464f..f255d2c7df 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index 201923e0e4..4a94a2b4f1 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,11 +1,11 @@ from types import SimpleNamespace import pytest +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.runtime import GraphRuntimeState, VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 3ab63aed25..328cd12f12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,9 +1,10 @@ from collections.abc import Mapping, Sequence -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter + class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 1bef6f69cd..bc11bf4174 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,12 +1,13 @@ from datetime import UTC, datetime from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 936ac37e55..c9e146ff12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,10 +1,11 @@ from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index b3c0eb74fa..0fde7565d2 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,6 +10,8 @@ from typing import Any from unittest.mock import Mock import pytest +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -25,8 +27,6 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables -from graphon.entities import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index b0f8b423e1..6167be3bbd 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,8 +476,9 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from core.app.entities.app_invoke_entities import InvokeFrom from graphon.enums import BuiltinNodeTypes + + from core.app.entities.app_invoke_entities import InvokeFrom from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 10fb2271f4..aa789d9ff3 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 620a153204..9e30faecf2 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index a3ab379b66..8a717e1dcc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,6 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -11,11 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 1f6e7e12ef..29df903aa8 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,14 +2,15 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState + from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index ba55e8f695..a78c1b428f 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,9 +1,6 @@ from collections.abc import Sequence from unittest.mock import Mock -from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.command_channels import CommandChannel from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent @@ -11,6 +8,10 @@ from graphon.node_events import NodeRunResult from graphon.runtime import ReadOnlyGraphRuntimeState from graphon.variables import StringVariable from graphon.variables.segments import Segment, StringSegment + +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 539944d683..035e64325b 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,16 +4,6 @@ from time import time from unittest.mock import Mock import pytest - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import ( - PauseStatePersistenceLayer, - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.workflow.system_variables import SystemVariableKey from graphon.entities.pause_reason import SchedulingPause from graphon.graph_engine.entities.commands import GraphEngineCommand from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -25,6 +15,16 @@ from graphon.graph_events import ( ) from graphon.runtime import ReadOnlyVariablePool from graphon.variables.segments import Segment + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 1c1bf391d3..4aaa10a81a 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -26,8 +28,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 81315d2508..d338cadb77 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 57456085c3..9a815fb94d 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonschema import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -18,7 +19,6 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 704b82adc0..a3b1e5f6b0 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,6 +17,14 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -37,14 +45,6 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index d49b6e4b71..90730dff5a 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,12 +1,12 @@ from collections.abc import Generator import pytest +from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from graphon.file import File, FileTransferMethod, FileType class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 395d392127..2b280dd674 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,13 +2,6 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -18,6 +11,13 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 803afa54d7..4a54649b28 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,11 +1,5 @@ from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.entities.provider_configuration import ProviderModelBundle -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, @@ -13,6 +7,13 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 5d865d934c..a4b3960b0a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,5 +1,3 @@ -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -9,6 +7,9 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil + def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 4b8175b0b4..408cf14a51 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,10 +49,6 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -60,6 +56,10 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 7c4defc180..641c5d9ba0 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,6 +53,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -63,7 +64,6 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index 3d3322094e..e229d5fc1a 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from graphon.entities import WorkflowExecution from graphon.enums import WorkflowType + +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 05b4f3a053..7dbf78d0f0 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 8be1ac318c..0fc82dda53 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,6 +7,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -21,11 +26,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index abdbc72085..84fe522388 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 5af1376a0a..27729e7f06 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index f17927f16b..ac65d0c02b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,7 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig + from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index afea9144c0..f5efb78b61 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index b19a21d7f4..331166fe63 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,6 +1,15 @@ from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -12,15 +21,6 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 43f3fbd5c9..0e3a7e623a 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 72a73dd936..c20edd7400 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 72052c8c05..7406b88270 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,11 +2,6 @@ import dataclasses import orjson import pytest -from pydantic import BaseModel - -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup @@ -47,6 +42,11 @@ from graphon.variables.variables import ( StringVariable, Variable, ) +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool def _build_variable_pool( diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index d4e862220a..37ecd2890b 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,5 +1,4 @@ import pytest - from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 94e788edb2..09254e17a3 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import Any import pytest - from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dae5e1ce98..75b01bf42e 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,6 +1,4 @@ import pytest -from pydantic import ValidationError - from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -12,6 +10,7 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase +from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 025d79b25d..41627f5e0b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,13 +5,12 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from graphon.enums import BuiltinNodeTypes - @pytest.fixture def memory_span_exporter(): @@ -62,9 +61,10 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from core.tools.entities.tool_entities import ToolProviderType from graphon.nodes.tool.entities import ToolNodeData + from core.tools.entities.tool_entities import ToolProviderType + node = MagicMock() node.id = "test-tool-node-id" node.title = "Test Tool Node" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 919f15efd0..9cf72763ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 76b2984a4b..88989db856 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,11 +7,12 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any -from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import DifyNodeFactory + from .test_mock_nodes import ( MockAgentNode, MockCodeNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 971b9b2bbf..8b7fbd1b30 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,10 +10,6 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -31,6 +27,11 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode + if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 55a329eba9..8311a1e847 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,13 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -30,6 +23,14 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 7d23b63049..b11f957677 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,11 +19,6 @@ from functools import lru_cache from pathlib import Path from typing import Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -44,6 +39,12 @@ from graphon.variables import ( StringVariable, ) +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 9c0ad25b58..7195471eb6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,14 +2,15 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index ec4cef1955..343bcd3919 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest - -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import get_node_type_classes_mapping + # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index ef0df55995..b9371a34f4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,7 +1,6 @@ import types from collections.abc import Mapping -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -14,6 +13,8 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) +from core.workflow.node_factory import get_node_type_classes_mapping + def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index ce0c9b79c6..d155124c50 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,4 +1,3 @@ -from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -9,6 +8,8 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType +from configs import dify_config + CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 9cceadde49..fb03ae9998 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,8 +1,9 @@ -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.workflow.nodes.datasource.datasource_node import DatasourceNode from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.workflow.nodes.datasource.datasource_node import DatasourceNode + class _VarSeg: def __init__(self, v): diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index be7cc073db..a5026b40cf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,8 +1,4 @@ import pytest - -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -16,6 +12,10 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables + HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index a3cadc0681..4705b3f76e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index 1d6a4da7c4..d16e1233ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,6 +1,7 @@ -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients + def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index bc98028d5b..52802c7ce1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,9 +1,6 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -14,6 +11,10 @@ from graphon.graph_events import ( from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45e8ae7d20..ab64be59ad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,6 +3,10 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -17,10 +21,6 @@ from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index eca34f05be..fdf1706765 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY + class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index b1f81b6c48..7841bf05ad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,19 +5,6 @@ from collections.abc import Sequence from unittest import mock import pytest - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -80,6 +67,19 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 8f8ec49f14..1c362a0a03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,8 +6,6 @@ from dataclasses import dataclass from typing import Any import pytest - -from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -20,6 +18,8 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType +from factories.variable_factory import build_segment_with_type + @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0522dd9d14..e11ebf6eb8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,16 @@ from collections.abc import Mapping import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_runtime import resolve_dify_run_context -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 87ec2d5bce..555ff0c945 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,8 +4,6 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod @@ -21,6 +19,8 @@ from graphon.nodes.document_extractor.node import ( from graphon.variables import ArrayFileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 782750e02e..1b14f0ab13 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,11 +3,6 @@ import uuid from unittest.mock import MagicMock, Mock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -16,6 +11,11 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b217e4e8e7..d28c3e01e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -18,6 +16,8 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + @pytest.fixture def list_operator_node(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 543f9878de..833c303052 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index 617554ee17..f1132af02b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,4 +1,5 @@ import pytest +from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -6,7 +7,6 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index dddd6eb00c..e7b2b2914a 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -2,15 +2,6 @@ import uuid from collections import defaultdict import pytest - -from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from core.workflow.variable_pool_initializer import add_variables_to_pool -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from factories.variable_factory import build_segment, segment_to_variable from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables import FileSegment, StringSegment @@ -36,6 +27,15 @@ from graphon.variables.variables import ( Variable, ) +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from factories.variable_factory import build_segment, segment_to_variable + @pytest.fixture def pool(): diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 041c5cc612..d8361d06c4 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,6 +1,12 @@ from types import SimpleNamespace import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import VariablePool +from graphon.variables.variables import StringVariable from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage @@ -10,12 +16,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, ) from core.workflow.workflow_entry import WorkflowEntry -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.file import File, FileTransferMethod, FileType -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import VariablePool -from graphon.variables.variables import StringVariable @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index 80dc8927fa..4b2f98aeff 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,11 +2,12 @@ from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.workflow_entry import WorkflowEntry from graphon.graph_engine.command_channels import RedisChannel from graphon.runtime import GraphRuntimeState, VariablePool +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.workflow_entry import WorkflowEntry + class TestWorkflowEntryRedisChannel: """Test suite for WorkflowEntry with Redis command channel.""" diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 511192001e..4fe3f2cb28 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -2,13 +2,13 @@ import uuid from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from httpx import Response from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id from factories.file_factory.builders import build_from_mapping as _build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models import ToolFile, UploadFile # Test Data diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index c35e80a826..a06c42507d 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,11 +4,6 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from factories import variable_factory -from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type from graphon.file import File, FileTransferMethod, FileType from graphon.variables import ( ArrayNumberVariable, @@ -36,6 +31,11 @@ from graphon.variables.segments import ( StringSegment, ) from graphon.variables.types import SegmentType +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type def test_string_variable(): diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index fa2c02020b..f1ce1a2c1c 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -5,7 +5,6 @@ Unit tests for FormService. from datetime import timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import ( diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 866ee61b3e..0babfbb315 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -5,7 +5,6 @@ Unit tests for human input form models. from datetime import datetime, timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index bb3a6db1a1..86163f1554 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,7 +1,8 @@ from uuid import uuid4 -from factories import variable_factory from graphon.variables import SegmentType + +from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index a87dd7f15a..3f6d6bfbe3 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -2,9 +2,9 @@ import importlib import types import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from core.workflow.file_reference import build_file_reference -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from models.model import Conversation, Message diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index f7bdc97eb5..e7c0479757 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -3,13 +3,14 @@ import json from unittest import mock from uuid import uuid4 +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from graphon.variables.segments import IntegerSegment, Segment + from constants import HIDDEN_VALUE from core.helper import encrypter from core.workflow.file_reference import build_file_reference from factories.variable_factory import build_segment -from graphon.file import File, FileTransferMethod, FileType -from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from graphon.variables.segments import IntegerSegment, Segment from models.workflow import ( Workflow, WorkflowDraftVariable, diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index eb9fef7587..507e1c8c3a 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -13,12 +13,12 @@ from datetime import UTC, datetime from uuid import uuid4 import pytest - from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) + from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import ( Workflow, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 71df8c4e20..6903c47a24 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,11 +109,11 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.entities import PreProcessingRule, Rule, Segmentation from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index 361e95a557..73fc399ac3 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -73,11 +73,6 @@ class TestAsyncWorkflowService: mock_dispatcher = MagicMock() mock_quota_service = MagicMock() - mock_get_workflow = MagicMock() - - mock_professional_task = MagicMock() - mock_team_task = MagicMock() - mock_sandbox_task = MagicMock() with ( patch.object( diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 2c7f13b79f..68f4c51afe 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -6,15 +6,26 @@ Tests are organized by functionality and include edge cases, error handling, and both positive and negative test scenarios. """ +from datetime import timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch +import pytest from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom from libs.datetime_utils import naive_utc_now +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable +from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService +from services.errors.conversation import ( + ConversationNotExistsError, + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) +from services.errors.message import MessageNotExistsError class ConversationServiceTestDataFactory: @@ -327,9 +338,330 @@ class TestConversationServiceHelpers: assert condition is not None +class TestConversationServiceGetConversation: + """Test conversation retrieval operations.""" + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_account(self, mock_db_session): + """ + Test successful conversation retrieval with account user. + + Should return conversation when found with proper filters. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_account_id=user.id, from_source=ConversationFromSource.CONSOLE + ) + + mock_db_session.scalar.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_end_user(self, mock_db_session): + """ + Test successful conversation retrieval with end user. + + Should return conversation when found with proper filters for API user. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_end_user_id=user.id, from_source=ConversationFromSource.API + ) + + mock_db_session.scalar.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + + @patch("services.conversation_service.db.session") + def test_get_conversation_not_found_raises_error(self, mock_db_session): + """ + Test that get_conversation raises error when conversation not found. + + Should raise ConversationNotExistsError when no matching conversation found. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + mock_db_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationNotExistsError): + ConversationService.get_conversation(app_model, "conv-123", user) + + +class TestConversationServiceRename: + """Test conversation rename operations.""" + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_rename_with_manual_name(self, mock_get_conversation, mock_db_session): + """ + Test renaming conversation with manual name. + + Should update conversation name and timestamp when auto_generate is False. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id="conv-123", + user=user, + name="New Name", + auto_generate=False, + ) + + # Assert + assert result == conversation + assert conversation.name == "New Name" + mock_db_session.commit.assert_called_once() + + +class TestConversationServiceAutoGenerateName: + """Test conversation auto-name generation operations.""" + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_success(self, mock_llm_generator, mock_db_session): + """ + Test successful auto-generation of conversation name. + + Should generate name using LLMGenerator and update conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_db_session.scalar.return_value = message + + # Mock LLM generator + mock_llm_generator.generate_conversation_name.return_value = "Generated Name" + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + assert conversation.name == "Generated Name" + mock_llm_generator.generate_conversation_name.assert_called_once_with( + app_model.tenant_id, message.query, conversation.id, app_model.id + ) + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + def test_auto_generate_name_no_message_raises_error(self, mock_db_session): + """ + Test auto-generation fails when no message found. + + Should raise MessageNotExistsError when conversation has no messages. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + # Mock database query to return None + mock_db_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + ConversationService.auto_generate_name(app_model, conversation) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_handles_llm_exception(self, mock_llm_generator, mock_db_session): + """ + Test auto-generation handles LLM generator exceptions gracefully. + + Should continue without name when LLMGenerator fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_db_session.scalar.return_value = message + + # Mock LLM generator to raise exception + mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + # Name should remain unchanged due to exception + mock_db_session.commit.assert_called_once() + + +class TestConversationServiceDelete: + """Test conversation deletion operations.""" + + @patch("services.conversation_service.delete_conversation_related_data") + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_success(self, mock_get_conversation, mock_db_session, mock_delete_task): + """ + Test successful conversation deletion. + + Should delete conversation and schedule cleanup task. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock(name="Test App") + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Act + ConversationService.delete(app_model, "conv-123", user) + + # Assert + mock_db_session.delete.assert_called_once_with(conversation) + mock_db_session.commit.assert_called_once() + mock_delete_task.delay.assert_called_once_with(conversation.id) + + class TestConversationServiceConversationalVariable: """Test conversational variable operations.""" + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_success(self, mock_get_conversation, mock_session_factory): + """ + Test successful retrieval of conversational variables. + + Should return paginated list of variables for conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + variable1 = ConversationServiceTestDataFactory.create_conversation_variable_mock() + variable2 = ConversationServiceTestDataFactory.create_conversation_variable_mock(variable_id="var-456") + + mock_session.scalars.return_value.all.return_value = [variable1, variable2] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 2 + assert result.limit == 10 + assert result.has_more is False + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_with_last_id(self, mock_get_conversation, mock_session_factory): + """ + Test retrieval of variables with last_id pagination. + + Should filter variables created after last_id. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( + created_at=naive_utc_now() - timedelta(hours=1) + ) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=naive_utc_now()) + + mock_session.scalar.return_value = last_variable + mock_session.scalars.return_value.all.return_value = [variable] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="var-123", + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + assert result.limit == 10 + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_last_id_not_found_raises_error( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that invalid last_id raises ConversationVariableNotExistsError. + + Should raise error when last_id doesn't exist. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="invalid-id", + ) + @patch("services.conversation_service.session_factory") @patch("services.conversation_service.ConversationService.get_conversation") @patch("services.conversation_service.dify_config") @@ -366,3 +698,466 @@ class TestConversationServiceConversationalVariable: # Assert - JSON filter should be applied assert mock_session.scalars.called + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_postgresql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for PostgreSQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "postgresql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + +class TestConversationServiceUpdateVariable: + """Test conversation variable update operations.""" + + @patch("services.conversation_service.variable_factory") + @patch("services.conversation_service.ConversationVariableUpdater") + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_success( + self, mock_get_conversation, mock_session_factory, mock_updater_class, mock_variable_factory + ): + """ + Test successful update of conversation variable. + + Should update variable value and return updated data. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="string") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": "new_value"} + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="new_value", + ) + + # Assert + assert result["id"] == "var-123" + assert result["value"] == "new_value" + mock_updater.update.assert_called_once_with("conv-123", updated_variable) + mock_updater.flush.assert_called_once() + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_not_found_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when variable doesn't exist. + + Should raise ConversationVariableNotExistsError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="invalid-id", + user=user, + new_value="new_value", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_type_mismatch_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when value type doesn't match expected type. + + Should raise ConversationVariableTypeMismatchError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="number") + mock_session.scalar.return_value = existing_variable + + # Act & Assert - Try to set string value for number variable + with pytest.raises(ConversationVariableTypeMismatchError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="string_value", # Wrong type + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_integer_number_compatibility( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that integer type accepts number values. + + Should allow number values for integer type variables. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="integer") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": 42} + + with ( + patch("services.conversation_service.variable_factory") as mock_variable_factory, + patch("services.conversation_service.ConversationVariableUpdater") as mock_updater_class, + ): + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value=42, # Number value for integer type + ) + + # Assert + assert result["value"] == 42 + mock_updater.update.assert_called_once() + + +class TestConversationServicePaginationAdvanced: + """Advanced pagination tests for ConversationService.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_last_id_not_found(self, mock_session_factory): + """ + Test pagination with invalid last_id raises error. + + Should raise LastConversationNotExistsError when last_id doesn't exist. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act & Assert + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id="invalid-id", + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_exclude_ids(self, mock_session_factory): + """ + Test pagination with exclude_ids filter. + + Should exclude specified conversation IDs from results. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=["excluded-123"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_has_more_detection(self, mock_session_factory): + """ + Test pagination has_more detection logic. + + Should set has_more=True when there are more results beyond limit. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + # Return exactly limit items to trigger has_more check + conversations = [ + ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=f"conv-{i}") for i in range(20) + ] + mock_session.scalars.return_value.all.return_value = conversations + mock_session.scalar.return_value = conversations[-1] + + # Mock count query to return > 0 + mock_session.scalar.return_value = 5 # Additional items exist + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is True + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_different_sort_by(self, mock_session_factory): + """ + Test pagination with different sort fields. + + Should handle various sort_by parameters correctly. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Test different sort fields + sort_fields = ["created_at", "-updated_at", "name", "-status"] + + for sort_by in sort_fields: + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + sort_by=sort_by, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + +class TestConversationServiceEdgeCases: + """Test edge cases and error scenarios.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_with_end_user_api_source(self, mock_session_factory): + """ + Test pagination correctly handles EndUser with API source. + + Should use 'api' as from_source for EndUser instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source=ConversationFromSource.API, from_end_user_id="user-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + @patch("services.conversation_service.session_factory") + def test_pagination_with_account_console_source(self, mock_session_factory): + """ + Test pagination correctly handles Account with console source. + + Should use 'console' as from_source for Account instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source=ConversationFromSource.CONSOLE, from_account_id="account-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + def test_pagination_with_include_ids_filter(self): + """ + Test pagination with include_ids filter. + + Should only return conversations with IDs in include_ids list. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv-123", "conv-456"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + # Verify that include_ids filter was applied + assert mock_session.scalars.called + + def test_pagination_with_empty_exclude_ids(self): + """ + Test pagination with empty exclude_ids list. + + Should handle empty exclude_ids gracefully. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=[], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is False diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 55af564821..9be475d043 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,18 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest - -import services.human_input_service as human_input_service_module -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) from graphon.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus + +import services.human_input_service as human_input_service_module +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 97f3bd6f01..1bd979b9ec 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest - -from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ConfigurateMethod + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 4b864dd221..98ec6fb77c 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,7 +16,6 @@ from typing import Any from uuid import uuid4 import pytest - from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segments import ( ArrayFileSegment, @@ -29,6 +28,7 @@ from graphon.variables.segments import ( ObjectSegment, StringSegment, ) + from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 02fbe473df..bf645f9795 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -561,18 +561,13 @@ class TestWebhookServiceUnit: assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None) - # === Merged from test_webhook_service_additional.py === from types import SimpleNamespace from typing import Any, cast -from unittest.mock import MagicMock -import pytest -from flask import Flask from graphon.variables.types import SegmentType -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import RequestEntityTooLarge from core.workflow.nodes.trigger_webhook.entities import ( @@ -587,7 +582,6 @@ from models.trigger import WorkflowWebhookTrigger from models.workflow import Workflow from services.errors.app import QuotaExceededError from services.trigger import webhook_service as service_module -from services.trigger.webhook_service import WebhookService class _FakeQuery: diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 239cc83518..a62c9f4555 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 60beec7964..8525672da8 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,12 +4,12 @@ import json from unittest.mock import Mock, patch import pytest -from sqlalchemy import Engine - -from core.workflow.file_reference import build_file_reference from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segments import ObjectSegment, StringSegment from graphon.variables.types import SegmentType +from sqlalchemy import Engine + +from core.workflow.file_reference import build_file_reference from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index f6bdb6a60e..e7e72793a3 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,6 +4,10 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session @@ -13,10 +17,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from graphon.enums import BuiltinNodeTypes -from graphon.file import File, FileTransferMethod, FileType -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index d570dce107..4146fd312b 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -6,13 +6,13 @@ from datetime import UTC, datetime from threading import Event import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index d7192994b2..98d057e41f 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,6 +3,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker from core.workflow.human_input_compat import ( @@ -12,9 +15,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 591da56f49..7119217e94 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest - from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus + from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index 689b973097..544e89fcee 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -4,6 +4,7 @@ from typing import Any from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -18,7 +19,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from graphon.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict[str, Any] | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index c166a946d9..ffa6833524 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,9 +2,6 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -21,6 +18,9 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output + def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: """Create a mock LLMUsage with all required fields""" diff --git a/api/uv.lock b/api/uv.lock index 7ed1b9960c..db00ccf800 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -380,14 +380,14 @@ wheels = [ [[package]] name = "authlib" -version = "1.6.11" +version = "1.6.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/28/10/b325d58ffe86815b399334a101e63bc6fa4e1953921cb23703b48a0a0220/authlib-1.6.11.tar.gz", hash = "sha256:64db35b9b01aeccb4715a6c9a6613a06f2bd7be2ab9d2eb89edd1dfc7580a38f", size = 165359, upload-time = "2026-04-16T07:22:50.279Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/98/00d3dd826d46959ad8e32af2dbb2398868fd9fd0683c26e56d0789bd0e68/authlib-1.6.9.tar.gz", hash = "sha256:d8f2421e7e5980cc1ddb4e32d3f5fa659cfaf60d8eaf3281ebed192e4ab74f04", size = 165134, upload-time = "2026-03-02T07:44:01.998Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/2f/55fca558f925a51db046e5b929deb317ddb05afed74b22d89f4eca578980/authlib-1.6.11-py2.py3-none-any.whl", hash = "sha256:c8687a9a26451c51a34a06fa17bb97cb15bba46a6a626755e2d7f50da8bff3e3", size = 244469, upload-time = "2026-04-16T07:22:48.413Z" }, + { url = "https://files.pythonhosted.org/packages/53/23/b65f568ed0c22f1efacb744d2db1a33c8068f384b8c9b482b52ebdbc3ef6/authlib-1.6.9-py2.py3-none-any.whl", hash = "sha256:f08b4c14e08f0861dc18a32357b33fbcfd2ea86cfe3fe149484b4d764c4a0ac3", size = 244197, upload-time = "2026-03-02T07:44:00.307Z" }, ] [[package]] @@ -3903,14 +3903,14 @@ wheels = [ [[package]] name = "mako" -version = "1.3.11" +version = "1.3.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" }, + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, ] [[package]] @@ -5501,11 +5501,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.10.2" +version = "6.10.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7b/3f/9f2167401c2e94833ca3b69535bad89e533b5de75fefe4197a2c224baec2/pypdf-6.10.2.tar.gz", hash = "sha256:7d09ce108eff6bf67465d461b6ef352dcb8d84f7a91befc02f904455c6eea11d", size = 5315679, upload-time = "2026-04-15T16:37:36.978Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/79/f2730c42ec7891a75a2fcea2eb4f356872bcbc671b711418060424796612/pypdf-6.10.1.tar.gz", hash = "sha256:62e6ca7f65aaa28b3d192addb44f97296e4be1748f57ed0f4efb2d4915841880", size = 5315704, upload-time = "2026-04-14T12:55:20.996Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/d6/1d5c60cc17bbdf37c1552d9c03862fc6d32c5836732a0415b2d637edc2d0/pypdf-6.10.2-py3-none-any.whl", hash = "sha256:aa53be9826655b51c96741e5d7983ca224d898ac0a77896e64636810517624aa", size = 336308, upload-time = "2026-04-15T16:37:34.851Z" }, + { url = "https://files.pythonhosted.org/packages/f0/04/e3aa7f1f14dbc53429cae34666261eb935d99bd61d24756ab94d7e0309da/pypdf-6.10.1-py3-none-any.whl", hash = "sha256:6331940d3bfe75b7e6601d35db7adabab5fc1d716efaeb384e3c0c3957d033de", size = 335606, upload-time = "2026-04-14T12:55:18.941Z" }, ] [[package]] @@ -6562,6 +6562,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/3f/48af1e72e59d60481724b326317bd311615bdedc31f8f81f9508fb84cda6/tablestore-6.4.4-py3-none-any.whl", hash = "sha256:984f086fa7acabaa3558da93205ad6df562b266b85fd249bc5891f2dd1d65814", size = 5118758, upload-time = "2026-04-09T09:40:17.209Z" }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, +] + [[package]] name = "tcvdb-text" version = "1.1.2" diff --git a/packages/iconify-collections/assets/vender/workflow/input-field.svg b/packages/iconify-collections/assets/vender/workflow/input-field.svg new file mode 100644 index 0000000000..47ef58181e --- /dev/null +++ b/packages/iconify-collections/assets/vender/workflow/input-field.svg @@ -0,0 +1,3 @@ + + + diff --git a/packages/iconify-collections/custom-public/icons.json b/packages/iconify-collections/custom-public/icons.json index 347b6145e2..7c7d110be8 100644 --- a/packages/iconify-collections/custom-public/icons.json +++ b/packages/iconify-collections/custom-public/icons.json @@ -1,6 +1,6 @@ { "prefix": "custom-public", - "lastModified": 1775115796, + "lastModified": 1776313052, "icons": { "avatar-user": { "body": "", @@ -65,6 +65,9 @@ "width": 50, "height": 26 }, + "common-enter-key": { + "body": "" + }, "common-gdpr": { "body": "", "width": 23, @@ -407,6 +410,11 @@ "model-checked": { "body": "" }, + "other-comment": { + "body": "", + "width": 14, + "height": 12 + }, "other-default-tool-icon": { "body": "" }, diff --git a/packages/iconify-collections/custom-public/info.json b/packages/iconify-collections/custom-public/info.json index 8b5572de6f..115e9e25f9 100644 --- a/packages/iconify-collections/custom-public/info.json +++ b/packages/iconify-collections/custom-public/info.json @@ -1,7 +1,7 @@ { "prefix": "custom-public", "name": "Dify Custom Public", - "total": 142, + "total": 144, "version": "0.0.0-private", "author": { "name": "LangGenius, Inc.", diff --git a/packages/iconify-collections/custom-vender/icons.json b/packages/iconify-collections/custom-vender/icons.json index a7dc8e75e0..d588db650e 100644 --- a/packages/iconify-collections/custom-vender/icons.json +++ b/packages/iconify-collections/custom-vender/icons.json @@ -1,6 +1,6 @@ { "prefix": "custom-vender", - "lastModified": 1775115796, + "lastModified": 1776313052, "icons": { "features-citations": { "body": "" @@ -1025,6 +1025,11 @@ "workflow-if-else": { "body": "" }, + "workflow-input-field": { + "body": "", + "width": 16, + "height": 16 + }, "workflow-iteration": { "body": "" }, diff --git a/packages/iconify-collections/custom-vender/info.json b/packages/iconify-collections/custom-vender/info.json index 0a84c45bbd..ea5f666503 100644 --- a/packages/iconify-collections/custom-vender/info.json +++ b/packages/iconify-collections/custom-vender/info.json @@ -1,7 +1,7 @@ { "prefix": "custom-vender", "name": "Dify Custom Vender", - "total": 277, + "total": 278, "version": "0.0.0-private", "author": { "name": "LangGenius, Inc.", diff --git a/sdks/nodejs-client/tsconfig.json b/sdks/nodejs-client/tsconfig.json index 1e55007ed0..46055447be 100644 --- a/sdks/nodejs-client/tsconfig.json +++ b/sdks/nodejs-client/tsconfig.json @@ -1,14 +1,18 @@ { - "extends": "@dify/tsconfig/node.json", "compilerOptions": { - "lib": ["ES2023", "DOM", "DOM.Iterable"], + "target": "ES2022", + "module": "ESNext", + "moduleResolution": "Bundler", "rootDir": ".", "outDir": "dist", - "noEmit": false, "declaration": true, "declarationMap": true, "sourceMap": true, + "strict": true, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "skipLibCheck": true, "types": ["node"] }, - "include": ["src/**/*.ts", "tests/**/*.ts", "vite.config.ts"] + "include": ["src/**/*.ts", "tests/**/*.ts"] } diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index a5ed79a7bd..e0f09ad2ac 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -77,6 +77,13 @@ vi.mock('@/context/provider-context', () => ({ }), })) +vi.mock('@/hooks/use-snippet-and-evaluation-plan-access', () => ({ + useSnippetAndEvaluationPlanAccess: () => ({ + canAccess: true, + isReady: true, + }), +})) + vi.mock('@/app/components/base/tag-management/store', () => ({ useStore: (selector: (state: Record) => unknown) => { const state = { @@ -93,6 +100,16 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/use-common', () => ({ + useMembers: () => ({ + data: { + accounts: [ + { id: 'user-1', name: 'Current User', email: 'current@example.com', avatar: '', avatar_url: '', role: 'owner', last_login_at: '', created_at: '', status: 'active' }, + ], + }, + }), +})) + vi.mock('@/service/apps', () => ({ fetchWorkflowOnlineUsers: vi.fn().mockResolvedValue({}), })) @@ -114,6 +131,18 @@ vi.mock('@/service/use-apps', () => ({ }), })) +vi.mock('@/service/use-snippets', () => ({ + useInfiniteSnippetList: () => ({ + data: { pages: [] }, + isLoading: false, + isFetching: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + hasNextPage: false, + error: null, + }), +})) + vi.mock('@/hooks/use-pay', () => ({ CheckModal: () => null, })) @@ -323,16 +352,11 @@ describe('App List Browsing Flow', () => { // -- Tab navigation -- describe('Tab Navigation', () => { - it('should render all category tabs', () => { + it('should render the app type dropdown trigger', () => { mockPages = [createPage([createMockApp()])] renderList() - expect(screen.getByText('app.types.all')).toBeInTheDocument() - expect(screen.getByText('app.types.workflow')).toBeInTheDocument() - expect(screen.getByText('app.types.advanced')).toBeInTheDocument() - expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() - expect(screen.getByText('app.types.agent')).toBeInTheDocument() - expect(screen.getByText('app.types.completion')).toBeInTheDocument() + expect(screen.getByText('app.studio.filters.types')).toBeInTheDocument() }) }) @@ -358,21 +382,19 @@ describe('App List Browsing Flow', () => { // -- "Created by me" filter -- describe('Created By Me Filter', () => { - it('should render the "created by me" checkbox', () => { + it('should not render a standalone "created by me" checkbox in the current header layout', () => { mockPages = [createPage([createMockApp()])] renderList() - expect(screen.getByText('app.showMyCreatedAppsOnly')).toBeInTheDocument() + expect(screen.queryByText('app.showMyCreatedAppsOnly')).not.toBeInTheDocument() }) - it('should toggle the "created by me" filter on click', () => { + it('should keep the current layout stable without a "created by me" control', () => { mockPages = [createPage([createMockApp()])] renderList() - const checkbox = screen.getByText('app.showMyCreatedAppsOnly') - fireEvent.click(checkbox) - - expect(screen.getByText('app.showMyCreatedAppsOnly')).toBeInTheDocument() + expect(screen.getByText('app.studio.filters.types')).toBeInTheDocument() + expect(screen.queryByText('app.showMyCreatedAppsOnly')).not.toBeInTheDocument() }) }) diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 9abc870ecf..ab45860a07 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -64,6 +64,13 @@ vi.mock('@/context/provider-context', () => ({ }), })) +vi.mock('@/hooks/use-snippet-and-evaluation-plan-access', () => ({ + useSnippetAndEvaluationPlanAccess: () => ({ + canAccess: true, + isReady: true, + }), +})) + vi.mock('@/app/components/base/tag-management/store', () => ({ useStore: (selector: (state: Record) => unknown) => { const state = { @@ -80,6 +87,16 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/use-common', () => ({ + useMembers: () => ({ + data: { + accounts: [ + { id: 'user-1', name: 'Current User', email: 'current@example.com', avatar: '', avatar_url: '', role: 'owner', last_login_at: '', created_at: '', status: 'active' }, + ], + }, + }), +})) + vi.mock('@/service/apps', () => ({ fetchWorkflowOnlineUsers: vi.fn().mockResolvedValue({}), })) @@ -101,6 +118,18 @@ vi.mock('@/service/use-apps', () => ({ }), })) +vi.mock('@/service/use-snippets', () => ({ + useInfiniteSnippetList: () => ({ + data: { pages: [] }, + isLoading: false, + isFetching: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + hasNextPage: false, + error: null, + }), +})) + vi.mock('@/hooks/use-pay', () => ({ CheckModal: () => null, })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/evaluation/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/evaluation/page.tsx new file mode 100644 index 0000000000..899667ad3c --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/evaluation/page.tsx @@ -0,0 +1,16 @@ +import SnippetAndEvaluationPlanGuard from '@/app/components/billing/snippet-and-evaluation-plan-guard' +import Evaluation from '@/app/components/evaluation' + +const Page = async (props: { + params: Promise<{ appId: string }> +}) => { + const { appId } = await props.params + + return ( + + + + ) +} + +export default Page diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 8a1a6fd131..7f8d6d535e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -8,6 +8,8 @@ import { RiDashboard2Line, RiFileList3Fill, RiFileList3Line, + RiFlaskFill, + RiFlaskLine, RiTerminalBoxFill, RiTerminalBoxLine, RiTerminalWindowFill, @@ -25,6 +27,7 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import { useSnippetAndEvaluationPlanAccess } from '@/hooks/use-snippet-and-evaluation-plan-access' import dynamic from '@/next/dynamic' import { usePathname, useRouter } from '@/next/navigation' import { fetchAppDetailDirect } from '@/service/apps' @@ -50,6 +53,7 @@ const AppDetailLayout: FC = (props) => { const pathname = usePathname() const media = useBreakpoints() const isMobile = media === MediaType.mobile + const { canAccess: canAccessSnippetsAndEvaluation } = useSnippetAndEvaluationPlanAccess() const { isCurrentWorkspaceEditor, isLoadingCurrentWorkspace, currentWorkspace } = useAppContext() const { appDetail, setAppDetail, setAppSidebarExpand } = useStore(useShallow(state => ({ appDetail: state.appDetail, @@ -67,42 +71,51 @@ const AppDetailLayout: FC = (props) => { }>>([]) const getNavigationConfig = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: AppModeEnum) => { - const navConfig = [ - ...(isCurrentWorkspaceEditor - ? [{ - name: t('appMenus.promptEng', { ns: 'common' }), - href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`, - icon: RiTerminalWindowLine, - selectedIcon: RiTerminalWindowFill, - }] - : [] - ), - { - name: t('appMenus.apiAccess', { ns: 'common' }), - href: `/app/${appId}/develop`, - icon: RiTerminalBoxLine, - selectedIcon: RiTerminalBoxFill, - }, - ...(isCurrentWorkspaceEditor - ? [{ - name: mode !== AppModeEnum.WORKFLOW - ? t('appMenus.logAndAnn', { ns: 'common' }) - : t('appMenus.logs', { ns: 'common' }), - href: `/app/${appId}/logs`, - icon: RiFileList3Line, - selectedIcon: RiFileList3Fill, - }] - : [] - ), - { - name: t('appMenus.overview', { ns: 'common' }), - href: `/app/${appId}/overview`, - icon: RiDashboard2Line, - selectedIcon: RiDashboard2Fill, - }, - ] + const navConfig = [] + + if (isCurrentWorkspaceEditor) { + navConfig.push({ + name: t('appMenus.promptEng', { ns: 'common' }), + href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`, + icon: RiTerminalWindowLine, + selectedIcon: RiTerminalWindowFill, + }) + if (canAccessSnippetsAndEvaluation) { + navConfig.push({ + name: t('appMenus.evaluation', { ns: 'common' }), + href: `/app/${appId}/evaluation`, + icon: RiFlaskLine, + selectedIcon: RiFlaskFill, + }) + } + } + + navConfig.push({ + name: t('appMenus.apiAccess', { ns: 'common' }), + href: `/app/${appId}/develop`, + icon: RiTerminalBoxLine, + selectedIcon: RiTerminalBoxFill, + }) + + if (isCurrentWorkspaceEditor) { + navConfig.push({ + name: mode !== AppModeEnum.WORKFLOW + ? t('appMenus.logAndAnn', { ns: 'common' }) + : t('appMenus.logs', { ns: 'common' }), + href: `/app/${appId}/logs`, + icon: RiFileList3Line, + selectedIcon: RiFileList3Fill, + }) + } + + navConfig.push({ + name: t('appMenus.overview', { ns: 'common' }), + href: `/app/${appId}/overview`, + icon: RiDashboard2Line, + selectedIcon: RiDashboard2Fill, + }) return navConfig - }, [t]) + }, [canAccessSnippetsAndEvaluation, t]) useDocumentTitle(appDetail?.name || t('menus.appDetail', { ns: 'common' })) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/__tests__/layout-main.spec.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/__tests__/layout-main.spec.tsx new file mode 100644 index 0000000000..2831b0d464 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/__tests__/layout-main.spec.tsx @@ -0,0 +1,230 @@ +import type { ReactNode } from 'react' +import type { DataSet } from '@/models/datasets' +import { render, screen } from '@testing-library/react' +import { IndexingType } from '@/app/components/datasets/create/step-two' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import DatasetDetailLayout from '../layout-main' + +let mockPathname = '/datasets/test-dataset-id/documents' +let mockDataset: DataSet | undefined +let mockCanAccessSnippetsAndEvaluation = true + +const mockSetAppSidebarExpand = vi.fn() +const mockMutateDatasetRes = vi.fn() + +vi.mock('@/next/navigation', () => ({ + usePathname: () => mockPathname, +})) + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: { setAppSidebarExpand: typeof mockSetAppSidebarExpand }) => unknown) => selector({ + setAppSidebarExpand: mockSetAppSidebarExpand, + }), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => 'desktop', + MediaType: { + mobile: 'mobile', + desktop: 'desktop', + }, +})) + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + useSubscription: vi.fn(), + }, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceDatasetOperator: false, + }), +})) + +vi.mock('@/hooks/use-snippet-and-evaluation-plan-access', () => ({ + useSnippetAndEvaluationPlanAccess: () => ({ + canAccess: mockCanAccessSnippetsAndEvaluation, + isReady: true, + }), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useDatasetDetail: () => ({ + data: mockDataset, + error: null, + refetch: mockMutateDatasetRes, + }), + useDatasetRelatedApps: () => ({ + data: [], + }), +})) + +vi.mock('@/app/components/app-sidebar', () => ({ + default: ({ + navigation, + children, + }: { + navigation: Array<{ name: string, href: string, disabled?: boolean }> + children?: ReactNode + }) => ( +
+ {navigation.map(item => ( + + ))} + {children} +
+ ), +})) + +vi.mock('@/app/components/datasets/extra-info', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/base/loading', () => ({ + default: () =>
loading
, +})) + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'test-dataset-id', + name: 'Test Dataset', + indexing_status: 'completed', + icon_info: { + icon: 'book', + icon_background: '#fff', + icon_type: 'emoji', + icon_url: '', + }, + description: '', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: IndexingType.QUALIFIED, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 0, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + word_count: 0, + provider: 'vendor', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 3, + score_threshold: 0.5, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + pipeline_id: 'pipeline-1', + is_published: true, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +describe('DatasetDetailLayout', () => { + beforeEach(() => { + vi.clearAllMocks() + mockPathname = '/datasets/test-dataset-id/documents' + mockDataset = createDataset() + mockCanAccessSnippetsAndEvaluation = true + }) + + describe('Evaluation navigation', () => { + it('should hide the evaluation menu when the dataset is not a rag pipeline', () => { + mockDataset = createDataset({ + runtime_mode: 'general', + is_published: false, + }) + + render( + +
content
+
, + ) + + expect(screen.queryByRole('button', { name: 'common.datasetMenus.evaluation' })).not.toBeInTheDocument() + }) + + it('should disable the evaluation menu when the rag pipeline is unpublished', () => { + mockDataset = createDataset({ + is_published: false, + }) + + render( + +
content
+
, + ) + + expect(screen.getByRole('button', { name: 'common.datasetMenus.evaluation' })).toBeDisabled() + }) + + it('should enable the evaluation menu when the rag pipeline is published', () => { + render( + +
content
+
, + ) + + expect(screen.getByRole('button', { name: 'common.datasetMenus.evaluation' })).toBeEnabled() + }) + + it('should hide the evaluation menu when snippet and evaluation access is unavailable', () => { + mockCanAccessSnippetsAndEvaluation = false + + render( + +
content
+
, + ) + + expect(screen.queryByRole('button', { name: 'common.datasetMenus.evaluation' })).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/evaluation/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/evaluation/page.tsx new file mode 100644 index 0000000000..ea8fc0ea82 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/evaluation/page.tsx @@ -0,0 +1,16 @@ +import SnippetAndEvaluationPlanGuard from '@/app/components/billing/snippet-and-evaluation-plan-guard' +import Evaluation from '@/app/components/evaluation' + +const Page = async (props: { + params: Promise<{ datasetId: string }> +}) => { + const { datasetId } = await props.params + + return ( + + + + ) +} + +export default Page diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index ba3272c1a7..c5719d6a61 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -7,6 +7,8 @@ import { RiEqualizer2Line, RiFileTextFill, RiFileTextLine, + RiFlaskFill, + RiFlaskLine, RiFocus2Fill, RiFocus2Line, } from '@remixicon/react' @@ -23,6 +25,7 @@ import DatasetDetailContext from '@/context/dataset-detail' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import { useSnippetAndEvaluationPlanAccess } from '@/hooks/use-snippet-and-evaluation-plan-access' import { usePathname } from '@/next/navigation' import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset' @@ -49,6 +52,7 @@ const DatasetDetailLayout: FC = (props) => { setHideHeader(v.payload) }) const { isCurrentWorkspaceDatasetOperator } = useAppContext() + const { canAccess: canAccessSnippetsAndEvaluation } = useSnippetAndEvaluationPlanAccess() const media = useBreakpoints() const isMobile = media === MediaType.mobile @@ -56,6 +60,7 @@ const DatasetDetailLayout: FC = (props) => { const { data: datasetRes, error, refetch: mutateDatasetRes } = useDatasetDetail(datasetId) const { data: relatedApps } = useDatasetRelatedApps(datasetId) + const isRagPipelineDataset = datasetRes?.runtime_mode === 'rag_pipeline' const isButtonDisabledWithPipeline = useMemo(() => { if (!datasetRes) @@ -86,24 +91,36 @@ const DatasetDetailLayout: FC = (props) => { ] if (datasetRes?.provider !== 'external') { - baseNavigation.unshift({ - name: t('datasetMenus.pipeline', { ns: 'common' }), - href: `/datasets/${datasetId}/pipeline`, - icon: PipelineLine as RemixiconComponentType, - selectedIcon: PipelineFill as RemixiconComponentType, - disabled: false, - }) - baseNavigation.unshift({ - name: t('datasetMenus.documents', { ns: 'common' }), - href: `/datasets/${datasetId}/documents`, - icon: RiFileTextLine, - selectedIcon: RiFileTextFill, - disabled: isButtonDisabledWithPipeline, - }) + return [ + { + name: t('datasetMenus.documents', { ns: 'common' }), + href: `/datasets/${datasetId}/documents`, + icon: RiFileTextLine, + selectedIcon: RiFileTextFill, + disabled: isButtonDisabledWithPipeline, + }, + { + name: t('datasetMenus.pipeline', { ns: 'common' }), + href: `/datasets/${datasetId}/pipeline`, + icon: PipelineLine as RemixiconComponentType, + selectedIcon: PipelineFill as RemixiconComponentType, + disabled: false, + }, + ...(isRagPipelineDataset && canAccessSnippetsAndEvaluation + ? [{ + name: t('datasetMenus.evaluation', { ns: 'common' }), + href: `/datasets/${datasetId}/evaluation`, + icon: RiFlaskLine, + selectedIcon: RiFlaskFill, + disabled: isButtonDisabledWithPipeline, + }] + : []), + ...baseNavigation, + ] } return baseNavigation - }, [t, datasetId, isButtonDisabledWithPipeline, datasetRes?.provider]) + }, [canAccessSnippetsAndEvaluation, t, datasetId, isButtonDisabledWithPipeline, isRagPipelineDataset, datasetRes?.provider]) useDocumentTitle(datasetRes?.name || t('menus.datasets', { ns: 'common' })) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 49e9431940..5ac39f1e39 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -5,6 +5,7 @@ import InSiteMessageNotification from '@/app/components/app/in-site-message/noti import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import Zendesk from '@/app/components/base/zendesk' +import GotoAnything from '@/app/components/goto-anything' import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' @@ -12,15 +13,10 @@ import { AppContextProvider } from '@/context/app-context-provider' import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { ModalContextProvider } from '@/context/modal-context-provider' import { ProviderContextProvider } from '@/context/provider-context-provider' -import dynamic from '@/next/dynamic' import PartnerStack from '../components/billing/partner-stack' import Splash from '../components/splash' import RoleRouteGuard from './role-route-guard' -const GotoAnything = dynamic(() => import('@/app/components/goto-anything'), { - ssr: false, -}) - const Layout = ({ children }: { children: ReactNode }) => { return ( <> diff --git a/web/app/(commonLayout)/role-route-guard.tsx b/web/app/(commonLayout)/role-route-guard.tsx index 483dfef095..6de5efb346 100644 --- a/web/app/(commonLayout)/role-route-guard.tsx +++ b/web/app/(commonLayout)/role-route-guard.tsx @@ -6,7 +6,7 @@ import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' import { usePathname, useRouter } from '@/next/navigation' -const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const +const datasetOperatorRedirectRoutes = ['/apps', '/app', '/snippets', '/explore', '/tools'] as const const isPathUnderRoute = (pathname: string, route: string) => pathname === route || pathname.startsWith(`${route}/`) diff --git a/web/app/(commonLayout)/snippets/[snippetId]/evaluation/page.tsx b/web/app/(commonLayout)/snippets/[snippetId]/evaluation/page.tsx new file mode 100644 index 0000000000..adc4bb6903 --- /dev/null +++ b/web/app/(commonLayout)/snippets/[snippetId]/evaluation/page.tsx @@ -0,0 +1,11 @@ +import SnippetEvaluationPage from '@/app/components/snippets/snippet-evaluation-page' + +const Page = async (props: { + params: Promise<{ snippetId: string }> +}) => { + const { snippetId } = await props.params + + return +} + +export default Page diff --git a/web/app/(commonLayout)/snippets/[snippetId]/orchestrate/page.tsx b/web/app/(commonLayout)/snippets/[snippetId]/orchestrate/page.tsx new file mode 100644 index 0000000000..8a39dc710b --- /dev/null +++ b/web/app/(commonLayout)/snippets/[snippetId]/orchestrate/page.tsx @@ -0,0 +1,11 @@ +import SnippetPage from '@/app/components/snippets' + +const Page = async (props: { + params: Promise<{ snippetId: string }> +}) => { + const { snippetId } = await props.params + + return +} + +export default Page diff --git a/web/app/(commonLayout)/snippets/[snippetId]/page.spec.ts b/web/app/(commonLayout)/snippets/[snippetId]/page.spec.ts new file mode 100644 index 0000000000..578c562848 --- /dev/null +++ b/web/app/(commonLayout)/snippets/[snippetId]/page.spec.ts @@ -0,0 +1,21 @@ +import Page from './page' + +const mockRedirect = vi.fn() + +vi.mock('next/navigation', () => ({ + redirect: (path: string) => mockRedirect(path), +})) + +describe('snippet detail redirect page', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should redirect legacy snippet detail routes to orchestrate', async () => { + await Page({ + params: Promise.resolve({ snippetId: 'snippet-1' }), + }) + + expect(mockRedirect).toHaveBeenCalledWith('/snippets/snippet-1/orchestrate') + }) +}) diff --git a/web/app/(commonLayout)/snippets/[snippetId]/page.tsx b/web/app/(commonLayout)/snippets/[snippetId]/page.tsx new file mode 100644 index 0000000000..3b35e29360 --- /dev/null +++ b/web/app/(commonLayout)/snippets/[snippetId]/page.tsx @@ -0,0 +1,11 @@ +import { redirect } from 'next/navigation' + +const Page = async (props: { + params: Promise<{ snippetId: string }> +}) => { + const { snippetId } = await props.params + + redirect(`/snippets/${snippetId}/orchestrate`) +} + +export default Page diff --git a/web/app/(commonLayout)/snippets/page.tsx b/web/app/(commonLayout)/snippets/page.tsx new file mode 100644 index 0000000000..73fadb0d27 --- /dev/null +++ b/web/app/(commonLayout)/snippets/page.tsx @@ -0,0 +1,12 @@ +import Apps from '@/app/components/apps' +import SnippetAndEvaluationPlanGuard from '@/app/components/billing/snippet-and-evaluation-plan-guard' + +const SnippetsPage = () => { + return ( + + + + ) +} + +export default SnippetsPage diff --git a/web/app/components/app-sidebar/__tests__/index.spec.tsx b/web/app/components/app-sidebar/__tests__/index.spec.tsx index b2e1e92bbb..1b6046baee 100644 --- a/web/app/components/app-sidebar/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/index.spec.tsx @@ -165,6 +165,21 @@ describe('AppDetailNav', () => { ) expect(screen.queryByTestId('extra-info')).not.toBeInTheDocument() }) + + it('should render custom header and navigation when provided', () => { + render( +
} + renderNavigation={mode =>
} + />, + ) + + expect(screen.getByTestId('custom-header')).toHaveAttribute('data-mode', 'expand') + expect(screen.getByTestId('custom-navigation')).toHaveAttribute('data-mode', 'expand') + expect(screen.queryByTestId('app-info')).not.toBeInTheDocument() + expect(screen.queryByTestId('nav-link-Overview')).not.toBeInTheDocument() + }) }) describe('Workflow canvas mode', () => { diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx index 2171974253..58a1c5a3c6 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx @@ -2,7 +2,7 @@ import type { App, AppSSO } from '@/types/app' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' -import { AppModeEnum } from '@/types/app' +import { AppModeEnum, AppTypeEnum } from '@/types/app' import AppInfoDetailPanel from '../app-info-detail-panel' vi.mock('../../../base/app-icon', () => ({ @@ -135,6 +135,17 @@ describe('AppInfoDetailPanel', () => { expect(cardView).toHaveAttribute('data-app-id', 'app-1') }) + it('should not render CardView when app type is evaluation', () => { + render( + , + ) + + expect(screen.queryByTestId('card-view')).not.toBeInTheDocument() + }) + it('should render app icon with large size', () => { render() const icon = screen.getByTestId('app-icon') diff --git a/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx b/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx index 624630b179..4e90c238ea 100644 --- a/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx +++ b/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx @@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next' import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view' import ContentDialog from '@/app/components/base/content-dialog' import { Button } from '@/app/components/base/ui/button' -import { AppModeEnum } from '@/types/app' +import { AppModeEnum, AppTypeEnum } from '@/types/app' import AppIcon from '../../base/app-icon' import { getAppModeLabel } from './app-mode-labels' import AppOperations from './app-operations' @@ -126,11 +126,13 @@ const AppInfoDetailPanel = ({ secondaryOperations={secondaryOperations} />
- + {appDetail.type !== AppTypeEnum.EVALUATION && ( + + )} {switchOperation && (
@@ -136,7 +141,8 @@ const AppDetailNav = ({ expand ? 'px-3 py-2' : 'p-3', )} > - {navigation.map((item, index) => { + {renderNavigation?.(appSidebarExpand)} + {!renderNavigation && navigation.map((item, index) => { return ( { expect(iconWrapper).toHaveClass('-ml-1') }) }) + + describe('Button Mode', () => { + it('should render as an interactive button when href is omitted', () => { + const onClick = vi.fn() + + render() + + const buttonElement = screen.getByText('Orchestrate').closest('button') + expect(buttonElement).not.toBeNull() + expect(buttonElement).toHaveClass('bg-components-menu-item-bg-active') + expect(buttonElement).toHaveClass('text-text-accent-light-mode-only') + + buttonElement?.click() + expect(onClick).toHaveBeenCalledTimes(1) + }) + }) }) diff --git a/web/app/components/app-sidebar/nav-link/index.tsx b/web/app/components/app-sidebar/nav-link/index.tsx index a2f737da16..89ad74d3c2 100644 --- a/web/app/components/app-sidebar/nav-link/index.tsx +++ b/web/app/components/app-sidebar/nav-link/index.tsx @@ -14,13 +14,15 @@ export type NavIcon = React.ComponentType< export type NavLinkProps = { name: string - href: string + href?: string iconMap: { selected: NavIcon normal: NavIcon } mode?: string disabled?: boolean + active?: boolean + onClick?: () => void } const NavLink = ({ @@ -29,6 +31,8 @@ const NavLink = ({ iconMap, mode = 'expand', disabled = false, + active, + onClick, }: NavLinkProps) => { const segment = useSelectedLayoutSegment() const formattedSegment = (() => { @@ -39,8 +43,11 @@ const NavLink = ({ return res })() - const isActive = href.toLowerCase().split('/')?.pop() === formattedSegment + const isActive = active ?? (href ? href.toLowerCase().split('/')?.pop() === formattedSegment : false) const NavIcon = isActive ? iconMap.selected : iconMap.normal + const linkClassName = cn(isActive + ? 'border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only system-sm-semibold' + : 'text-components-menu-item-text system-sm-medium hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover', 'flex h-8 items-center rounded-lg pl-3 pr-1') const renderIcon = () => (
@@ -70,13 +77,32 @@ const NavLink = ({ ) } + if (!href) { + return ( + + ) + } + return ( {renderIcon()} diff --git a/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx b/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx new file mode 100644 index 0000000000..4c24af692b --- /dev/null +++ b/web/app/components/app-sidebar/snippet-info/__tests__/dropdown.spec.tsx @@ -0,0 +1,285 @@ +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import type { CreateSnippetDialogPayload } from '@/app/components/workflow/create-snippet-dialog' +import type { SnippetDetail } from '@/models/snippet' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import SnippetInfoDropdown from '../dropdown' + +const mockReplace = vi.fn() +const mockDownloadBlob = vi.fn() +const mockToastSuccess = vi.fn() +const mockToastError = vi.fn() +const mockUpdateMutate = vi.fn() +const mockExportMutateAsync = vi.fn() +const mockDeleteMutate = vi.fn() +let mockDropdownOpen = false +let mockDropdownOnOpenChange: ((open: boolean) => void) | undefined + +vi.mock('@/next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), +})) + +vi.mock('@/utils/download', () => ({ + downloadBlob: (args: { data: Blob, fileName: string }) => mockDownloadBlob(args), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (...args: unknown[]) => mockToastSuccess(...args), + error: (...args: unknown[]) => mockToastError(...args), + }, +})) + +vi.mock('@/app/components/base/ui/dropdown-menu', () => ({ + DropdownMenu: ({ + open, + onOpenChange, + children, + }: { + open?: boolean + onOpenChange?: (open: boolean) => void + children: React.ReactNode + }) => { + mockDropdownOpen = !!open + mockDropdownOnOpenChange = onOpenChange + return
{children}
+ }, + DropdownMenuTrigger: ({ + children, + className, + }: { + children: React.ReactNode + className?: string + }) => ( + + ), + DropdownMenuContent: ({ children }: { children: React.ReactNode }) => ( + mockDropdownOpen ?
{children}
: null + ), + DropdownMenuItem: ({ + children, + onClick, + }: { + children: React.ReactNode + onClick?: () => void + }) => ( + + ), + DropdownMenuSeparator: () =>
, +})) + +vi.mock('@/service/use-snippets', () => ({ + useUpdateSnippetMutation: () => ({ + mutate: mockUpdateMutate, + isPending: false, + }), + useExportSnippetMutation: () => ({ + mutateAsync: mockExportMutateAsync, + isPending: false, + }), + useDeleteSnippetMutation: () => ({ + mutate: mockDeleteMutate, + isPending: false, + }), +})) + +type MockCreateSnippetDialogProps = { + isOpen: boolean + title?: string + confirmText?: string + initialValue?: { + name?: string + description?: string + icon?: AppIconSelection + } + onClose: () => void + onConfirm: (payload: CreateSnippetDialogPayload) => void +} + +vi.mock('@/app/components/workflow/create-snippet-dialog', () => ({ + default: ({ + isOpen, + title, + confirmText, + initialValue, + onClose, + onConfirm, + }: MockCreateSnippetDialogProps) => { + if (!isOpen) + return null + + return ( +
+
{title}
+
{confirmText}
+
{initialValue?.name}
+
{initialValue?.description}
+ + +
+ ) + }, +})) + +const mockSnippet: SnippetDetail = { + id: 'snippet-1', + name: 'Social Media Repurposer', + description: 'Turn one blog post into multiple social media variations.', + author: 'Dify', + updatedAt: '2026-03-25 10:00', + usage: '12', + icon: '🤖', + iconBackground: '#F0FDF9', + status: undefined, +} + +describe('SnippetInfoDropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDropdownOpen = false + mockDropdownOnOpenChange = undefined + }) + + // Rendering coverage for the menu trigger itself. + describe('Rendering', () => { + it('should render the dropdown trigger button', () => { + render() + + expect(screen.getByRole('button')).toBeInTheDocument() + }) + }) + + // Edit flow should seed the dialog with current snippet info and submit updates. + describe('Edit Snippet', () => { + it('should open the edit dialog and submit snippet updates', async () => { + const user = userEvent.setup() + mockUpdateMutate.mockImplementation((_variables: unknown, options?: { onSuccess?: () => void }) => { + options?.onSuccess?.() + }) + + render() + await user.click(screen.getByRole('button')) + await user.click(screen.getByText('snippet.menu.editInfo')) + + expect(screen.getByTestId('create-snippet-dialog')).toBeInTheDocument() + expect(screen.getByText('snippet.editDialogTitle')).toBeInTheDocument() + expect(screen.getByText('common.operation.save')).toBeInTheDocument() + expect(screen.getByText(mockSnippet.name)).toBeInTheDocument() + expect(screen.getByText(mockSnippet.description)).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'submit-edit' })) + + expect(mockUpdateMutate).toHaveBeenCalledWith({ + params: { snippetId: mockSnippet.id }, + body: { + name: 'Updated snippet', + description: 'Updated description', + icon_info: { + icon: '✨', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: undefined, + }, + }, + }, expect.objectContaining({ + onSuccess: expect.any(Function), + onError: expect.any(Function), + })) + expect(mockToastSuccess).toHaveBeenCalledWith('snippet.editDone') + }) + }) + + // Export should call the export hook and download the returned YAML blob. + describe('Export Snippet', () => { + it('should export and download the snippet yaml', async () => { + const user = userEvent.setup() + mockExportMutateAsync.mockResolvedValue('yaml: content') + + render() + + await user.click(screen.getByRole('button')) + await user.click(screen.getByText('snippet.menu.exportSnippet')) + + await waitFor(() => { + expect(mockExportMutateAsync).toHaveBeenCalledWith({ snippetId: mockSnippet.id }) + }) + + expect(mockDownloadBlob).toHaveBeenCalledWith({ + data: expect.any(Blob), + fileName: `${mockSnippet.name}.yml`, + }) + }) + + it('should show an error toast when export fails', async () => { + const user = userEvent.setup() + mockExportMutateAsync.mockRejectedValue(new Error('export failed')) + + render() + + await user.click(screen.getByRole('button')) + await user.click(screen.getByText('snippet.menu.exportSnippet')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('snippet.exportFailed') + }) + }) + }) + + // Delete should require confirmation and redirect after a successful mutation. + describe('Delete Snippet', () => { + it('should confirm deletion and redirect to the snippets list', async () => { + const user = userEvent.setup() + mockDeleteMutate.mockImplementation((_variables: unknown, options?: { onSuccess?: () => void }) => { + options?.onSuccess?.() + }) + + render() + + await user.click(screen.getByRole('button')) + await user.click(screen.getByText('snippet.menu.deleteSnippet')) + + expect(screen.getByText('snippet.deleteConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('snippet.deleteConfirmContent')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'snippet.menu.deleteSnippet' })) + + expect(mockDeleteMutate).toHaveBeenCalledWith({ + params: { snippetId: mockSnippet.id }, + }, expect.objectContaining({ + onSuccess: expect.any(Function), + onError: expect.any(Function), + })) + expect(mockToastSuccess).toHaveBeenCalledWith('snippet.deleted') + expect(mockReplace).toHaveBeenCalledWith('/snippets') + }) + }) +}) diff --git a/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx new file mode 100644 index 0000000000..50754ffd23 --- /dev/null +++ b/web/app/components/app-sidebar/snippet-info/__tests__/index.spec.tsx @@ -0,0 +1,62 @@ +import type { SnippetDetail } from '@/models/snippet' +import { render, screen } from '@testing-library/react' +import * as React from 'react' +import SnippetInfo from '..' + +vi.mock('../dropdown', () => ({ + default: () =>
, +})) + +const mockSnippet: SnippetDetail = { + id: 'snippet-1', + name: 'Social Media Repurposer', + description: 'Turn one blog post into multiple social media variations.', + author: 'Dify', + updatedAt: '2026-03-25 10:00', + usage: '12', + icon: '🤖', + iconBackground: '#F0FDF9', + status: undefined, +} + +describe('SnippetInfo', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Rendering tests for the collapsed and expanded sidebar header states. + describe('Rendering', () => { + it('should render the expanded snippet details and dropdown when expand is true', () => { + render() + + expect(screen.getByText(mockSnippet.name)).toBeInTheDocument() + expect(screen.getByText('snippet.typeLabel')).toBeInTheDocument() + expect(screen.getByText(mockSnippet.description)).toBeInTheDocument() + expect(screen.getByTestId('snippet-info-dropdown')).toBeInTheDocument() + }) + + it('should hide the expanded-only content when expand is false', () => { + render() + + expect(screen.queryByText(mockSnippet.name)).not.toBeInTheDocument() + expect(screen.queryByText('snippet.typeLabel')).not.toBeInTheDocument() + expect(screen.queryByText(mockSnippet.description)).not.toBeInTheDocument() + expect(screen.queryByTestId('snippet-info-dropdown')).not.toBeInTheDocument() + }) + }) + + // Edge cases around optional snippet fields should not break the header layout. + describe('Edge Cases', () => { + it('should omit the description block when the snippet has no description', () => { + render( + , + ) + + expect(screen.getByText(mockSnippet.name)).toBeInTheDocument() + expect(screen.queryByText(mockSnippet.description)).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/snippet-info/dropdown.tsx b/web/app/components/app-sidebar/snippet-info/dropdown.tsx new file mode 100644 index 0000000000..1e1bddf615 --- /dev/null +++ b/web/app/components/app-sidebar/snippet-info/dropdown.tsx @@ -0,0 +1,198 @@ +'use client' + +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import type { SnippetDetail } from '@/models/snippet' +import { cn } from '@langgenius/dify-ui/cn' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, +} from '@/app/components/base/ui/alert-dialog' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' +import { toast } from '@/app/components/base/ui/toast' +import CreateSnippetDialog from '@/app/components/workflow/create-snippet-dialog' +import { useRouter } from '@/next/navigation' +import { useDeleteSnippetMutation, useExportSnippetMutation, useUpdateSnippetMutation } from '@/service/use-snippets' + +import { downloadBlob } from '@/utils/download' + +type SnippetInfoDropdownProps = { + snippet: SnippetDetail +} + +const FALLBACK_ICON: AppIconSelection = { + type: 'emoji', + icon: '🤖', + background: '#FFEAD5', +} + +const SnippetInfoDropdown = ({ snippet }: SnippetInfoDropdownProps) => { + const { t } = useTranslation('snippet') + const { replace } = useRouter() + const [open, setOpen] = React.useState(false) + const [isEditDialogOpen, setIsEditDialogOpen] = React.useState(false) + const [isDeleteDialogOpen, setIsDeleteDialogOpen] = React.useState(false) + const updateSnippetMutation = useUpdateSnippetMutation() + const exportSnippetMutation = useExportSnippetMutation() + const deleteSnippetMutation = useDeleteSnippetMutation() + + const initialValue = React.useMemo(() => ({ + name: snippet.name, + description: snippet.description, + icon: snippet.icon + ? { + type: 'emoji' as const, + icon: snippet.icon, + background: snippet.iconBackground || FALLBACK_ICON.background, + } + : FALLBACK_ICON, + }), [snippet.description, snippet.icon, snippet.iconBackground, snippet.name]) + + const handleOpenEditDialog = React.useCallback(() => { + setOpen(false) + setIsEditDialogOpen(true) + }, []) + + const handleExportSnippet = React.useCallback(async () => { + setOpen(false) + try { + const data = await exportSnippetMutation.mutateAsync({ snippetId: snippet.id }) + const file = new Blob([data], { type: 'application/yaml' }) + downloadBlob({ data: file, fileName: `${snippet.name}.yml` }) + } + catch { + toast.error(t('exportFailed')) + } + }, [exportSnippetMutation, snippet.id, snippet.name, t]) + + const handleEditSnippet = React.useCallback(async ({ name, description, icon }: { + name: string + description: string + icon: AppIconSelection + }) => { + updateSnippetMutation.mutate({ + params: { snippetId: snippet.id }, + body: { + name, + description: description || undefined, + icon_info: { + icon: icon.type === 'emoji' ? icon.icon : icon.fileId, + icon_type: icon.type, + icon_background: icon.type === 'emoji' ? icon.background : undefined, + icon_url: icon.type === 'image' ? icon.url : undefined, + }, + }, + }, { + onSuccess: () => { + toast.success(t('editDone')) + setIsEditDialogOpen(false) + }, + onError: (error) => { + toast.error(error instanceof Error ? error.message : t('editFailed')) + }, + }) + }, [snippet.id, t, updateSnippetMutation]) + + const handleDeleteSnippet = React.useCallback(() => { + deleteSnippetMutation.mutate({ + params: { snippetId: snippet.id }, + }, { + onSuccess: () => { + toast.success(t('deleted')) + setIsDeleteDialogOpen(false) + replace('/snippets') + }, + onError: (error) => { + toast.error(error instanceof Error ? error.message : t('deleteFailed')) + }, + }) + }, [deleteSnippetMutation, replace, snippet.id, t]) + + return ( + <> + + + + + + + + {t('menu.editInfo')} + + + + {t('menu.exportSnippet')} + + + { + setOpen(false) + setIsDeleteDialogOpen(true) + }} + > + + {t('menu.deleteSnippet')} + + + + + {isEditDialogOpen && ( + setIsEditDialogOpen(false)} + onConfirm={handleEditSnippet} + /> + )} + + + +
+ + {t('deleteConfirmTitle')} + + + {t('deleteConfirmContent')} + +
+ + + {t('operation.cancel', { ns: 'common' })} + + + {t('menu.deleteSnippet')} + + +
+
+ + ) +} + +export default React.memo(SnippetInfoDropdown) diff --git a/web/app/components/app-sidebar/snippet-info/index.tsx b/web/app/components/app-sidebar/snippet-info/index.tsx new file mode 100644 index 0000000000..cc1f867026 --- /dev/null +++ b/web/app/components/app-sidebar/snippet-info/index.tsx @@ -0,0 +1,55 @@ +'use client' + +import type { SnippetDetail } from '@/models/snippet' +import { cn } from '@langgenius/dify-ui/cn' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import AppIcon from '@/app/components/base/app-icon' +import SnippetInfoDropdown from './dropdown' + +type SnippetInfoProps = { + expand: boolean + snippet: SnippetDetail +} + +const SnippetInfo = ({ + expand, + snippet, +}: SnippetInfoProps) => { + const { t } = useTranslation('snippet') + + return ( +
+
+
+
+ +
+ {expand && } +
+ {expand && ( +
+
+ {snippet.name} +
+
+ {t('typeLabel')} +
+
+ )} + {expand && snippet.description && ( +

+ {snippet.description} +

+ )} +
+
+ ) +} + +export default React.memo(SnippetInfo) diff --git a/web/app/components/app/app-publisher/__tests__/index.spec.tsx b/web/app/components/app/app-publisher/__tests__/index.spec.tsx index bba2f53afc..86b45a2a79 100644 --- a/web/app/components/app/app-publisher/__tests__/index.spec.tsx +++ b/web/app/components/app/app-publisher/__tests__/index.spec.tsx @@ -2,7 +2,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { AccessMode } from '@/models/access-control' -import { AppModeEnum } from '@/types/app' +import { AppModeEnum, AppTypeEnum } from '@/types/app' import { basePath } from '@/utils/var' import AppPublisher from '../index' @@ -15,6 +15,8 @@ const mockOpenAsyncWindow = vi.fn() const mockFetchInstalledAppList = vi.fn() const mockFetchAppDetailDirect = vi.fn() const mockToastError = vi.fn() +const mockConvertWorkflowType = vi.fn() +const mockRefetchEvaluationWorkflowAssociatedTargets = vi.fn() const mockInvalidateAppWorkflow = vi.fn() const sectionProps = vi.hoisted(() => ({ @@ -27,6 +29,7 @@ const ahooksMocks = vi.hoisted(() => ({ })) let mockAppDetail: Record | null = null +let mockEvaluationWorkflowAssociatedTargets: Record | undefined vi.mock('react-i18next', () => ({ useTranslation: () => ({ @@ -89,6 +92,21 @@ vi.mock('@/service/apps', () => ({ fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args), })) +vi.mock('@/service/use-apps', () => ({ + useConvertWorkflowTypeMutation: () => ({ + mutateAsync: (...args: unknown[]) => mockConvertWorkflowType(...args), + isPending: false, + }), +})) + +vi.mock('@/service/use-evaluation', () => ({ + useEvaluationWorkflowAssociatedTargets: () => ({ + data: mockEvaluationWorkflowAssociatedTargets, + refetch: mockRefetchEvaluationWorkflowAssociatedTargets, + isFetching: false, + }), +})) + vi.mock('@/service/use-workflow', () => ({ useInvalidateAppWorkflow: () => mockInvalidateAppWorkflow, })) @@ -129,15 +147,15 @@ vi.mock('@/app/components/base/portal-to-follow-elem', async () => { return { PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( - +
{children}
-
+ ), PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => (
{children}
), PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => { - const open = ReactModule.useContext(OpenContext) + const open = ReactModule.use(OpenContext) return open ?
{children}
: null }, } @@ -150,6 +168,7 @@ vi.mock('../sections', () => ({
+
) }, @@ -180,6 +199,7 @@ describe('AppPublisher', () => { name: 'Demo App', mode: AppModeEnum.CHAT, access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + type: AppTypeEnum.WORKFLOW, site: { app_base_url: 'https://example.com', access_token: 'token-1', @@ -192,6 +212,12 @@ describe('AppPublisher', () => { id: 'app-1', access_mode: AccessMode.PUBLIC, }) + mockConvertWorkflowType.mockResolvedValue({}) + mockEvaluationWorkflowAssociatedTargets = { items: [] } + mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValue({ + data: { items: [] }, + isError: false, + }) mockOpenAsyncWindow.mockImplementation(async (resolver: () => Promise) => { await resolver() }) @@ -457,4 +483,352 @@ describe('AppPublisher', () => { }) expect(screen.getByTestId('access-control'))!.toBeInTheDocument() }) + + it('should switch workflow type, refresh app detail, and close the popover for published apps', async () => { + mockFetchAppDetailDirect.mockResolvedValueOnce({ + id: 'app-1', + type: AppTypeEnum.EVALUATION, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockConvertWorkflowType).toHaveBeenCalledWith({ + params: { appId: 'app-1' }, + query: { target_type: AppTypeEnum.EVALUATION }, + }) + expect(mockFetchAppDetailDirect).toHaveBeenCalledWith({ url: '/apps', id: 'app-1' }) + expect(mockSetAppDetail).toHaveBeenCalledWith({ + id: 'app-1', + type: AppTypeEnum.EVALUATION, + }) + }) + expect(screen.queryByText('publisher-summary-publish')).not.toBeInTheDocument() + }) + + it('should hide access and actions sections for evaluation workflow apps', () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + + expect(screen.getByText('publisher-summary-publish')).toBeInTheDocument() + expect(screen.queryByText('publisher-access-control')).not.toBeInTheDocument() + expect(screen.queryByText('publisher-embed')).not.toBeInTheDocument() + expect(sectionProps.summary?.workflowTypeSwitchConfig).toEqual({ + targetType: AppTypeEnum.WORKFLOW, + publishLabelKey: 'common.publishAsStandardWorkflow', + switchLabelKey: 'common.switchToStandardWorkflow', + tipKey: 'common.switchToStandardWorkflowTip', + }) + }) + + it('should confirm before switching an evaluation workflow with associated targets to a standard workflow', async () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + mockEvaluationWorkflowAssociatedTargets = { + items: [ + { + target_type: 'app', + target_id: 'dependent-app-1', + target_name: 'Dependent App', + }, + { + target_type: 'knowledge_base', + target_id: 'knowledge-1', + target_name: 'Knowledge Base', + }, + ], + } + mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValueOnce({ + data: mockEvaluationWorkflowAssociatedTargets, + isError: false, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockRefetchEvaluationWorkflowAssociatedTargets).toHaveBeenCalledTimes(1) + }) + expect(mockConvertWorkflowType).not.toHaveBeenCalled() + expect(screen.getByText('Dependent App')).toBeInTheDocument() + expect(screen.getByText('Knowledge Base')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'common.switchToStandardWorkflowConfirm.switch' })) + + await waitFor(() => { + expect(mockConvertWorkflowType).toHaveBeenCalledWith({ + params: { appId: 'app-1' }, + query: { target_type: AppTypeEnum.WORKFLOW }, + }) + }) + }) + + it('should switch an evaluation workflow directly when there are no associated targets', async () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockRefetchEvaluationWorkflowAssociatedTargets).toHaveBeenCalledTimes(1) + expect(mockConvertWorkflowType).toHaveBeenCalledWith({ + params: { appId: 'app-1' }, + query: { target_type: AppTypeEnum.WORKFLOW }, + }) + }) + expect(screen.queryByText('common.switchToStandardWorkflowConfirm.title')).not.toBeInTheDocument() + }) + + it('should block switching an evaluation workflow when associated targets fail to load', async () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValueOnce({ + data: undefined, + isError: true, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('common.switchToStandardWorkflowConfirm.loadFailed') + }) + expect(mockConvertWorkflowType).not.toHaveBeenCalled() + }) + + it('should block switching to evaluation workflow when restricted nodes exist', async () => { + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('common.switchToEvaluationWorkflowDisabledTip') + }) + + expect(mockConvertWorkflowType).not.toHaveBeenCalled() + expect(sectionProps.summary?.workflowTypeSwitchDisabled).toBe(true) + expect(sectionProps.summary?.workflowTypeSwitchDisabledReason).toBe('common.switchToEvaluationWorkflowDisabledTip') + }) + + it('should switch workflow type, refresh app detail, and close the popover for published apps', async () => { + mockFetchAppDetailDirect.mockResolvedValueOnce({ + id: 'app-1', + type: AppTypeEnum.EVALUATION, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockConvertWorkflowType).toHaveBeenCalledWith({ + params: { appId: 'app-1' }, + query: { target_type: AppTypeEnum.EVALUATION }, + }) + expect(mockFetchAppDetailDirect).toHaveBeenCalledWith({ url: '/apps', id: 'app-1' }) + expect(mockSetAppDetail).toHaveBeenCalledWith({ + id: 'app-1', + type: AppTypeEnum.EVALUATION, + }) + }) + expect(screen.queryByText('publisher-summary-publish')).not.toBeInTheDocument() + }) + + it('should hide access and actions sections for evaluation workflow apps', () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + + expect(screen.getByText('publisher-summary-publish')).toBeInTheDocument() + expect(screen.queryByText('publisher-access-control')).not.toBeInTheDocument() + expect(screen.queryByText('publisher-embed')).not.toBeInTheDocument() + expect(sectionProps.summary?.workflowTypeSwitchConfig).toEqual({ + targetType: AppTypeEnum.WORKFLOW, + publishLabelKey: 'common.publishAsStandardWorkflow', + switchLabelKey: 'common.switchToStandardWorkflow', + tipKey: 'common.switchToStandardWorkflowTip', + }) + }) + + it('should confirm before switching an evaluation workflow with associated targets to a standard workflow', async () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + mockEvaluationWorkflowAssociatedTargets = { + items: [ + { + target_type: 'app', + target_id: 'dependent-app-1', + target_name: 'Dependent App', + }, + { + target_type: 'knowledge_base', + target_id: 'knowledge-1', + target_name: 'Knowledge Base', + }, + ], + } + mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValueOnce({ + data: mockEvaluationWorkflowAssociatedTargets, + isError: false, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockRefetchEvaluationWorkflowAssociatedTargets).toHaveBeenCalledTimes(1) + }) + expect(mockConvertWorkflowType).not.toHaveBeenCalled() + expect(screen.getByText('Dependent App')).toBeInTheDocument() + expect(screen.getByText('Knowledge Base')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'common.switchToStandardWorkflowConfirm.switch' })) + + await waitFor(() => { + expect(mockConvertWorkflowType).toHaveBeenCalledWith({ + params: { appId: 'app-1' }, + query: { target_type: AppTypeEnum.WORKFLOW }, + }) + }) + }) + + it('should switch an evaluation workflow directly when there are no associated targets', async () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockRefetchEvaluationWorkflowAssociatedTargets).toHaveBeenCalledTimes(1) + expect(mockConvertWorkflowType).toHaveBeenCalledWith({ + params: { appId: 'app-1' }, + query: { target_type: AppTypeEnum.WORKFLOW }, + }) + }) + expect(screen.queryByText('common.switchToStandardWorkflowConfirm.title')).not.toBeInTheDocument() + }) + + it('should block switching an evaluation workflow when associated targets fail to load', async () => { + mockAppDetail = { + ...mockAppDetail, + type: AppTypeEnum.EVALUATION, + } + mockRefetchEvaluationWorkflowAssociatedTargets.mockResolvedValueOnce({ + data: undefined, + isError: true, + }) + + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('common.switchToStandardWorkflowConfirm.loadFailed') + }) + expect(mockConvertWorkflowType).not.toHaveBeenCalled() + }) + + it('should block switching to evaluation workflow when restricted nodes exist', async () => { + render( + , + ) + + fireEvent.click(screen.getByText('common.publish')) + fireEvent.click(screen.getByText('publisher-switch-workflow-type')) + + await waitFor(() => { + expect(mockToastError).toHaveBeenCalledWith('common.switchToEvaluationWorkflowDisabledTip') + }) + + expect(mockConvertWorkflowType).not.toHaveBeenCalled() + expect(sectionProps.summary?.workflowTypeSwitchDisabled).toBe(true) + expect(sectionProps.summary?.workflowTypeSwitchDisabledReason).toBe('common.switchToEvaluationWorkflowDisabledTip') + }) }) diff --git a/web/app/components/app/app-publisher/__tests__/sections.spec.tsx b/web/app/components/app/app-publisher/__tests__/sections.spec.tsx index 57e7a55b13..f242f1d6c2 100644 --- a/web/app/components/app/app-publisher/__tests__/sections.spec.tsx +++ b/web/app/components/app/app-publisher/__tests__/sections.spec.tsx @@ -45,12 +45,14 @@ describe('app-publisher sections', () => { handleRestore={handleRestore} isChatApp multipleModelConfigs={[]} + onWorkflowTypeSwitch={vi.fn()} publishDisabled={false} published={false} publishedAt={Date.now()} publishShortcut={['ctrl', '⇧', 'P']} startNodeLimitExceeded={false} upgradeHighlightStyle={{}} + workflowTypeSwitchDisabled={false} />, ) @@ -83,12 +85,14 @@ describe('app-publisher sections', () => { handleRestore={vi.fn()} isChatApp={false} multipleModelConfigs={[]} + onWorkflowTypeSwitch={vi.fn()} publishDisabled={false} published={false} publishedAt={undefined} publishShortcut={['ctrl', '⇧', 'P']} startNodeLimitExceeded={false} upgradeHighlightStyle={{}} + workflowTypeSwitchDisabled={false} />, ) @@ -107,12 +111,14 @@ describe('app-publisher sections', () => { handleRestore={vi.fn()} isChatApp={false} multipleModelConfigs={[{ id: '1' } as any]} + onWorkflowTypeSwitch={vi.fn()} publishDisabled={false} published={false} publishedAt={undefined} publishShortcut={['ctrl', '⇧', 'P']} startNodeLimitExceeded={false} upgradeHighlightStyle={{}} + workflowTypeSwitchDisabled={false} />, ) @@ -131,18 +137,85 @@ describe('app-publisher sections', () => { handleRestore={vi.fn()} isChatApp={false} multipleModelConfigs={[]} + onWorkflowTypeSwitch={vi.fn()} publishDisabled={false} published={false} publishedAt={undefined} publishShortcut={['ctrl', '⇧', 'P']} startNodeLimitExceeded upgradeHighlightStyle={{}} + workflowTypeSwitchDisabled={false} />, ) expect(screen.getByText('publishLimit.startNodeDesc')).toBeInTheDocument() }) + it('should render workflow type switch action and call switch handler', () => { + const onWorkflowTypeSwitch = vi.fn() + + render( + '1 minute ago'} + handlePublish={vi.fn()} + handleRestore={vi.fn()} + isChatApp={false} + multipleModelConfigs={[]} + onWorkflowTypeSwitch={onWorkflowTypeSwitch} + publishDisabled={false} + published={false} + publishedAt={undefined} + publishShortcut={['ctrl', '⇧', 'P']} + startNodeLimitExceeded={false} + upgradeHighlightStyle={{}} + workflowTypeSwitchConfig={{ + targetType: 'evaluation', + publishLabelKey: 'common.publishAsEvaluationWorkflow', + switchLabelKey: 'common.switchToEvaluationWorkflow', + tipKey: 'common.switchToEvaluationWorkflowTip', + }} + workflowTypeSwitchDisabled={false} + />, + ) + + fireEvent.click(screen.getByText('common.publishAsEvaluationWorkflow')) + + expect(onWorkflowTypeSwitch).toHaveBeenCalledTimes(1) + }) + + it('should disable workflow type switch when a disabled reason is provided', () => { + render( + '1 minute ago'} + handlePublish={vi.fn()} + handleRestore={vi.fn()} + isChatApp={false} + multipleModelConfigs={[]} + onWorkflowTypeSwitch={vi.fn()} + publishDisabled={false} + published={false} + publishedAt={undefined} + publishShortcut={['ctrl', '⇧', 'P']} + startNodeLimitExceeded={false} + upgradeHighlightStyle={{}} + workflowTypeSwitchConfig={{ + targetType: 'evaluation', + publishLabelKey: 'common.publishAsEvaluationWorkflow', + switchLabelKey: 'common.switchToEvaluationWorkflow', + tipKey: 'common.switchToEvaluationWorkflowTip', + }} + workflowTypeSwitchDisabled + workflowTypeSwitchDisabledReason="common.switchToEvaluationWorkflowDisabledTip" + />, + ) + + expect(screen.getByRole('button', { name: /common\.publishAsEvaluationWorkflow/i })).toBeDisabled() + }) + it('should render loading access state and access mode labels when enabled', () => { const { rerender } = render( void + onConfirm: () => void +} + +const TARGET_TYPE_META: Record + href: (targetId: string) => string +}> = { + app: { + icon: 'i-ri-flow-chart', + iconClassName: 'bg-components-icon-bg-teal-soft text-util-colors-teal-teal-600', + labelKey: 'common.switchToStandardWorkflowConfirm.targetTypes.app', + href: targetId => `/app/${targetId}/workflow`, + }, + snippets: { + icon: 'i-ri-edit-2-line', + iconClassName: 'bg-components-icon-bg-violet-soft text-util-colors-violet-violet-600', + labelKey: 'common.switchToStandardWorkflowConfirm.targetTypes.snippets', + href: targetId => `/snippets/${targetId}/orchestrate`, + }, + knowledge_base: { + icon: 'i-ri-book-2-line', + iconClassName: 'bg-components-icon-bg-indigo-soft text-util-colors-blue-blue-600', + labelKey: 'common.switchToStandardWorkflowConfirm.targetTypes.knowledge_base', + href: targetId => `/datasets/${targetId}/documents`, + }, +} + +const getTargetMeta = (targetType: EvaluationWorkflowAssociatedTargetType) => { + return TARGET_TYPE_META[targetType] ?? TARGET_TYPE_META.app +} + +const DependentTargetItem = ({ + target, +}: { + target: EvaluationWorkflowAssociatedTarget +}) => { + const { t } = useTranslation() + const meta = getTargetMeta(target.target_type) + const targetName = target.target_name || target.target_id + + return ( + +
{showAppAccessControl && { setShowAppAccessControl(false) }} />} + void performWorkflowTypeSwitch()} + /> ) } diff --git a/web/app/components/app/app-publisher/sections.tsx b/web/app/components/app/app-publisher/sections.tsx index d4864a3763..54fbcee672 100644 --- a/web/app/components/app/app-publisher/sections.tsx +++ b/web/app/components/app/app-publisher/sections.tsx @@ -1,10 +1,10 @@ import type { CSSProperties, ReactNode } from 'react' import type { ModelAndParameter } from '../configuration/debug/types' import type { AppPublisherProps } from './index' -import type { PublishWorkflowParams } from '@/types/workflow' +import type { I18nKeysWithPrefix } from '@/types/i18n' +import type { PublishWorkflowParams, WorkflowTypeConversionTarget } from '@/types/workflow' import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' -import { CodeBrowser } from '@/app/components/base/icons/src/vender/line/development' import Loading from '@/app/components/base/loading' import { Button } from '@/app/components/base/ui/button' import { @@ -21,6 +21,8 @@ import PublishWithMultipleModel from './publish-with-multiple-model' import SuggestedAction from './suggested-action' import { ACCESS_MODE_MAP } from './utils' +type WorkflowTypeSwitchLabelKey = I18nKeysWithPrefix<'workflow', 'common.'> + type SummarySectionProps = Pick Promise handleRestore: () => Promise isChatApp: boolean + onWorkflowTypeSwitch: () => Promise published: boolean publishShortcut: string[] upgradeHighlightStyle: CSSProperties + workflowTypeSwitchConfig?: { + targetType: WorkflowTypeConversionTarget + publishLabelKey: WorkflowTypeSwitchLabelKey + switchLabelKey: WorkflowTypeSwitchLabelKey + tipKey: WorkflowTypeSwitchLabelKey + } + workflowTypeSwitchDisabled: boolean + workflowTypeSwitchDisabledReason?: string } type AccessSectionProps = { @@ -90,6 +101,28 @@ export const AccessModeDisplay = ({ mode }: { mode?: keyof typeof ACCESS_MODE_MA ) } +const ActionTooltip = ({ + disabled, + tooltip, + children, +}: { + disabled: boolean + tooltip?: ReactNode + children: ReactNode +}) => { + if (!disabled || !tooltip) + return <>{children} + + return ( + + {children}
} /> + + {tooltip} + + + ) +} + export const PublisherSummarySection = ({ debugWithMultipleModel = false, draftUpdatedAt, @@ -98,12 +131,16 @@ export const PublisherSummarySection = ({ handleRestore, isChatApp, multipleModelConfigs = [], + onWorkflowTypeSwitch, publishDisabled = false, published, publishedAt, publishShortcut, startNodeLimitExceeded = false, upgradeHighlightStyle, + workflowTypeSwitchConfig, + workflowTypeSwitchDisabled, + workflowTypeSwitchDisabledReason, }: SummarySectionProps) => { const { t } = useTranslation() @@ -164,6 +201,47 @@ export const PublisherSummarySection = ({
)} + {workflowTypeSwitchConfig && ( + + + + )} {startNodeLimitExceeded && (

{ - if (!disabled || !tooltip) - return <>{children} - - return ( - - {children}

} /> - - {tooltip} - - - ) -} - export const PublisherActionsSection = ({ appDetail, appURL, @@ -305,7 +361,7 @@ export const PublisherActionsSection = ({ } + icon={} > {t('common.embedIntoSite', { ns: 'workflow' })} diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 1de6e6ce0c..2783f66c3f 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -96,8 +96,8 @@ const AdvancedPromptInput: FC = ({ }, onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { - if (promptVariables[i]!.key === newExternalDataTool.variable) { - toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i]!.key })) + if (promptVariables[i].key === newExternalDataTool.variable) { + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 6b5c3acccb..0f6d2b94a6 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -94,8 +94,8 @@ const Prompt: FC = ({ }, onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { - if (promptVariables[i]!.key === newExternalDataTool.variable) { - toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i]!.key })) + if (promptVariables[i].key === newExternalDataTool.variable) { + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index aca2817249..951680c035 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -165,8 +165,8 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar }, onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { - if (promptVariables[i]!.key === newExternalDataTool.variable && i !== index) { - toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i]!.key })) + if (promptVariables[i].key === newExternalDataTool.variable && i !== index) { + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } @@ -220,7 +220,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar const handleRemoveVar = useCallback((index: number) => { const removeVar = promptVariables[index] - if (mode === AppModeEnum.COMPLETION && dataSets.length > 0 && removeVar!.is_context_var) { + if (mode === AppModeEnum.COMPLETION && dataSets.length > 0 && removeVar.is_context_var) { showDeleteContextVarModal() setRemoveIndex(index) return diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index d9200af773..ec4ceb5ed4 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -96,7 +96,7 @@ const GetAutomaticRes: FC = ({ const [model, setModel] = React.useState(localModel || { name: '', provider: '', - mode: mode as unknown as ModelModeType.chat, + mode: mode as unknown as ModelModeType, completion_params: {} as CompletionParams, }) const { diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx index eb1ee7e10c..b58520c379 100644 --- a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -78,7 +78,7 @@ export const GetCodeGeneratorResModal: FC = ( const [model, setModel] = React.useState(localModel || { name: '', provider: '', - mode: mode as unknown as ModelModeType.chat, + mode: mode as unknown as ModelModeType, completion_params: defaultCompletionParams, }) const { diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index e5080f26e4..d4ce935a4d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -1,11 +1,12 @@ +import type { CSSProperties } from 'react' import { noop } from 'es-toolkit/function' import { memo } from 'react' import { useTranslation } from 'react-i18next' import { Slider } from '@/app/components/base/ui/slider' -const weightedScoreSliderSlotClassNames = { - track: 'bg-util-colors-teal-teal-500', - indicator: 'bg-util-colors-blue-light-blue-light-500', +const weightedScoreSliderStyle: CSSProperties & Record<'--slider-track' | '--slider-range', string> = { + '--slider-track': 'var(--color-util-colors-teal-teal-500)', + '--slider-range': 'var(--color-util-colors-blue-light-blue-light-500)', } const formatNumber = (value: number) => { @@ -35,8 +36,8 @@ const WeightedScore = ({ return (
-
-
+
+
!readonly && onChange({ value: [v, (10 - v * 10) / 10] })} disabled={readonly} aria-label={t('weightedScore.semantic', { ns: 'dataset' })} - slotClassNames={weightedScoreSliderSlotClassNames} />
-
+
{t('weightedScore.semantic', { ns: 'dataset' })}
- {formatNumber(value.value[0]!)} + {formatNumber(value.value[0])}
-
- {formatNumber(value.value[1]!)} +
+ {formatNumber(value.value[1])}
{t('weightedScore.keyword', { ns: 'dataset' })}
diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx index 2e535baeac..4b21616d46 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx @@ -84,7 +84,7 @@ const DebugItem: FC = ({ style={style} >
-
+
# {index + 1}
@@ -115,7 +115,7 @@ const DebugItem: FC = ({ {showRemove && ( <> {(showDuplicate || showDebugAsSingleModel) && } - + {t('operation.remove', { ns: 'common' })} diff --git a/web/app/components/app/workflow-log/__tests__/evaluation-cell.spec.tsx b/web/app/components/app/workflow-log/__tests__/evaluation-cell.spec.tsx new file mode 100644 index 0000000000..e51ffdaaac --- /dev/null +++ b/web/app/components/app/workflow-log/__tests__/evaluation-cell.spec.tsx @@ -0,0 +1,75 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import EvaluationCell from '../evaluation-cell' + +describe('EvaluationCell', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render a placeholder when evaluation data is empty', () => { + render() + + expect(screen.getByText('-')).toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'appLog.table.header.evaluation' })).not.toBeInTheDocument() + }) + + it('should render a trigger button when evaluation data is available', () => { + render( + , + ) + + expect(screen.getByRole('button', { name: 'appLog.table.header.evaluation' })).toBeInTheDocument() + }) + }) + + describe('Interactions', () => { + it('should render evaluation details when clicking the trigger', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'appLog.table.header.evaluation' })) + + expect(await screen.findByTestId('workflow-log-evaluation-popover')).toBeInTheDocument() + expect(screen.getByText('Faithfulness')).toBeInTheDocument() + expect(screen.getByText('0.98')).toBeInTheDocument() + expect(screen.getByText('Knowledge Retrieval')).toBeInTheDocument() + }) + + it('should render boolean values using readable text', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'appLog.table.header.evaluation' })) + + expect(await screen.findByText('True')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/workflow-log/__tests__/index.spec.tsx b/web/app/components/app/workflow-log/__tests__/index.spec.tsx index 1fe3db30db..4414843ed0 100644 --- a/web/app/components/app/workflow-log/__tests__/index.spec.tsx +++ b/web/app/components/app/workflow-log/__tests__/index.spec.tsx @@ -215,6 +215,7 @@ const createMockWorkflowLog = (overrides: Partial = {}): W }, created_at: Date.now(), ...overrides, + evaluation: overrides.evaluation ?? [], }) const createMockLogsResponse = ( diff --git a/web/app/components/app/workflow-log/__tests__/list.spec.tsx b/web/app/components/app/workflow-log/__tests__/list.spec.tsx index 35e6369e67..1246096cf5 100644 --- a/web/app/components/app/workflow-log/__tests__/list.spec.tsx +++ b/web/app/components/app/workflow-log/__tests__/list.spec.tsx @@ -146,6 +146,7 @@ const createMockWorkflowLog = (overrides: Partial = {}): W email: 'test@example.com', }, created_at: Date.now(), + evaluation: [], ...overrides, }) @@ -181,7 +182,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(container.querySelector('.spin-animation'))!.toBeInTheDocument() + expect(container.querySelector('.spin-animation')).toBeInTheDocument() }) it('should render loading state when appDetail is undefined', () => { @@ -191,7 +192,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(container.querySelector('.spin-animation'))!.toBeInTheDocument() + expect(container.querySelector('.spin-animation')).toBeInTheDocument() }) it('should render table when data is available', () => { @@ -201,7 +202,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render all table headers', () => { @@ -211,11 +212,12 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('appLog.table.header.startTime'))!.toBeInTheDocument() - expect(screen.getByText('appLog.table.header.status'))!.toBeInTheDocument() - expect(screen.getByText('appLog.table.header.runtime'))!.toBeInTheDocument() - expect(screen.getByText('appLog.table.header.tokens'))!.toBeInTheDocument() - expect(screen.getByText('appLog.table.header.user'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.user')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.evaluation')).toBeInTheDocument() }) it('should render trigger column for workflow apps', () => { @@ -226,7 +228,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('appLog.table.header.triggered_from'))!.toBeInTheDocument() + expect(screen.getByText('appLog.table.header.triggered_from')).toBeInTheDocument() }) it('should not render trigger column for non-workflow apps', () => { @@ -256,7 +258,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Success'))!.toBeInTheDocument() + expect(screen.getByText('Success')).toBeInTheDocument() }) it('should render failure status correctly', () => { @@ -270,7 +272,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Failure'))!.toBeInTheDocument() + expect(screen.getByText('Failure')).toBeInTheDocument() }) it('should render stopped status correctly', () => { @@ -284,7 +286,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Stop'))!.toBeInTheDocument() + expect(screen.getByText('Stop')).toBeInTheDocument() }) it('should render running status correctly', () => { @@ -298,7 +300,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Running'))!.toBeInTheDocument() + expect(screen.getByText('Running')).toBeInTheDocument() }) it('should render partial-succeeded status correctly', () => { @@ -312,7 +314,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('Partial Success'))!.toBeInTheDocument() + expect(screen.getByText('Partial Success')).toBeInTheDocument() }) }) @@ -332,7 +334,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('John Doe'))!.toBeInTheDocument() + expect(screen.getByText('John Doe')).toBeInTheDocument() }) it('should display end user session id when created by end user', () => { @@ -347,7 +349,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('session-abc-123'))!.toBeInTheDocument() + expect(screen.getByText('session-abc-123')).toBeInTheDocument() }) it('should display N/A when no user info', () => { @@ -362,7 +364,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('N/A'))!.toBeInTheDocument() + expect(screen.getByText('N/A')).toBeInTheDocument() }) }) @@ -404,8 +406,9 @@ describe('WorkflowAppLogList', () => { // Arrow should rotate (indicated by class change) // The sort icon should have rotate-180 class for ascending - const sortIcon = startTimeHeader.closest('div')?.querySelector('svg') - expect(sortIcon)!.toBeInTheDocument() + const sortIcon = startTimeHeader.closest('div')?.querySelector('.i-heroicons-arrow-down') + expect(sortIcon).toBeInTheDocument() + expect(sortIcon).toHaveClass('rotate-180') }) it('should render sort arrow icon', () => { @@ -416,8 +419,8 @@ describe('WorkflowAppLogList', () => { ) // Check for ArrowDownIcon presence - const sortArrow = container.querySelector('svg.ml-0\\.5') - expect(sortArrow)!.toBeInTheDocument() + const sortArrow = container.querySelector('.i-heroicons-arrow-down') + expect(sortArrow).toBeInTheDocument() }) }) @@ -440,11 +443,11 @@ describe('WorkflowAppLogList', () => { ) const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]!) // Click first data row + await user.click(dataRows[1]) // Click first data row const dialog = await screen.findByRole('dialog') - expect(dialog)!.toBeInTheDocument() - expect(screen.getByText('appLog.runDetail.workflowTitle'))!.toBeInTheDocument() + expect(dialog).toBeInTheDocument() + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() }) it('should close drawer and call onRefresh when closing', async () => { @@ -459,7 +462,7 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]!) + await user.click(dataRows[1]) await screen.findByRole('dialog') // Close drawer using Escape key @@ -482,46 +485,42 @@ describe('WorkflowAppLogList', () => { const dataRows = screen.getAllByRole('row') const dataRow = dataRows[1] - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight - // Before click - no highlight // Before click - no highlight expect(dataRow).not.toHaveClass('bg-background-default-hover') // After click - has highlight (via currentLog state) - await user.click(dataRow!) + await user.click(dataRow) // The row should have the selected class - // The row should have the selected class - expect(dataRow)!.toHaveClass('bg-background-default-hover') + expect(dataRow).toHaveClass('bg-background-default-hover') + }) + + it('should open evaluation popover without opening drawer when clicking evaluation trigger', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + evaluation: [{ + name: 'Faithfulness', + value: 0.98, + nodeInfo: { + node_id: 'node-1', + title: 'Knowledge Retrieval', + type: 'knowledge-retrieval', + }, + }], + }), + ]) + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'appLog.table.header.evaluation' })) + + expect(await screen.findByTestId('workflow-log-evaluation-popover')).toBeInTheDocument() + expect(screen.getByText('Faithfulness')).toBeInTheDocument() + expect(screen.getByText('Knowledge Retrieval')).toBeInTheDocument() + expect(screen.queryByRole('heading', { name: 'appLog.runDetail.workflowTitle' })).not.toBeInTheDocument() }) }) @@ -547,7 +546,7 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]!) + await user.click(dataRows[1]) await screen.findByRole('dialog') // Replay button should be present for app-run triggers @@ -575,12 +574,12 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]!) + await user.click(dataRows[1]) await screen.findByRole('dialog') // Replay button should be present for debugging triggers const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) - expect(replayButton)!.toBeInTheDocument() + expect(replayButton).toBeInTheDocument() }) it('should not show replay for webhook triggers', async () => { @@ -601,40 +600,9 @@ describe('WorkflowAppLogList', () => { // Open drawer const dataRows = screen.getAllByRole('row') - await user.click(dataRows[1]!) + await user.click(dataRows[1]) await screen.findByRole('dialog') - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers - // Replay button should not be present for webhook triggers // Replay button should not be present for webhook triggers expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() }) @@ -657,7 +625,7 @@ describe('WorkflowAppLogList', () => { // Unread indicator is a small blue dot const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') - expect(unreadDot)!.toBeInTheDocument() + expect(unreadDot).toBeInTheDocument() }) it('should not show unread indicator for read logs', () => { @@ -692,7 +660,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('1.235s'))!.toBeInTheDocument() + expect(screen.getByText('1.235s')).toBeInTheDocument() }) it('should display 0 elapsed time with special styling', () => { @@ -707,8 +675,8 @@ describe('WorkflowAppLogList', () => { ) const zeroTime = screen.getByText('0.000s') - expect(zeroTime)!.toBeInTheDocument() - expect(zeroTime)!.toHaveClass('text-text-quaternary') + expect(zeroTime).toBeInTheDocument() + expect(zeroTime).toHaveClass('text-text-quaternary') }) }) @@ -727,7 +695,7 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('12345'))!.toBeInTheDocument() + expect(screen.getByText('12345')).toBeInTheDocument() }) }) @@ -743,7 +711,7 @@ describe('WorkflowAppLogList', () => { ) const table = screen.getByRole('table') - expect(table)!.toBeInTheDocument() + expect(table).toBeInTheDocument() // Should only have header row const rows = screen.getAllByRole('row') @@ -784,8 +752,8 @@ describe('WorkflowAppLogList', () => { , ) - expect(screen.getByText('0.000s'))!.toBeInTheDocument() - expect(screen.getByText('0'))!.toBeInTheDocument() + expect(screen.getByText('0.000s')).toBeInTheDocument() + expect(screen.getByText('0')).toBeInTheDocument() }) it('should handle null workflow_run.triggered_from for non-workflow apps', () => { @@ -802,37 +770,6 @@ describe('WorkflowAppLogList', () => { , ) - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column - // Should render without trigger column // Should render without trigger column expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() }) diff --git a/web/app/components/app/workflow-log/evaluation-cell.tsx b/web/app/components/app/workflow-log/evaluation-cell.tsx new file mode 100644 index 0000000000..3e49015454 --- /dev/null +++ b/web/app/components/app/workflow-log/evaluation-cell.tsx @@ -0,0 +1,100 @@ +'use client' + +import type { EvaluationLogItem } from '@/models/log' +import { cn } from '@langgenius/dify-ui/cn' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' +import { getNodeVisual, getToneClasses } from '@/app/components/evaluation/components/metric-selector/utils' + +type EvaluationCellProps = { + evaluation: EvaluationLogItem[] +} + +const formatEvaluationValue = (value: EvaluationLogItem['value']) => { + if (typeof value === 'boolean') + return value ? 'True' : 'False' + + return String(value) +} + +const EvaluationCell = ({ + evaluation, +}: EvaluationCellProps) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + + if (!evaluation.length) { + return ( +
+ - +
+ ) + } + + return ( + + + + ) +} + +export default EvaluationCell diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index e514962d9b..0f9b3f9620 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunTriggeredFrom } from '@/models/log' import type { App } from '@/types/app' -import { ArrowDownIcon } from '@heroicons/react/24/outline' import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useEffect, useState } from 'react' @@ -14,6 +13,7 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' import { AppModeEnum } from '@/types/app' import DetailPanel from './detail' +import EvaluationCell from './evaluation-cell' import TriggerByDisplay from './trigger-by-display' type ILogs = { @@ -118,22 +118,21 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { return (
- +
- + + {isWorkflow && } @@ -172,6 +171,9 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { {endUser} + {isWorkflow && (
{t('table.header.startTime', { ns: 'appLog' })} - +
{t('table.header.status', { ns: 'appLog' })} {t('table.header.runtime', { ns: 'appLog' })} {t('table.header.tokens', { ns: 'appLog' })}{t('table.header.user', { ns: 'appLog' })}{t('table.header.user', { ns: 'appLog' })}{t('table.header.evaluation', { ns: 'appLog' })}{t('table.header.triggered_from', { ns: 'appLog' })}
event.stopPropagation()}> + + diff --git a/web/app/components/apps/__tests__/empty.spec.tsx b/web/app/components/apps/__tests__/empty.spec.tsx index 8dbbbc3ffb..2536d61006 100644 --- a/web/app/components/apps/__tests__/empty.spec.tsx +++ b/web/app/components/apps/__tests__/empty.spec.tsx @@ -2,6 +2,8 @@ import { render, screen } from '@testing-library/react' import * as React from 'react' import Empty from '../empty' +const defaultMessage = 'workflow.tabs.noSnippetsFound' + describe('Empty', () => { beforeEach(() => { vi.clearAllMocks() @@ -9,32 +11,32 @@ describe('Empty', () => { describe('Rendering', () => { it('should render without crashing', () => { - render() - expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + render() + expect(screen.getByText(defaultMessage)).toBeInTheDocument() }) it('should render 36 placeholder cards', () => { - const { container } = render() + const { container } = render() const placeholderCards = container.querySelectorAll('.bg-background-default-lighter') expect(placeholderCards).toHaveLength(36) }) - it('should display the no apps found message', () => { - render() + it('should display the provided message', () => { + render() expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() }) }) describe('Styling', () => { it('should have correct container styling for overlay', () => { - const { container } = render() + const { container } = render() const overlay = container.querySelector('.pointer-events-none') expect(overlay).toBeInTheDocument() expect(overlay).toHaveClass('absolute', 'inset-0', 'z-20') }) it('should have correct styling for placeholder cards', () => { - const { container } = render() + const { container } = render() const card = container.querySelector('.bg-background-default-lighter') expect(card).toHaveClass('inline-flex', 'h-[160px]', 'rounded-xl') }) @@ -42,10 +44,10 @@ describe('Empty', () => { describe('Edge Cases', () => { it('should handle multiple renders without issues', () => { - const { rerender } = render() - expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + const { rerender } = render() + expect(screen.getByText(defaultMessage)).toBeInTheDocument() - rerender() + rerender() expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() }) }) diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index eddcb31d60..7a421dfa0f 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -1,4 +1,4 @@ -import { act, fireEvent, screen } from '@testing-library/react' +import { act, fireEvent, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { renderWithNuqs } from '@/test/nuqs-testing' @@ -15,10 +15,14 @@ vi.mock('@/next/navigation', () => ({ const mockIsCurrentWorkspaceEditor = vi.fn(() => true) const mockIsCurrentWorkspaceDatasetOperator = vi.fn(() => false) +const mockIsLoadingCurrentWorkspace = vi.fn(() => false) +const mockCanAccessSnippetsAndEvaluation = vi.fn(() => true) + vi.mock('@/context/app-context', () => ({ useAppContext: () => ({ isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor(), isCurrentWorkspaceDatasetOperator: mockIsCurrentWorkspaceDatasetOperator(), + isLoadingCurrentWorkspace: mockIsLoadingCurrentWorkspace(), }), })) @@ -30,12 +34,21 @@ vi.mock('@/context/global-public-context', () => ({ }), })) +vi.mock('@/hooks/use-snippet-and-evaluation-plan-access', () => ({ + useSnippetAndEvaluationPlanAccess: () => ({ + canAccess: mockCanAccessSnippetsAndEvaluation(), + isReady: true, + }), +})) + const mockSetQuery = vi.fn() const mockQueryState = { tagIDs: [] as string[], + creatorIDs: [] as string[], keywords: '', isCreatedByMe: false, } + vi.mock('../hooks/use-apps-query-state', () => ({ default: () => ({ query: mockQueryState, @@ -45,6 +58,7 @@ vi.mock('../hooks/use-apps-query-state', () => ({ let mockOnDSLFileDropped: ((file: File) => void) | null = null let mockDragging = false + vi.mock('../hooks/use-dsl-drag-drop', () => ({ useDSLDragDrop: ({ onDSLFileDropped }: { onDSLFileDropped: (file: File) => void }) => { mockOnDSLFileDropped = onDSLFileDropped @@ -54,11 +68,15 @@ vi.mock('../hooks/use-dsl-drag-drop', () => ({ const mockRefetch = vi.fn() const mockFetchNextPage = vi.fn() +const mockFetchSnippetNextPage = vi.fn() +const mockUseInfiniteAppList = vi.fn() +const mockUseInfiniteSnippetList = vi.fn() const mockServiceState = { error: null as Error | null, hasNextPage: false, isLoading: false, + isFetching: false, isFetchingNextPage: false, } @@ -97,21 +115,85 @@ const defaultAppData = { } vi.mock('@/service/use-apps', () => ({ - useInfiniteAppList: () => ({ - data: defaultAppData, - isLoading: mockServiceState.isLoading, - isFetchingNextPage: mockServiceState.isFetchingNextPage, - fetchNextPage: mockFetchNextPage, - hasNextPage: mockServiceState.hasNextPage, - error: mockServiceState.error, - refetch: mockRefetch, - }), + useInfiniteAppList: (params: unknown, options: unknown) => { + mockUseInfiniteAppList(params, options) + return { + data: defaultAppData, + isLoading: mockServiceState.isLoading, + isFetching: mockServiceState.isFetching, + isFetchingNextPage: mockServiceState.isFetchingNextPage, + fetchNextPage: mockFetchNextPage, + hasNextPage: mockServiceState.hasNextPage, + error: mockServiceState.error, + refetch: mockRefetch, + } + }, useDeleteAppMutation: () => ({ mutateAsync: vi.fn(), isPending: false, }), })) +const mockSnippetServiceState = { + error: null as Error | null, + hasNextPage: false, + isLoading: false, + isFetching: false, + isFetchingNextPage: false, +} + +const defaultSnippetData = { + pages: [{ + data: [ + { + id: 'snippet-1', + name: 'Tone Rewriter', + description: 'Rewrites rough drafts into a concise, professional tone for internal stakeholder updates.', + type: 'node', + is_published: false, + use_count: 19, + icon_info: { + icon_type: 'emoji', + icon: '🪄', + icon_background: '#E0EAFF', + icon_url: '', + }, + created_at: 1704067200, + updated_at: '2024-01-02 10:00', + author: '', + }, + ], + total: 1, + }], +} + +vi.mock('@/service/use-snippets', () => ({ + useInfiniteSnippetList: (params: unknown, options: unknown) => { + mockUseInfiniteSnippetList(params, options) + return { + data: defaultSnippetData, + isLoading: mockSnippetServiceState.isLoading, + isFetching: mockSnippetServiceState.isFetching, + isFetchingNextPage: mockSnippetServiceState.isFetchingNextPage, + fetchNextPage: mockFetchSnippetNextPage, + hasNextPage: mockSnippetServiceState.hasNextPage, + error: mockSnippetServiceState.error, + } + }, + useCreateSnippetMutation: () => ({ + mutate: vi.fn(), + isPending: false, + }), + useImportSnippetDSLMutation: () => ({ + mutate: vi.fn(), + isPending: false, + }), + useConfirmSnippetImportMutation: () => ({ + mutate: vi.fn(), + isPending: false, + }), +})) + vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([{ id: 'tag-1', name: 'Test Tag', type: 'app' }]), })) @@ -124,6 +206,17 @@ vi.mock('@/config', async (importOriginal) => { } }) +vi.mock('@/service/use-common', () => ({ + useMembers: () => ({ + data: { + accounts: [ + { id: 'user-1', name: 'Current User', email: 'current@example.com', avatar: '', avatar_url: '', role: 'owner', last_login_at: '', created_at: '', status: 'active' }, + { id: 'user-2', name: 'Alice', email: 'alice@example.com', avatar: '', avatar_url: '', role: 'admin', last_login_at: '', created_at: '', status: 'active' }, + ], + }, + }), +})) + vi.mock('@/hooks/use-pay', () => ({ CheckModal: () => null, })) @@ -137,13 +230,21 @@ vi.mock('@/next/dynamic', () => ({ return React.createElement('div', { 'data-testid': 'tag-management-modal' }) } } + if (fnString.includes('create-from-dsl-modal')) { return function MockCreateFromDSLModal({ show, onClose, onSuccess }: { show: boolean, onClose: () => void, onSuccess: () => void }) { if (!show) return null - return React.createElement('div', { 'data-testid': 'create-dsl-modal' }, React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-dsl-modal' }, 'Close'), React.createElement('button', { 'onClick': onSuccess, 'data-testid': 'success-dsl-modal' }, 'Success')) + + return React.createElement( + 'div', + { 'data-testid': 'create-dsl-modal' }, + React.createElement('button', { 'data-testid': 'close-dsl-modal', 'onClick': onClose }, 'Close'), + React.createElement('button', { 'data-testid': 'success-dsl-modal', 'onClick': onSuccess }, 'Success'), + ) } } + return () => null }, })) @@ -161,8 +262,8 @@ vi.mock('../new-app-card', () => ({ })) vi.mock('../empty', () => ({ - default: () => { - return React.createElement('div', { 'data-testid': 'empty-state', 'role': 'status' }, 'No apps found') + default: ({ message }: { message: string }) => { + return React.createElement('div', { 'data-testid': 'empty-state', 'role': 'status' }, message) }, })) @@ -192,151 +293,105 @@ beforeAll(() => { } as unknown as typeof IntersectionObserver }) -// Render helper wrapping with shared nuqs testing helper. -const renderList = (searchParams = '') => { - return renderWithNuqs(, { searchParams }) +const renderList = (props: React.ComponentProps = {}, searchParams = '') => { + return renderWithNuqs(, { searchParams }) } describe('List', () => { beforeEach(() => { vi.clearAllMocks() + defaultSnippetData.pages[0].data = [ + { + id: 'snippet-1', + name: 'Tone Rewriter', + description: 'Rewrites rough drafts into a concise, professional tone for internal stakeholder updates.', + type: 'node', + is_published: false, + use_count: 19, + icon_info: { + icon_type: 'emoji', + icon: '🪄', + icon_background: '#E0EAFF', + icon_url: '', + }, + created_at: 1704067200, + updated_at: '2024-01-02 10:00', + author: '', + }, + ] + defaultSnippetData.pages[0].total = 1 useTagStore.setState({ tagList: [{ id: 'tag-1', name: 'Test Tag', type: 'app', binding_count: 0 }], showTagManagementModal: false, }) mockIsCurrentWorkspaceEditor.mockReturnValue(true) mockIsCurrentWorkspaceDatasetOperator.mockReturnValue(false) + mockIsLoadingCurrentWorkspace.mockReturnValue(false) + mockCanAccessSnippetsAndEvaluation.mockReturnValue(true) mockDragging = false mockOnDSLFileDropped = null mockServiceState.error = null mockServiceState.hasNextPage = false mockServiceState.isLoading = false + mockServiceState.isFetching = false mockServiceState.isFetchingNextPage = false mockQueryState.tagIDs = [] + mockQueryState.creatorIDs = [] mockQueryState.keywords = '' mockQueryState.isCreatedByMe = false + mockSnippetServiceState.error = null + mockSnippetServiceState.hasNextPage = false + mockSnippetServiceState.isLoading = false + mockSnippetServiceState.isFetching = false + mockSnippetServiceState.isFetchingNextPage = false + mockUseInfiniteAppList.mockClear() + mockUseInfiniteSnippetList.mockClear() intersectionCallback = null localStorage.clear() }) - describe('Rendering', () => { - it('should render without crashing', () => { - renderList() - expect(screen.getByText('app.types.all'))!.toBeInTheDocument() - }) - - it('should render tab slider with all app types', () => { + describe('Apps Mode', () => { + it('should render the apps route switch, dropdown filters, and app cards', () => { renderList() - expect(screen.getByText('app.types.all'))!.toBeInTheDocument() - expect(screen.getByText('app.types.workflow'))!.toBeInTheDocument() - expect(screen.getByText('app.types.advanced'))!.toBeInTheDocument() - expect(screen.getByText('app.types.chatbot'))!.toBeInTheDocument() - expect(screen.getByText('app.types.agent'))!.toBeInTheDocument() - expect(screen.getByText('app.types.completion'))!.toBeInTheDocument() + expect(screen.getByRole('link', { name: 'app.studio.apps' })).toHaveAttribute('href', '/apps') + expect(screen.getByRole('link', { name: 'workflow.tabs.snippets' })).toHaveAttribute('href', '/snippets') + expect(screen.getByText('app.studio.filters.types')).toBeInTheDocument() + expect(screen.getByText('app.studio.filters.allCreators')).toBeInTheDocument() + expect(screen.getByText('common.tag.placeholder')).toBeInTheDocument() + expect(screen.getByTestId('app-card-app-1')).toBeInTheDocument() + expect(screen.getByTestId('app-card-app-2')).toBeInTheDocument() + expect(screen.getByTestId('new-app-card')).toBeInTheDocument() }) - it('should render search input', () => { - renderList() - expect(screen.getByRole('textbox'))!.toBeInTheDocument() - }) - - it('should render tag filter', () => { - renderList() - expect(screen.getByText('common.tag.placeholder'))!.toBeInTheDocument() - }) - - it('should render created by me checkbox', () => { - renderList() - expect(screen.getByText('app.showMyCreatedAppsOnly'))!.toBeInTheDocument() - }) - - it('should render app cards when apps exist', () => { - renderList() - - expect(screen.getByTestId('app-card-app-1'))!.toBeInTheDocument() - expect(screen.getByTestId('app-card-app-2'))!.toBeInTheDocument() - }) - - it('should render new app card for editors', () => { - renderList() - expect(screen.getByTestId('new-app-card'))!.toBeInTheDocument() - }) - - it('should render footer when branding is disabled', () => { - renderList() - expect(screen.getByTestId('footer'))!.toBeInTheDocument() - }) - - it('should render drop DSL hint for editors', () => { - renderList() - expect(screen.getByText('app.newApp.dropDSLToCreateApp'))!.toBeInTheDocument() - }) - }) - - describe('Tab Navigation', () => { - it('should update URL when workflow tab is clicked', async () => { + it('should update the category query when selecting an app type from the dropdown', async () => { const { onUrlUpdate } = renderList() - fireEvent.click(screen.getByText('app.types.workflow')) + fireEvent.click(screen.getByText('app.studio.filters.types')) + fireEvent.click(await screen.findByText('app.types.workflow')) - await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] expect(lastCall.searchParams.get('category')).toBe(AppModeEnum.WORKFLOW) }) - it('should update URL when all tab is clicked', async () => { - const { onUrlUpdate } = renderList('?category=workflow') - - fireEvent.click(screen.getByText('app.types.all')) - - await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] - // nuqs removes the default value ('all') from URL params - expect(lastCall.searchParams.has('category')).toBe(false) - }) - }) - - describe('Search Functionality', () => { - it('should render search input field', () => { - renderList() - expect(screen.getByRole('textbox'))!.toBeInTheDocument() - }) - - it('should handle search input change', () => { + it('should update creatorIDs when selecting a creator from the dropdown', async () => { renderList() - const input = screen.getByRole('textbox') - fireEvent.change(input, { target: { value: 'test search' } }) + fireEvent.click(screen.getByText('app.studio.filters.allCreators')) + fireEvent.click(await screen.findByText('Current User')) - expect(mockSetQuery).toHaveBeenCalled() + expect(mockSetQuery).toHaveBeenCalledTimes(1) }) - it('should handle search clear button click', () => { - mockQueryState.keywords = 'existing search' + it('should pass creator_id to the app list query when creatorIDs are selected', () => { + mockQueryState.creatorIDs = ['user-1', 'user-2'] renderList() - const clearButton = document.querySelector('.group') - expect(clearButton)!.toBeInTheDocument() - if (clearButton) - fireEvent.click(clearButton) - - expect(mockSetQuery).toHaveBeenCalled() - }) - }) - - describe('Tag Filter', () => { - it('should render tag filter component', () => { - renderList() - expect(screen.getByText('common.tag.placeholder'))!.toBeInTheDocument() - }) - }) - - describe('Created By Me Filter', () => { - it('should render checkbox with correct label', () => { - renderList() - expect(screen.getByText('app.showMyCreatedAppsOnly'))!.toBeInTheDocument() + expect(mockUseInfiniteAppList).toHaveBeenCalledWith(expect.objectContaining({ + creator_id: 'user-1,user-2', + }), expect.any(Object)) }) it('should handle checkbox change', () => { @@ -391,39 +446,39 @@ describe('List', () => { describe('Edge Cases', () => { it('should handle multiple renders without issues', () => { const { unmount } = renderWithNuqs() - expect(screen.getByText('app.types.all'))!.toBeInTheDocument() + expect(screen.getByText('app.types.all')).toBeInTheDocument() unmount() renderList() - expect(screen.getByText('app.types.all'))!.toBeInTheDocument() + expect(screen.getByText('app.types.all')).toBeInTheDocument() }) it('should render app cards correctly', () => { renderList() - expect(screen.getByText('Test App 1'))!.toBeInTheDocument() - expect(screen.getByText('Test App 2'))!.toBeInTheDocument() + expect(screen.getByText('Test App 1')).toBeInTheDocument() + expect(screen.getByText('Test App 2')).toBeInTheDocument() }) it('should render with all filter options visible', () => { renderList() - expect(screen.getByRole('textbox'))!.toBeInTheDocument() - expect(screen.getByText('common.tag.placeholder'))!.toBeInTheDocument() - expect(screen.getByText('app.showMyCreatedAppsOnly'))!.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByText('common.tag.placeholder')).toBeInTheDocument() + expect(screen.getByText('app.showMyCreatedAppsOnly')).toBeInTheDocument() }) }) describe('Dragging State', () => { it('should show drop hint when DSL feature is enabled for editors', () => { renderList() - expect(screen.getByText('app.newApp.dropDSLToCreateApp'))!.toBeInTheDocument() + expect(screen.getByText('app.newApp.dropDSLToCreateApp')).toBeInTheDocument() }) it('should render dragging state overlay when dragging', () => { mockDragging = true const { container } = renderList() - expect(container)!.toBeInTheDocument() + expect(container).toBeInTheDocument() }) }) @@ -431,12 +486,12 @@ describe('List', () => { it('should render all app type tabs', () => { renderList() - expect(screen.getByText('app.types.all'))!.toBeInTheDocument() - expect(screen.getByText('app.types.workflow'))!.toBeInTheDocument() - expect(screen.getByText('app.types.advanced'))!.toBeInTheDocument() - expect(screen.getByText('app.types.chatbot'))!.toBeInTheDocument() - expect(screen.getByText('app.types.agent'))!.toBeInTheDocument() - expect(screen.getByText('app.types.completion'))!.toBeInTheDocument() + expect(screen.getByText('app.types.all')).toBeInTheDocument() + expect(screen.getByText('app.types.workflow')).toBeInTheDocument() + expect(screen.getByText('app.types.advanced')).toBeInTheDocument() + expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() + expect(screen.getByText('app.types.agent')).toBeInTheDocument() + expect(screen.getByText('app.types.completion')).toBeInTheDocument() }) it('should update URL for each app type tab click', async () => { @@ -454,7 +509,7 @@ describe('List', () => { onUrlUpdate.mockClear() fireEvent.click(screen.getByText(text)) await vi.waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0] + const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] expect(lastCall.searchParams.get('category')).toBe(mode) } }) @@ -464,22 +519,22 @@ describe('List', () => { it('should display all app cards from data', () => { renderList() - expect(screen.getByTestId('app-card-app-1'))!.toBeInTheDocument() - expect(screen.getByTestId('app-card-app-2'))!.toBeInTheDocument() + expect(screen.getByTestId('app-card-app-1')).toBeInTheDocument() + expect(screen.getByTestId('app-card-app-2')).toBeInTheDocument() }) it('should display app names correctly', () => { renderList() - expect(screen.getByText('Test App 1'))!.toBeInTheDocument() - expect(screen.getByText('Test App 2'))!.toBeInTheDocument() + expect(screen.getByText('Test App 1')).toBeInTheDocument() + expect(screen.getByText('Test App 2')).toBeInTheDocument() }) }) describe('Footer Visibility', () => { it('should render footer when branding is disabled', () => { renderList() - expect(screen.getByTestId('footer'))!.toBeInTheDocument() + expect(screen.getByTestId('footer')).toBeInTheDocument() }) }) @@ -493,99 +548,79 @@ describe('List', () => { mockOnDSLFileDropped(mockFile) }) - expect(screen.getByTestId('create-dsl-modal'))!.toBeInTheDocument() - }) - - it('should close DSL modal when onClose is called', () => { - renderList() - - const mockFile = new File(['test content'], 'test.yml', { type: 'application/yaml' }) - act(() => { - if (mockOnDSLFileDropped) - mockOnDSLFileDropped(mockFile) - }) - - expect(screen.getByTestId('create-dsl-modal'))!.toBeInTheDocument() - + expect(screen.getByTestId('create-dsl-modal')).toBeInTheDocument() fireEvent.click(screen.getByTestId('close-dsl-modal')) - expect(screen.queryByTestId('create-dsl-modal')).not.toBeInTheDocument() }) - it('should close DSL modal and refetch when onSuccess is called', () => { + it('should hide the snippets route switch when snippet access is unavailable', () => { + mockCanAccessSnippetsAndEvaluation.mockReturnValue(false) + renderList() - const mockFile = new File(['test content'], 'test.yml', { type: 'application/yaml' }) + expect(screen.getByRole('link', { name: 'app.studio.apps' })).toHaveAttribute('href', '/apps') + expect(screen.queryByRole('link', { name: 'workflow.tabs.snippets' })).not.toBeInTheDocument() + }) + }) + + describe('Snippets Mode', () => { + it('should render the snippets create card and snippet card from the real query hook', () => { + renderList({ pageType: 'snippets' }) + + expect(screen.getByText('snippet.create')).toBeInTheDocument() + expect(screen.getByText('Tone Rewriter')).toBeInTheDocument() + expect(screen.getByText('Rewrites rough drafts into a concise, professional tone for internal stakeholder updates.')).toBeInTheDocument() + expect(screen.getByRole('link', { name: /Tone Rewriter/i })).toHaveAttribute('href', '/snippets/snippet-1/orchestrate') + expect(screen.queryByTestId('new-app-card')).not.toBeInTheDocument() + expect(screen.queryByTestId('app-card-app-1')).not.toBeInTheDocument() + }) + + it('should request the next snippet page when the infinite-scroll anchor intersects', () => { + mockSnippetServiceState.hasNextPage = true + renderList({ pageType: 'snippets' }) + act(() => { - if (mockOnDSLFileDropped) - mockOnDSLFileDropped(mockFile) + intersectionCallback?.([{ isIntersecting: true } as IntersectionObserverEntry], {} as IntersectionObserver) }) - expect(screen.getByTestId('create-dsl-modal'))!.toBeInTheDocument() - - fireEvent.click(screen.getByTestId('success-dsl-modal')) - - expect(screen.queryByTestId('create-dsl-modal')).not.toBeInTheDocument() - expect(mockRefetch).toHaveBeenCalled() - }) - }) - - describe('Infinite Scroll', () => { - it('should call fetchNextPage when intersection observer triggers', () => { - mockServiceState.hasNextPage = true - renderList() - - if (intersectionCallback) { - act(() => { - intersectionCallback!( - [{ isIntersecting: true } as IntersectionObserverEntry], - {} as IntersectionObserver, - ) - }) - } - - expect(mockFetchNextPage).toHaveBeenCalled() + expect(mockFetchSnippetNextPage).toHaveBeenCalled() }) - it('should not call fetchNextPage when not intersecting', () => { - mockServiceState.hasNextPage = true - renderList() + it('should not render app-only controls in snippets mode', () => { + renderList({ pageType: 'snippets' }) - if (intersectionCallback) { - act(() => { - intersectionCallback!( - [{ isIntersecting: false } as IntersectionObserverEntry], - {} as IntersectionObserver, - ) - }) - } - - expect(mockFetchNextPage).not.toHaveBeenCalled() + expect(screen.queryByText('app.studio.filters.types')).not.toBeInTheDocument() + expect(screen.queryByText('common.tag.placeholder')).not.toBeInTheDocument() + expect(screen.queryByText('app.newApp.dropDSLToCreateApp')).not.toBeInTheDocument() }) - it('should not call fetchNextPage when loading', () => { - mockServiceState.hasNextPage = true - mockServiceState.isLoading = true - renderList() + it('should pass creator_id to the snippet list query when creatorIDs are selected', () => { + mockQueryState.creatorIDs = ['user-1', 'user-2'] - if (intersectionCallback) { - act(() => { - intersectionCallback!( - [{ isIntersecting: true } as IntersectionObserverEntry], - {} as IntersectionObserver, - ) - }) - } + renderList({ pageType: 'snippets' }) - expect(mockFetchNextPage).not.toHaveBeenCalled() + expect(mockUseInfiniteSnippetList).toHaveBeenCalledWith(expect.objectContaining({ + creator_id: 'user-1,user-2', + }), expect.any(Object)) }) - }) - describe('Error State', () => { - it('should handle error state in useEffect', () => { - mockServiceState.error = new Error('Test error') - const { container } = renderList() - expect(container)!.toBeInTheDocument() + it('should not fetch the next snippet page when no more data is available', () => { + renderList({ pageType: 'snippets' }) + + act(() => { + intersectionCallback?.([{ isIntersecting: true } as IntersectionObserverEntry], {} as IntersectionObserver) + }) + + expect(mockFetchSnippetNextPage).not.toHaveBeenCalled() + }) + + it('should reuse the shared empty state when no snippets are available', () => { + defaultSnippetData.pages[0].data = [] + defaultSnippetData.pages[0].total = 0 + + renderList({ pageType: 'snippets' }) + + expect(screen.getByTestId('empty-state')).toHaveTextContent('workflow.tabs.noSnippetsFound') }) }) }) diff --git a/web/app/components/apps/app-type-filter-shared.ts b/web/app/components/apps/app-type-filter-shared.ts new file mode 100644 index 0000000000..26b279ae2f --- /dev/null +++ b/web/app/components/apps/app-type-filter-shared.ts @@ -0,0 +1,16 @@ +import { parseAsStringLiteral } from 'nuqs' +import { AppModes } from '@/types/app' + +const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const +type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number] +export type { AppListCategory } + +const appListCategorySet = new Set(APP_LIST_CATEGORY_VALUES) + +export const isAppListCategory = (value: string): value is AppListCategory => { + return appListCategorySet.has(value) +} + +export const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) + .withDefault('all') + .withOptions({ history: 'push' }) diff --git a/web/app/components/apps/app-type-filter.tsx b/web/app/components/apps/app-type-filter.tsx new file mode 100644 index 0000000000..a1401100ae --- /dev/null +++ b/web/app/components/apps/app-type-filter.tsx @@ -0,0 +1,72 @@ +'use client' + +import type { AppListCategory } from './app-type-filter-shared' +import { cn } from '@langgenius/dify-ui/cn' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuRadioGroup, + DropdownMenuRadioItem, + DropdownMenuRadioItemIndicator, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' +import { AppModeEnum } from '@/types/app' +import { isAppListCategory } from './app-type-filter-shared' + +const chipClassName = 'flex h-8 items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2 text-[13px] leading-[18px] text-text-secondary hover:bg-components-input-bg-hover' + +type AppTypeFilterProps = { + activeTab: AppListCategory + onChange: (value: AppListCategory) => void +} + +const AppTypeFilter = ({ + activeTab, + onChange, +}: AppTypeFilterProps) => { + const { t } = useTranslation() + + const options = useMemo(() => ([ + { value: 'all', text: t('types.all', { ns: 'app' }), iconClassName: 'i-ri-apps-2-line' }, + { value: AppModeEnum.WORKFLOW, text: t('types.workflow', { ns: 'app' }), iconClassName: 'i-ri-exchange-2-line' }, + { value: AppModeEnum.ADVANCED_CHAT, text: t('types.advanced', { ns: 'app' }), iconClassName: 'i-ri-message-3-line' }, + { value: AppModeEnum.CHAT, text: t('types.chatbot', { ns: 'app' }), iconClassName: 'i-ri-message-3-line' }, + { value: AppModeEnum.AGENT_CHAT, text: t('types.agent', { ns: 'app' }), iconClassName: 'i-ri-robot-3-line' }, + { value: AppModeEnum.COMPLETION, text: t('types.completion', { ns: 'app' }), iconClassName: 'i-ri-file-4-line' }, + ]), [t]) + + const activeOption = options.find(option => option.value === activeTab) + const triggerLabel = activeTab === 'all' ? t('studio.filters.types', { ns: 'app' }) : activeOption?.text + + return ( + + + )} + > + + {triggerLabel} + + + + isAppListCategory(value) && onChange(value)}> + {options.map(option => ( + + + {option.text} + + + ))} + + + + ) +} + +export default AppTypeFilter diff --git a/web/app/components/apps/creators-filter.tsx b/web/app/components/apps/creators-filter.tsx new file mode 100644 index 0000000000..9a00ccab6f --- /dev/null +++ b/web/app/components/apps/creators-filter.tsx @@ -0,0 +1,219 @@ +'use client' + +import { cn } from '@langgenius/dify-ui/cn' +import { useCallback, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' +import Input from '@/app/components/base/input' +import { Avatar } from '@/app/components/base/ui/avatar' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' +import { useAppContext } from '@/context/app-context' +import { useMembers } from '@/service/use-common' + +type CreatorsFilterProps = { + value: string[] + onChange: (value: string[]) => void +} + +type CreatorOption = { + id: string + name: string + avatarUrl: string | null + isYou: boolean +} + +const baseChipClassName = 'flex h-8 items-center rounded-lg border-[0.5px] px-2 text-[13px] leading-4 transition-colors' + +const CreatorsFilter = ({ + value, + onChange, +}: CreatorsFilterProps) => { + const { t } = useTranslation() + const { userProfile } = useAppContext() + const { data: membersData } = useMembers() + const [keywords, setKeywords] = useState('') + + const creatorOptions = useMemo(() => { + const currentUserId = userProfile?.id + const members = membersData?.accounts ?? [] + + return [...members] + .filter(member => member.status !== 'pending') + .sort((left, right) => { + if (left.id === currentUserId) + return -1 + if (right.id === currentUserId) + return 1 + return left.name.localeCompare(right.name) + }) + .map(member => ({ + id: member.id, + name: member.name, + avatarUrl: member.avatar_url, + isYou: member.id === currentUserId, + })) + }, [membersData?.accounts, userProfile?.id]) + + const filteredCreators = useMemo(() => { + const normalizedKeywords = keywords.trim().toLowerCase() + if (!normalizedKeywords) + return creatorOptions + + return creatorOptions.filter((creator) => { + const keyword = normalizedKeywords + return creator.name.toLowerCase().includes(keyword) + }) + }, [creatorOptions, keywords]) + + const selectedCreators = useMemo(() => { + const creatorMap = new Map(creatorOptions.map(creator => [creator.id, creator])) + return value + .map(id => creatorMap.get(id)) + .filter((creator): creator is CreatorOption => Boolean(creator)) + }, [creatorOptions, value]) + + const toggleCreator = useCallback((creatorId: string) => { + if (value.includes(creatorId)) { + onChange(value.filter(id => id !== creatorId)) + return + } + + onChange([...value, creatorId]) + }, [onChange, value]) + + const resetCreators = useCallback(() => { + onChange([]) + setKeywords('') + }, [onChange]) + + const selectedCount = value.length + const selectedAvatarCreators = selectedCreators.slice(0, 3) + const isSelected = selectedCount > 0 + + return ( + + + )} + > + + {!isSelected && ( + <> + {t('studio.filters.allCreators', { ns: 'app' })} + + + )} + {isSelected && ( + <> + {t('studio.filters.creators', { ns: 'app' })} + + {selectedAvatarCreators.map((creator, index) => ( + 0 && '-ml-1', + )} + /> + ))} + + {`+${selectedCount}`} + { + event.stopPropagation() + resetCreators() + }} + onKeyDown={(event) => { + if (event.key !== 'Enter' && event.key !== ' ') + return + + event.preventDefault() + event.stopPropagation() + resetCreators() + }} + > + + + + )} + + +
+ setKeywords(e.target.value)} + onClear={() => setKeywords('')} + placeholder={t('studio.filters.searchCreators', { ns: 'app' })} + /> + {isSelected && ( + + )} +
+
+ {filteredCreators.map((creator) => { + const checked = value.includes(creator.id) + + return ( + + ) + })} +
+
+
+ ) +} + +export default CreatorsFilter diff --git a/web/app/components/apps/empty.tsx b/web/app/components/apps/empty.tsx index 0dee3c908a..0876101d79 100644 --- a/web/app/components/apps/empty.tsx +++ b/web/app/components/apps/empty.tsx @@ -1,5 +1,4 @@ import * as React from 'react' -import { useTranslation } from 'react-i18next' const DefaultCards = React.memo(() => { const renderArray = Array.from({ length: 36 }) @@ -17,15 +16,17 @@ const DefaultCards = React.memo(() => { ) }) -const Empty = () => { - const { t } = useTranslation() +type Props = { + message: string +} +const Empty = ({ message }: Props) => { return ( <>
- {t('newApp.noAppsFound', { ns: 'app' })} + {message}
diff --git a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx index 4b0c63f580..d5734cce07 100644 --- a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx +++ b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx @@ -23,6 +23,7 @@ describe('useAppsQueryState', () => { const { result } = renderWithAdapter() expect(result.current.query.tagIDs).toBeUndefined() + expect(result.current.query.creatorIDs).toBeUndefined() expect(result.current.query.keywords).toBeUndefined() expect(result.current.query.isCreatedByMe).toBe(false) }) @@ -41,6 +42,12 @@ describe('useAppsQueryState', () => { expect(result.current.query.keywords).toBe('search term') }) + it('should parse creatorIDs when URL includes creatorIDs', () => { + const { result } = renderWithAdapter('?creatorIDs=user-1;user-2') + + expect(result.current.query.creatorIDs).toEqual(['user-1', 'user-2']) + }) + it('should parse isCreatedByMe when URL includes true value', () => { const { result } = renderWithAdapter('?isCreatedByMe=true') @@ -49,10 +56,11 @@ describe('useAppsQueryState', () => { it('should parse all params when URL includes multiple filters', () => { const { result } = renderWithAdapter( - '?tagIDs=tag1;tag2&keywords=test&isCreatedByMe=true', + '?tagIDs=tag1;tag2&creatorIDs=user-1;user-2&keywords=test&isCreatedByMe=true', ) expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2']) + expect(result.current.query.creatorIDs).toEqual(['user-1', 'user-2']) expect(result.current.query.keywords).toBe('test') expect(result.current.query.isCreatedByMe).toBe(true) }) @@ -79,6 +87,16 @@ describe('useAppsQueryState', () => { expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2']) }) + it('should update creatorIDs when setQuery receives creatorIDs', () => { + const { result } = renderWithAdapter() + + act(() => { + result.current.setQuery({ creatorIDs: ['user-1', 'user-2'] }) + }) + + expect(result.current.query.creatorIDs).toEqual(['user-1', 'user-2']) + }) + it('should update isCreatedByMe when setQuery receives true', () => { const { result } = renderWithAdapter() @@ -131,6 +149,18 @@ describe('useAppsQueryState', () => { expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2') }) + it('should sync creatorIDs to URL when creatorIDs change', async () => { + const { result, onUrlUpdate } = renderWithAdapter() + + act(() => { + result.current.setQuery({ creatorIDs: ['user-1', 'user-2'] }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.get('creatorIDs')).toBe('user-1;user-2') + }) + it('should sync isCreatedByMe to URL when enabled', async () => { const { result, onUrlUpdate } = renderWithAdapter() @@ -167,6 +197,18 @@ describe('useAppsQueryState', () => { expect(update.searchParams.has('tagIDs')).toBe(false) }) + it('should remove creatorIDs from URL when creatorIDs are empty', async () => { + const { result, onUrlUpdate } = renderWithAdapter('?creatorIDs=user-1;user-2') + + act(() => { + result.current.setQuery({ creatorIDs: [] }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.searchParams.has('creatorIDs')).toBe(false) + }) + it('should remove isCreatedByMe from URL when disabled', async () => { const { result, onUrlUpdate } = renderWithAdapter('?isCreatedByMe=true') @@ -212,12 +254,17 @@ describe('useAppsQueryState', () => { result.current.setQuery(prev => ({ ...prev, tagIDs: ['tag1'] })) }) + act(() => { + result.current.setQuery(prev => ({ ...prev, creatorIDs: ['user-1'] })) + }) + act(() => { result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true })) }) expect(result.current.query.keywords).toBe('first') expect(result.current.query.tagIDs).toEqual(['tag1']) + expect(result.current.query.creatorIDs).toEqual(['user-1']) expect(result.current.query.isCreatedByMe).toBe(true) }) }) diff --git a/web/app/components/apps/hooks/use-apps-query-state.ts b/web/app/components/apps/hooks/use-apps-query-state.ts index ecf7707e8a..50ae13a425 100644 --- a/web/app/components/apps/hooks/use-apps-query-state.ts +++ b/web/app/components/apps/hooks/use-apps-query-state.ts @@ -3,6 +3,7 @@ import { useCallback, useMemo } from 'react' type AppsQuery = { tagIDs?: string[] + creatorIDs?: string[] keywords?: string isCreatedByMe?: boolean } @@ -13,6 +14,7 @@ function useAppsQueryState() { const [urlQuery, setUrlQuery] = useQueryStates( { tagIDs: parseAsArrayOf(parseAsString, ';'), + creatorIDs: parseAsArrayOf(parseAsString, ';'), keywords: parseAsString, isCreatedByMe: parseAsBoolean, }, @@ -23,15 +25,18 @@ function useAppsQueryState() { const query = useMemo(() => ({ tagIDs: urlQuery.tagIDs ?? undefined, + creatorIDs: urlQuery.creatorIDs ?? undefined, keywords: normalizeKeywords(urlQuery.keywords), isCreatedByMe: urlQuery.isCreatedByMe ?? false, - }), [urlQuery.isCreatedByMe, urlQuery.keywords, urlQuery.tagIDs]) + }), [urlQuery.creatorIDs, urlQuery.isCreatedByMe, urlQuery.keywords, urlQuery.tagIDs]) const setQuery = useCallback((next: AppsQuery | ((prev: AppsQuery) => AppsQuery)) => { const buildPatch = (patch: AppsQuery) => { const result: Partial = {} if ('tagIDs' in patch) result.tagIDs = patch.tagIDs && patch.tagIDs.length > 0 ? patch.tagIDs : null + if ('creatorIDs' in patch) + result.creatorIDs = patch.creatorIDs && patch.creatorIDs.length > 0 ? patch.creatorIDs : null if ('keywords' in patch) result.keywords = patch.keywords ? patch.keywords : null if ('isCreatedByMe' in patch) @@ -42,6 +47,7 @@ function useAppsQueryState() { if (typeof next === 'function') { setUrlQuery(prev => buildPatch(next({ tagIDs: prev.tagIDs ?? undefined, + creatorIDs: prev.creatorIDs ?? undefined, keywords: normalizeKeywords(prev.keywords), isCreatedByMe: prev.isCreatedByMe ?? false, }))) diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index 9bf07e81e6..9f23e42bb9 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -13,14 +13,24 @@ import { fetchAppDetail } from '@/service/explore' import { trackCreateApp } from '@/utils/create-app-tracking' import List from './list' +export type StudioPageType = 'apps' | 'snippets' + +type AppsProps = { + pageType?: StudioPageType +} + const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) -const Apps = () => { +const Apps = ({ + pageType = 'apps', +}: AppsProps) => { const { t } = useTranslation() - useDocumentTitle(t('menus.apps', { ns: 'common' })) + useDocumentTitle(pageType === 'apps' + ? t('menus.apps', { ns: 'common' }) + : t('tabs.snippets', { ns: 'workflow' })) useEducationInit() const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) @@ -116,7 +126,7 @@ const Apps = () => { }} >
- + {isShowTryAppPanel && ( import('@/app/components/base/tag-management'), { ssr: false, @@ -35,25 +43,18 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro ssr: false, }) -const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const -type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number] -const appListCategorySet = new Set(APP_LIST_CATEGORY_VALUES) - -const isAppListCategory = (value: string): value is AppListCategory => { - return appListCategorySet.has(value) -} - -const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) - .withDefault('all') - .withOptions({ history: 'push' }) - type Props = { controlRefreshList?: number + pageType?: StudioPageType } + const List: FC = ({ controlRefreshList = 0, + pageType = 'apps', }) => { const { t } = useTranslation() + const isAppsPage = pageType === 'apps' + const { canAccess: canAccessSnippetsAndEvaluation } = useSnippetAndEvaluationPlanAccess() const { systemFeatures } = useGlobalPublicStore() const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() const showTagManagementModal = useTagStore(s => s.showTagManagementModal) @@ -62,20 +63,28 @@ const List: FC = ({ parseAsAppListCategory, ) - const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() - const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe) + const { query: { tagIDs = [], creatorIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() const [tagFilterValue, setTagFilterValue] = useState(tagIDs) - const [searchKeywords, setSearchKeywords] = useState(keywords) - const newAppCardRef = useRef(null) - const containerRef = useRef(null) + const [appKeywords, setAppKeywords] = useState(keywords) + const [snippetKeywordsInput, setSnippetKeywordsInput] = useState('') + const [snippetKeywords, setSnippetKeywords] = useState('') const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false) const [droppedDSLFile, setDroppedDSLFile] = useState() + const containerRef = useRef(null) + const anchorRef = useRef(null) + const newAppCardRef = useRef(null) + const [workflowOnlineUsersMap, setWorkflowOnlineUsersMap] = useState>({}) const setKeywords = useCallback((keywords: string) => { setQuery(prev => ({ ...prev, keywords })) }, [setQuery]) - const setTagIDs = useCallback((tagIDs: string[]) => { - setQuery(prev => ({ ...prev, tagIDs })) + + const setTagIDs = useCallback((nextTagIDs: string[]) => { + setQuery(prev => ({ ...prev, tagIDs: nextTagIDs })) + }, [setQuery]) + + const setCreatorIDs = useCallback((nextCreatorIDs: string[]) => { + setQuery(prev => ({ ...prev, creatorIDs: nextCreatorIDs })) }, [setQuery]) const handleDSLFileDropped = useCallback((file: File) => { @@ -86,15 +95,16 @@ const List: FC = ({ const { dragging } = useDSLDragDrop({ onDSLFileDropped: handleDSLFileDropped, containerRef, - enabled: isCurrentWorkspaceEditor, + enabled: isAppsPage && isCurrentWorkspaceEditor, }) const appListQueryParams = { page: 1, limit: 30, - name: searchKeywords, + name: appKeywords, tag_ids: tagIDs, - is_created_by_me: isCreatedByMe, + is_created_by_me: queryIsCreatedByMe, + ...(creatorIDs.length > 0 ? { creator_id: creatorIDs.join(',') } : {}), ...(activeTab !== 'all' ? { mode: activeTab } : {}), } @@ -107,84 +117,125 @@ const List: FC = ({ hasNextPage, error, refetch, - } = useInfiniteAppList(appListQueryParams, { enabled: !isCurrentWorkspaceDatasetOperator }) + } = useInfiniteAppList(appListQueryParams, { + enabled: isAppsPage && !isCurrentWorkspaceDatasetOperator, + }) + + const { + data: snippetData, + isLoading: isSnippetListLoading, + isFetching: isSnippetListFetching, + isFetchingNextPage: isSnippetListFetchingNextPage, + fetchNextPage: fetchSnippetNextPage, + hasNextPage: hasSnippetNextPage, + error: snippetError, + } = useInfiniteSnippetList({ + page: 1, + limit: 30, + keyword: snippetKeywords || undefined, + creator_id: creatorIDs.length > 0 ? creatorIDs.join(',') : undefined, + }, { + enabled: !isAppsPage, + }) useEffect(() => { - if (controlRefreshList > 0) { + if (isAppsPage && controlRefreshList > 0) refetch() - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [controlRefreshList]) - - const anchorRef = useRef(null) - const options = [ - { value: 'all', text: t('types.all', { ns: 'app' }), icon: }, - { value: AppModeEnum.WORKFLOW, text: t('types.workflow', { ns: 'app' }), icon: }, - { value: AppModeEnum.ADVANCED_CHAT, text: t('types.advanced', { ns: 'app' }), icon: }, - { value: AppModeEnum.CHAT, text: t('types.chatbot', { ns: 'app' }), icon: }, - { value: AppModeEnum.AGENT_CHAT, text: t('types.agent', { ns: 'app' }), icon: }, - { value: AppModeEnum.COMPLETION, text: t('types.completion', { ns: 'app' }), icon: }, - ] + }, [controlRefreshList, isAppsPage, refetch]) useEffect(() => { + if (!isAppsPage) + return + if (localStorage.getItem(NEED_REFRESH_APP_LIST_KEY) === '1') { localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY) refetch() } - }, [refetch]) + }, [isAppsPage, refetch]) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) return - const hasMore = hasNextPage ?? true + + const hasMore = isAppsPage ? (hasNextPage ?? true) : (hasSnippetNextPage ?? true) + const isPageLoading = isAppsPage ? isLoading : isSnippetListLoading + const isNextPageFetching = isAppsPage ? isFetchingNextPage : isSnippetListFetchingNextPage + const currentError = isAppsPage ? error : snippetError let observer: IntersectionObserver | undefined - if (error) { - if (observer) - observer.disconnect() + if (currentError) { + observer?.disconnect() return } if (anchorRef.current && containerRef.current) { - // Calculate dynamic rootMargin: clamps to 100-200px range, using 20% of container height as the base value for better responsiveness const containerHeight = containerRef.current.clientHeight - const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200)) // Clamps to 100-200px range, using 20% of container height as the base value + const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200)) observer = new IntersectionObserver((entries) => { - if (entries[0]!.isIntersecting && !isLoading && !isFetchingNextPage && !error && hasMore) - fetchNextPage() + if (entries[0].isIntersecting && !isPageLoading && !isNextPageFetching && !currentError && hasMore) { + if (isAppsPage) + fetchNextPage() + else + fetchSnippetNextPage() + } }, { root: containerRef.current, rootMargin: `${dynamicMargin}px`, - threshold: 0.1, // Trigger when 10% of the anchor element is visible + threshold: 0.1, }) observer.observe(anchorRef.current) } + return () => observer?.disconnect() - }, [isLoading, isFetchingNextPage, fetchNextPage, error, hasNextPage, isCurrentWorkspaceDatasetOperator]) + }, [error, fetchNextPage, fetchSnippetNextPage, hasNextPage, hasSnippetNextPage, isAppsPage, isCurrentWorkspaceDatasetOperator, isFetchingNextPage, isLoading, isSnippetListFetchingNextPage, isSnippetListLoading, snippetError]) - const { run: handleSearch } = useDebounceFn(() => { - setSearchKeywords(keywords) + const { run: handleAppSearch } = useDebounceFn((value: string) => { + setAppKeywords(value) }, { wait: 500 }) - const handleKeywordsChange = (value: string) => { - setKeywords(value) - handleSearch() - } - const { run: handleTagsUpdate } = useDebounceFn(() => { - setTagIDs(tagFilterValue) + const { run: handleSnippetSearch } = useDebounceFn((value: string) => { + setSnippetKeywords(value) }, { wait: 500 }) - const handleTagsChange = (value: string[]) => { + + const handleKeywordsChange = useCallback((value: string) => { + if (isAppsPage) { + setKeywords(value) + handleAppSearch(value) + return + } + + setSnippetKeywordsInput(value) + handleSnippetSearch(value) + }, [handleAppSearch, handleSnippetSearch, isAppsPage, setKeywords]) + + const { run: handleTagsUpdate } = useDebounceFn((value: string[]) => { + setTagIDs(value) + }, { wait: 500 }) + + const handleTagsChange = useCallback((value: string[]) => { setTagFilterValue(value) - handleTagsUpdate() - } + handleTagsUpdate(value) + }, [handleTagsUpdate]) - const handleCreatedByMeChange = useCallback(() => { - const newValue = !isCreatedByMe - setIsCreatedByMe(newValue) - setQuery(prev => ({ ...prev, isCreatedByMe: newValue })) - }, [isCreatedByMe, setQuery]) + const appItems = useMemo(() => { + return (data?.pages ?? []).flatMap(({ data: apps }) => apps) + }, [data?.pages]) + const snippetItems = useMemo(() => { + return (snippetData?.pages ?? []).flatMap(({ data }) => data) + }, [snippetData?.pages]) + + const showSkeleton = isAppsPage + ? (isLoading || (isFetching && data?.pages?.length === 0)) + : (isSnippetListLoading || (isSnippetListFetching && snippetItems.length === 0)) + const hasAnyApp = (data?.pages?.[0]?.total ?? 0) > 0 + const hasAnySnippet = snippetItems.length > 0 + const currentKeywords = isAppsPage ? keywords : snippetKeywordsInput + const showEmptyState = !showSkeleton && (isAppsPage ? !hasAnyApp : !hasAnySnippet) + const emptyStateMessage = isAppsPage + ? t('newApp.noAppsFound', { ns: 'app' }) + : t('tabs.noSnippetsFound', { ns: 'workflow' }) const pages = data?.pages ?? [] const appIds = useMemo(() => { const ids = new Set() @@ -233,85 +284,99 @@ const List: FC = ({ return () => window.clearInterval(timer) }, [refetch, refreshWorkflowOnlineUsers, systemFeatures.enable_collaboration_mode]) - const hasAnyApp = (pages[0]?.total ?? 0) > 0 - // Show skeleton during initial load or when refetching with no previous data - const showSkeleton = isLoading || (isFetching && pages.length === 0) - return ( <>
{dragging && ( -
-
+
)}
- { - if (isAppListCategory(nextValue)) - setActiveTab(nextValue) - }} - options={options} - /> +
+ + {isAppsPage && ( + { + void setActiveTab(value) + }} + /> + )} + + {isAppsPage && ( + + )} +
+
- - handleKeywordsChange(e.target.value)} onClear={() => handleKeywordsChange('')} />
+
{(isCurrentWorkspaceEditor || isLoadingCurrentWorkspace) && ( - + isAppsPage + ? ( + + ) + : canAccessSnippetsAndEvaluation && )} - {(() => { - if (showSkeleton) - return - if (hasAnyApp) { - return pages.flatMap(({ data: apps }) => apps).map(app => ( - - )) - } + {showSkeleton && } - // No apps - show empty state - return - })()} - {isFetchingNextPage && ( + {!showSkeleton && isAppsPage && hasAnyApp && pages.flatMap(({ data: apps }) => apps).map(app => ( + + ))} + + {!showSkeleton && !isAppsPage && hasAnySnippet && snippetItems.map(snippet => ( + + ))} + + {showEmptyState && } + + {isAppsPage && isFetchingNextPage && ( + + )} + + {!isAppsPage && isSnippetListFetchingNextPage && ( )}
- {isCurrentWorkspaceEditor && ( + {isAppsPage && isCurrentWorkspaceEditor && (
@@ -319,17 +384,18 @@ const List: FC = ({ {t('newApp.dropDSLToCreateApp', { ns: 'app' })}
)} + {!systemFeatures.branding.enabled && (
)}
- {showTagManagementModal && ( + {isAppsPage && showTagManagementModal && ( )}
- {showCreateFromDSLModal && ( + {isAppsPage && showCreateFromDSLModal && ( { diff --git a/web/app/components/apps/studio-route-switch.tsx b/web/app/components/apps/studio-route-switch.tsx new file mode 100644 index 0000000000..18235f9b74 --- /dev/null +++ b/web/app/components/apps/studio-route-switch.tsx @@ -0,0 +1,48 @@ +'use client' + +import type { StudioPageType } from '.' +import { cn } from '@langgenius/dify-ui/cn' +import Link from '@/next/link' + +type Props = { + pageType: StudioPageType + appsLabel: string + snippetsLabel: string + showSnippets?: boolean +} + +const StudioRouteSwitch = ({ + pageType, + appsLabel, + snippetsLabel, + showSnippets = true, +}: Props) => { + return ( +
+ + {appsLabel} + + {showSnippets && ( + + {snippetsLabel} + + )} +
+ ) +} + +export default StudioRouteSwitch diff --git a/web/app/components/base/audio-gallery/AudioPlayer.tsx b/web/app/components/base/audio-gallery/AudioPlayer.tsx index c3b2056698..9174b13356 100644 --- a/web/app/components/base/audio-gallery/AudioPlayer.tsx +++ b/web/app/components/base/audio-gallery/AudioPlayer.tsx @@ -95,7 +95,7 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { for (let i = 0; i < samples; i++) { let sum = 0 for (let j = 0; j < blockSize; j++) - sum += Math.abs(channelData[i * blockSize + j]!) + sum += Math.abs(channelData[i * blockSize + j]) // Apply nonlinear scaling to enhance small amplitudes waveformData.push((sum / blockSize) * 5) } @@ -145,7 +145,7 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { e.preventDefault() const getClientX = (event: React.MouseEvent | React.TouchEvent): number => { if ('touches' in event) - return event.touches[0]!.clientX + return event.touches[0].clientX return event.clientX } const updateProgress = (clientX: number) => { diff --git a/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx b/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx index 83a8666e79..bd5f01bcda 100644 --- a/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/__tests__/chat-wrapper.spec.tsx @@ -151,8 +151,8 @@ describe('ChatWrapper', () => { render() - expect(await screen.findByText('Welcome'))!.toBeInTheDocument() - expect(await screen.findByText('Q1'))!.toBeInTheDocument() + expect(await screen.findByText('Welcome')).toBeInTheDocument() + expect(await screen.findByText('Q1')).toBeInTheDocument() fireEvent.click(screen.getByText('Q1')) expect(handleSend).toHaveBeenCalled() @@ -170,7 +170,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Default opening statement'))!.toBeInTheDocument() + expect(screen.getByText('Default opening statement')).toBeInTheDocument() }) it('should render welcome screen without suggested questions', async () => { @@ -186,7 +186,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Welcome message'))!.toBeInTheDocument() + expect(await screen.findByText('Welcome message')).toBeInTheDocument() }) it('should show responding state', async () => { @@ -197,7 +197,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Bot thinking...'))!.toBeInTheDocument() + expect(await screen.findByText('Bot thinking...')).toBeInTheDocument() }) it('should handle manual message input and stop responding', async () => { @@ -320,9 +320,9 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const disabledContainer = chatInput!.closest('.pointer-events-none') - expect(disabledContainer)!.toBeInTheDocument() - expect(disabledContainer)!.toHaveClass('opacity-50') + const disabledContainer = chatInput.closest('.pointer-events-none') + expect(disabledContainer).toBeInTheDocument() + expect(disabledContainer).toHaveClass('opacity-50') }) it('should not disable input when required field has value', () => { @@ -337,7 +337,7 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') + const container = chatInput.closest('.pointer-events-none') expect(container).not.toBeInTheDocument() }) @@ -361,8 +361,8 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') - expect(container)!.toBeInTheDocument() + const container = chatInput.closest('.pointer-events-none') + expect(container).toBeInTheDocument() }) it('should not disable input when file is fully uploaded', () => { @@ -411,8 +411,8 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') - expect(container)!.toBeInTheDocument() + const container = chatInput.closest('.pointer-events-none') + expect(container).toBeInTheDocument() }) it('should not disable when all files are uploaded', () => { @@ -457,7 +457,7 @@ describe('ChatWrapper', () => { render() const textarea = screen.getByRole('textbox') const container = textarea.closest('.pointer-events-none') - expect(container)!.toBeInTheDocument() + expect(container).toBeInTheDocument() }) it('should not disable input when allInputsHidden is true', () => { @@ -523,7 +523,7 @@ describe('ChatWrapper', () => { render() expect(handleSwitchSibling).toHaveBeenCalledWith('resume-node', expect.any(Object)) - const resumeOptions = handleSwitchSibling.mock.calls[0]![1] + const resumeOptions = handleSwitchSibling.mock.calls[0][1] resumeOptions.onGetSuggestedQuestions('response-from-resume') expect(fetchSuggestedQuestions).toHaveBeenCalledWith('response-from-resume', 'webApp', 'test-app-id') }) @@ -619,7 +619,7 @@ describe('ChatWrapper', () => { render() - const onStopCallback = vi.mocked(useChat).mock.calls[0]![3] as (taskId: string) => void + const onStopCallback = vi.mocked(useChat).mock.calls[0][3] as (taskId: string) => void onStopCallback('taskId-123') expect(stopChatMessageResponding).toHaveBeenCalledWith('', 'taskId-123', 'webApp', 'test-app-id') }) @@ -645,7 +645,7 @@ describe('ChatWrapper', () => { expect(handleSend).toHaveBeenCalled() // Get the options passed to handleSend - const options = handleSend.mock.calls[0]![2] + const options = handleSend.mock.calls[0][2] expect(options.isPublicAPI).toBe(true) // Call onGetSuggestedQuestions @@ -679,7 +679,7 @@ describe('ChatWrapper', () => { fireEvent.click(nextButton) expect(handleSwitchSibling).toHaveBeenCalled() - const options = handleSwitchSibling.mock.calls[0]![1] + const options = handleSwitchSibling.mock.calls[0][1] options.onGetSuggestedQuestions('response-id') expect(fetchSuggestedQuestions).toHaveBeenCalledWith('response-id', 'webApp', 'test-app-id') } @@ -708,8 +708,8 @@ describe('ChatWrapper', () => { expect(handleSend).toHaveBeenCalled() const args = handleSend.mock.calls[0] // args[1] is data - expect(args![1].query).toBe('Q1') - expect(args![1].parent_message_id).toBeNull() + expect(args[1].query).toBe('Q1') + expect(args[1].parent_message_id).toBeNull() } }) @@ -737,7 +737,7 @@ describe('ChatWrapper', () => { fireEvent.click(regenerateBtn) expect(handleSend).toHaveBeenCalled() const args = handleSend.mock.calls[0] - expect(args![1].parent_message_id).toBe('a0') + expect(args[1].parent_message_id).toBe('a0') } }) @@ -774,10 +774,10 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Node 1'))!.toBeInTheDocument() + expect(await screen.findByText('Node 1')).toBeInTheDocument() const input = screen.getAllByRole('textbox').find(el => el.closest('.chat-answer-container')) || screen.getAllByRole('textbox')[0] - fireEvent.change(input!, { target: { value: 'test' } }) + fireEvent.change(input, { target: { value: 'test' } }) const runButton = screen.getByText('Run') fireEvent.click(runButton) @@ -817,10 +817,10 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(await screen.findByText('Node Web 1'))!.toBeInTheDocument() + expect(await screen.findByText('Node Web 1')).toBeInTheDocument() const input = screen.getAllByRole('textbox').find(el => el.closest('.chat-answer-container')) || screen.getAllByRole('textbox')[0] - fireEvent.change(input!, { target: { value: 'web-test' } }) + fireEvent.change(input, { target: { value: 'web-test' } }) fireEvent.click(screen.getByText('Run')) await waitFor(() => { @@ -841,7 +841,7 @@ describe('ChatWrapper', () => { render() expect(document.querySelector('.chat-answer-container')).not.toBeInTheDocument() - expect(screen.getByText('Welcome'))!.toBeInTheDocument() + expect(screen.getByText('Welcome')).toBeInTheDocument() }) it('should show all messages including opening statement when there are multiple messages', () => { @@ -861,7 +861,7 @@ describe('ChatWrapper', () => { render() const welcomeElements = screen.getAllByText('Welcome') expect(welcomeElements.length).toBeGreaterThan(0) - expect(screen.getByText('User message'))!.toBeInTheDocument() + expect(screen.getByText('User message')).toBeInTheDocument() }) it('should show chatNode and inputs form on desktop for new conversation', () => { @@ -873,7 +873,7 @@ describe('ChatWrapper', () => { }) render() - expect(screen.getByText('Test'))!.toBeInTheDocument() + expect(screen.getByText('Test')).toBeInTheDocument() }) it('should show chatNode on mobile for new conversation only', () => { @@ -885,7 +885,7 @@ describe('ChatWrapper', () => { }) const { rerender } = render() - expect(screen.getByText('Test'))!.toBeInTheDocument() + expect(screen.getByText('Test')).toBeInTheDocument() vi.mocked(useChatWithHistoryContext).mockReturnValue({ ...defaultContextValue, @@ -974,8 +974,8 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Answer'))!.toBeInTheDocument() - expect(screen.getByAltText('answer icon'))!.toBeInTheDocument() + expect(screen.getByText('Answer')).toBeInTheDocument() + expect(screen.getByAltText('answer icon')).toBeInTheDocument() }) it('should render question icon fallback when user avatar is available', () => { @@ -993,7 +993,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('J'))!.toBeInTheDocument() + expect(screen.getByText('J')).toBeInTheDocument() }) it('should use fallback values for nullable appData, appMeta and avatar name', () => { @@ -1012,8 +1012,8 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Question with fallback avatar name'))!.toBeInTheDocument() - expect(screen.getByText('U'))!.toBeInTheDocument() + expect(screen.getByText('Question with fallback avatar name')).toBeInTheDocument() + expect(screen.getByText('U')).toBeInTheDocument() }) it('should set handleStop on currentChatInstanceRef', () => { @@ -1101,8 +1101,8 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') - expect(container)!.toBeInTheDocument() + const container = chatInput.closest('.pointer-events-none') + expect(container).toBeInTheDocument() }) it('should call formatBooleanInputs when sending message', async () => { @@ -1223,8 +1223,7 @@ describe('ChatWrapper', () => { render() // This tests line 91 - using currentConversationItem.introduction - // This tests line 91 - using currentConversationItem.introduction - expect(screen.getByText('Custom introduction from conversation item'))!.toBeInTheDocument() + expect(screen.getByText('Custom introduction from conversation item')).toBeInTheDocument() }) it('should handle early return when hasEmptyInput is already set', () => { @@ -1243,8 +1242,8 @@ describe('ChatWrapper', () => { // This tests line 106 - early return when hasEmptyInput is set const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') - expect(container)!.toBeInTheDocument() + const container = chatInput.closest('.pointer-events-none') + expect(container).toBeInTheDocument() }) it('should handle early return when fileIsUploading is already set', () => { @@ -1271,8 +1270,8 @@ describe('ChatWrapper', () => { // This tests line 109 - early return when fileIsUploading is set const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') - expect(container)!.toBeInTheDocument() + const container = chatInput.closest('.pointer-events-none') + expect(container).toBeInTheDocument() }) it('should handle doSend with no parent message id', async () => { @@ -1562,7 +1561,7 @@ describe('ChatWrapper', () => { } as unknown as ChatHookReturn) render() - expect(screen.getByText('Default opening statement'))!.toBeInTheDocument() + expect(screen.getByText('Default opening statement')).toBeInTheDocument() }) it('should handle doSend when regenerating with null parentAnswer', async () => { @@ -1610,9 +1609,7 @@ describe('ChatWrapper', () => { // Just verify the component renders - the actual editedQuestion flow // is tested through the doRegenerate callback that's passed to Chat - // Just verify the component renders - the actual editedQuestion flow - // is tested through the doRegenerate callback that's passed to Chat - expect(screen.getByText('Answer'))!.toBeInTheDocument() + expect(screen.getByText('Answer')).toBeInTheDocument() expect(handleSend).toBeDefined() }) @@ -1632,9 +1629,7 @@ describe('ChatWrapper', () => { // The doRegenerate is passed to Chat component and would be called // This ensures lines 198-200 are covered - // The doRegenerate is passed to Chat component and would be called - // This ensures lines 198-200 are covered - expect(screen.getByText('A1'))!.toBeInTheDocument() + expect(screen.getByText('A1')).toBeInTheDocument() }) it('should handle doRegenerate when question has message_files', async () => { @@ -1814,38 +1809,7 @@ describe('ChatWrapper', () => { render() const textboxes = screen.getAllByRole('textbox') const chatInput = textboxes[textboxes.length - 1] - const container = chatInput!.closest('.pointer-events-none') - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required - // Should not be disabled because it's not required + const container = chatInput.closest('.pointer-events-none') // Should not be disabled because it's not required expect(container).not.toBeInTheDocument() }) diff --git a/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx b/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx index b1c23a129b..5feaccd191 100644 --- a/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/header/__tests__/index.spec.tsx @@ -108,7 +108,7 @@ describe('Header Component', () => { currentConversationItem: mockConv, sidebarCollapseState: true, }) - expect(screen.getByText('My Chat'))!.toBeInTheDocument() + expect(screen.getByText('My Chat')).toBeInTheDocument() }) it('should render ViewFormDropdown trigger when inputsForms are present', () => { @@ -133,7 +133,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') // Sidebar, NewChat, ResetChat (3) const resetChatBtn = buttons[buttons.length - 1] - await userEvent.click(resetChatBtn!) + await userEvent.click(resetChatBtn) expect(handleNewConversation).toHaveBeenCalled() }) @@ -144,7 +144,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') const sidebarBtn = buttons[0] - await userEvent.click(sidebarBtn!) + await userEvent.click(sidebarBtn) expect(handleSidebarCollapse).toHaveBeenCalledWith(false) }) @@ -163,7 +163,7 @@ describe('Header Component', () => { await userEvent.click(trigger) const pinBtn = await screen.findByText('explore.sidebar.action.pin') - expect(pinBtn)!.toBeInTheDocument() + expect(pinBtn).toBeInTheDocument() await userEvent.click(pinBtn) @@ -225,7 +225,7 @@ describe('Header Component', () => { const renameMenuBtn = await screen.findByText('explore.sidebar.action.rename') await userEvent.click(renameMenuBtn) - expect(await screen.findByText('common.chat.renameConversation'))!.toBeInTheDocument() + expect(await screen.findByText('common.chat.renameConversation')).toBeInTheDocument() const input = screen.getByDisplayValue('My Chat') await userEvent.clear(input) @@ -236,7 +236,7 @@ describe('Header Component', () => { expect(handleRenameConversation).toHaveBeenCalledWith('conv-1', 'New Name', expect.any(Object)) - const successCallback = handleRenameConversation.mock.calls[0]![2].onSuccess + const successCallback = handleRenameConversation.mock.calls[0][2].onSuccess await act(async () => { successCallback() }) @@ -262,14 +262,14 @@ describe('Header Component', () => { await userEvent.click(deleteMenuBtn) expect(handleDeleteConversation).not.toHaveBeenCalled() - expect(await screen.findByText('share.chat.deleteConversation.title'))!.toBeInTheDocument() + expect(await screen.findByText('share.chat.deleteConversation.title')).toBeInTheDocument() const confirmBtn = await screen.findByText('common.operation.confirm') await userEvent.click(confirmBtn) expect(handleDeleteConversation).toHaveBeenCalledWith('conv-1', expect.any(Object)) - const successCallback = handleDeleteConversation.mock.calls[0]![1].onSuccess + const successCallback = handleDeleteConversation.mock.calls[0][1].onSuccess await act(async () => { successCallback() }) @@ -311,7 +311,7 @@ describe('Header Component', () => { await userEvent.click(screen.getByText('My Chat')) await userEvent.click(await screen.findByText('explore.sidebar.action.delete')) - expect(await screen.findByText('share.chat.deleteConversation.title'))!.toBeInTheDocument() + expect(await screen.findByText('share.chat.deleteConversation.title')).toBeInTheDocument() }) }) @@ -332,7 +332,7 @@ describe('Header Component', () => { it('should render system title if conversation id is missing', () => { setup({ currentConversationId: '', sidebarCollapseState: true }) const titleEl = screen.getByText('Test App') - expect(titleEl)!.toHaveClass('system-md-semibold') + expect(titleEl).toHaveClass('system-md-semibold') }) it('should render app icon from URL when icon_url is provided', () => { @@ -347,7 +347,7 @@ describe('Header Component', () => { }, }) const img = screen.getByAltText('app icon') - expect(img)!.toHaveAttribute('src', 'https://example.com/icon.png') + expect(img).toHaveAttribute('src', 'https://example.com/icon.png') }) it('should handle undefined appData gracefully (optional chaining)', () => { @@ -364,8 +364,7 @@ describe('Header Component', () => { sidebarCollapseState: true, }) // The separator is just a div with text content '/' - // The separator is just a div with text content '/' - expect(screen.getByText('/'))!.toBeInTheDocument() + expect(screen.getByText('/')).toBeInTheDocument() }) it('should handle New Chat button state when currentConversationId is present but isResponding is true', () => { @@ -378,7 +377,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') // Sidebar, NewChat, ResetChat (3) const newChatBtn = buttons[1] - expect(newChatBtn)!.toBeDisabled() + expect(newChatBtn).toBeDisabled() }) it('should handle New Chat button state when currentConversationId is missing and isResponding is false', () => { @@ -391,7 +390,7 @@ describe('Header Component', () => { const buttons = screen.getAllByRole('button') // Sidebar, NewChat (2) const newChatBtn = buttons[1] - expect(newChatBtn)!.toBeDisabled() + expect(newChatBtn).toBeDisabled() }) it('should not render operation menu if conversation id is missing', () => { diff --git a/web/app/components/base/chat/chat-with-history/header/operation.tsx b/web/app/components/base/chat/chat-with-history/header/operation.tsx index d439a43c1f..a6dd6a0a9e 100644 --- a/web/app/components/base/chat/chat-with-history/header/operation.tsx +++ b/web/app/components/base/chat/chat-with-history/header/operation.tsx @@ -71,7 +71,7 @@ const Operation: FC = ({ )} {isShowDelete && ( handleDeferredAction(onDelete)} > diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index df261d750c..e6f5657ff5 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -452,7 +452,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) setOriginConversationList(produce((draft) => { const index = originConversationList.findIndex(item => item.id === conversationId) - const item = draft[index]! + const item = draft[index] draft[index] = { ...item, name: newName, diff --git a/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx b/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx index adda03fb55..611d2bb1b9 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/operation.tsx @@ -105,7 +105,7 @@ const Operation: FC = ({ )} {isShowDelete && ( { e.stopPropagation() diff --git a/web/app/components/base/chat/chat/citation/popup.tsx b/web/app/components/base/chat/chat/citation/popup.tsx index 51a73bc4b6..2b4070b69a 100644 --- a/web/app/components/base/chat/chat/citation/popup.tsx +++ b/web/app/components/base/chat/chat/citation/popup.tsx @@ -64,10 +64,10 @@ const Popup: FC = ({
-
+
-
+
{(data.dataSourceType === 'upload_file' || data.dataSourceType === 'file') && !!data.sources?.[0]?.dataset_id ? ( ) - expect(screen.getByRole('button', { name: 'Run' }))!.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Run' })).toBeInTheDocument() }) it('should render correctly without optional className props', () => { const { wrapper, canvasLayer, gradientLayer, contentLayer } = renderGridMask({}, Plain child) - expect(wrapper)!.toHaveClass('bg-saas-background') - expect(canvasLayer)!.toHaveClass('absolute') - expect(gradientLayer)!.toHaveClass('absolute') - expect(contentLayer)!.toHaveTextContent('Plain child') + expect(wrapper).toHaveClass('bg-saas-background') + expect(canvasLayer).toHaveClass('absolute') + expect(gradientLayer).toHaveClass('absolute') + expect(contentLayer).toHaveTextContent('Plain child') }) it('should render wrapper, canvas, gradient and content layers in order', () => { const { wrapper, canvasLayer, gradientLayer, contentLayer } = renderGridMask({}, Content) - expect(wrapper)!.toBeInTheDocument() + expect(wrapper).toBeInTheDocument() expect(wrapper.children).toHaveLength(3) - expect(canvasLayer)!.toHaveClass('z-0') - expect(gradientLayer)!.toHaveClass('z-1') - expect(contentLayer)!.toHaveClass('z-2') - expect(contentLayer)!.toHaveTextContent('Content') + expect(canvasLayer).toHaveClass('z-0') + expect(gradientLayer).toHaveClass('z-1') + expect(contentLayer).toHaveClass('z-2') + expect(contentLayer).toHaveTextContent('Content') }) }) describe('Props', () => { it('should apply wrapperClassName to wrapper element', () => { const { wrapper } = renderGridMask({ wrapperClassName: 'custom-wrapper' }, Child) - expect(wrapper)!.toHaveClass('custom-wrapper') - expect(wrapper)!.toHaveClass('relative') + expect(wrapper).toHaveClass('custom-wrapper') + expect(wrapper).toHaveClass('relative') }) it('should apply canvasClassName and grid background class to canvas layer', () => { const { canvasLayer } = renderGridMask({ canvasClassName: 'custom-canvas' }, Child) - expect(canvasLayer)!.toHaveClass('custom-canvas') - expect(canvasLayer)!.toHaveClass(Style.gridBg!) + expect(canvasLayer).toHaveClass('custom-canvas') + expect(canvasLayer).toHaveClass(Style.gridBg) }) it('should apply gradientClassName to gradient layer', () => { const { gradientLayer } = renderGridMask({ gradientClassName: 'custom-gradient' }, Child) - expect(gradientLayer)!.toHaveClass('custom-gradient') - expect(gradientLayer)!.toHaveClass('bg-grid-mask-background') + expect(gradientLayer).toHaveClass('custom-gradient') + expect(gradientLayer).toHaveClass('bg-grid-mask-background') }) }) }) diff --git a/web/app/components/base/icons/src/vender/line/development/index.ts b/web/app/components/base/icons/src/vender/line/development/index.ts index 7c3c48aa5e..4278370eec 100644 --- a/web/app/components/base/icons/src/vender/line/development/index.ts +++ b/web/app/components/base/icons/src/vender/line/development/index.ts @@ -1,2 +1 @@ export { default as BracketsX } from './BracketsX' -export { default as CodeBrowser } from './CodeBrowser' diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index ec1b1248f1..8a1a082b0f 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -31,7 +31,7 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentFile = files[index]! + const currentFile = files[index] const newFiles = [...files.slice(0, index), { ...currentFile, deleted: true }, ...files.slice(index + 1)] setFiles(newFiles) filesRef.current = newFiles @@ -41,7 +41,7 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentFile = files[index]! + const currentFile = files[index] const newFiles = [...files.slice(0, index), { ...currentFile, progress: -1 }, ...files.slice(index + 1)] filesRef.current = newFiles setFiles(newFiles) @@ -51,7 +51,7 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentImageFile = files[index]! + const currentImageFile = files[index] const newFiles = [...files.slice(0, index), { ...currentImageFile, progress: 100 }, ...files.slice(index + 1)] filesRef.current = newFiles setFiles(newFiles) @@ -61,9 +61,9 @@ export const useImageFiles = () => { const files = filesRef.current const index = files.findIndex(file => file._id === imageFileId) if (index > -1) { - const currentImageFile = files[index]! + const currentImageFile = files[index] imageUpload({ - file: currentImageFile!.file!, + file: currentImageFile.file!, onProgressCallback: (progress) => { const newFiles = [...files.slice(0, index), { ...currentImageFile, progress }, ...files.slice(index + 1)] filesRef.current = newFiles @@ -114,7 +114,7 @@ export const useLocalFileUploader = ({ limit, disabled = false, onUpload }: useL // TODO: leave some warnings? return } - if (!ALLOW_FILE_EXTENSIONS.includes(file.type.split('/')[1]!)) + if (!ALLOW_FILE_EXTENSIONS.includes(file.type.split('/')[1])) return if (limit && file.size > limit * 1024 * 1024) { toast.error(t('imageUploader.uploadFromComputerLimit', { ns: 'common', size: limit })) diff --git a/web/app/components/base/image-uploader/image-list.stories.tsx b/web/app/components/base/image-uploader/image-list.stories.tsx index 0c27211d16..cfea4a0da0 100644 --- a/web/app/components/base/image-uploader/image-list.stories.tsx +++ b/web/app/components/base/image-uploader/image-list.stories.tsx @@ -132,7 +132,7 @@ const ImageUploaderPlayground = ({ readonly }: Story['args']) => { return (
- Add images + Add images
)} diff --git a/web/app/components/base/notion-icon/index.tsx b/web/app/components/base/notion-icon/index.tsx index f2b5146d73..62fcef1dc1 100644 --- a/web/app/components/base/notion-icon/index.tsx +++ b/web/app/components/base/notion-icon/index.tsx @@ -31,7 +31,7 @@ const NotionIcon = ({ ) } return ( -
{name?.[0]!.toLocaleUpperCase()}
+
{name?.[0].toLocaleUpperCase()}
) } diff --git a/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx b/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx index 21a7a08d63..d4b559452e 100644 --- a/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx +++ b/web/app/components/base/notion-page-selector/page-selector/__tests__/index.spec.tsx @@ -24,11 +24,11 @@ const mockList: DataSourceNotionPage[] = [ ] const mockPagesMap: DataSourceNotionPageMap = { - 'root-1': { ...mockList[0]!, workspace_id: 'workspace-1' }, - 'child-1': { ...mockList[1]!, workspace_id: 'workspace-1' }, - 'grandchild-1': { ...mockList[2]!, workspace_id: 'workspace-1' }, - 'child-2': { ...mockList[3]!, workspace_id: 'workspace-1' }, - 'root-2': { ...mockList[4]!, workspace_id: 'workspace-1' }, + 'root-1': { ...mockList[0], workspace_id: 'workspace-1' }, + 'child-1': { ...mockList[1], workspace_id: 'workspace-1' }, + 'grandchild-1': { ...mockList[2], workspace_id: 'workspace-1' }, + 'child-2': { ...mockList[3], workspace_id: 'workspace-1' }, + 'root-2': { ...mockList[4], workspace_id: 'workspace-1' }, } describe('PageSelector', () => { @@ -39,7 +39,7 @@ describe('PageSelector', () => { it('should render root level pages initially', () => { render() - expect(screen.getByText('Root 1'))!.toBeInTheDocument() + expect(screen.getByText('Root 1')).toBeInTheDocument() expect(screen.queryByText('Child 1')).not.toBeInTheDocument() }) @@ -50,13 +50,13 @@ describe('PageSelector', () => { const toggle = screen.getByTestId('notion-page-toggle-root-1') await user.click(toggle) - expect(screen.getByText('Child 1'))!.toBeInTheDocument() + expect(screen.getByText('Child 1')).toBeInTheDocument() }) it('should call onSelect with descendants when parent is selected', async () => { const handleSelect = vi.fn() const user = userEvent.setup() - render() + render() const checkbox = screen.getByTestId('checkbox-notion-page-checkbox-root-1') await user.click(checkbox) @@ -78,7 +78,7 @@ describe('PageSelector', () => { it('should show breadcrumbs when searching', () => { render() - expect(screen.getByText('Root 1 / Child 1 / Grandchild 1'))!.toBeInTheDocument() + expect(screen.getByText('Root 1 / Child 1 / Grandchild 1')).toBeInTheDocument() }) it('should call onPreview when preview button is clicked', async () => { @@ -95,7 +95,7 @@ describe('PageSelector', () => { it('should show no result message when search returns nothing', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() }) it('should handle selection when searchValue is present', async () => { @@ -124,7 +124,7 @@ describe('PageSelector', () => { const toggleBtn = screen.getByTestId('notion-page-toggle-root-1') await user.click(toggleBtn) // Expand - await waitFor(() => expect(screen.queryByText('Child 1'))!.toBeInTheDocument()) + await waitFor(() => expect(screen.queryByText('Child 1')).toBeInTheDocument()) await user.click(toggleBtn) // Collapse await waitFor(() => expect(screen.queryByText('Child 1')).not.toBeInTheDocument()) @@ -149,14 +149,14 @@ describe('PageSelector', () => { it('should render preview button when canPreview is true', () => { render() - expect(screen.getByTestId('notion-page-preview-root-1'))!.toBeInTheDocument() + expect(screen.getByTestId('notion-page-preview-root-1')).toBeInTheDocument() }) it('should use previewPageId prop when provided', () => { const { rerender } = render() let row = screen.getByTestId('notion-page-row-root-1') - expect(row)!.toHaveClass('bg-state-base-hover') + expect(row).toHaveClass('bg-state-base-hover') rerender() @@ -190,9 +190,8 @@ describe('PageSelector', () => { await user.click(toggle) // Both children should be visible - // Both children should be visible - expect(screen.getByText('Child 1'))!.toBeInTheDocument() - expect(screen.getByText('Child 2'))!.toBeInTheDocument() + expect(screen.getByText('Child 1')).toBeInTheDocument() + expect(screen.getByText('Child 2')).toBeInTheDocument() }) it('should expand nested children when toggling parent', async () => { @@ -202,12 +201,12 @@ describe('PageSelector', () => { // Expand root-1 let toggle = screen.getByTestId('notion-page-toggle-root-1') await user.click(toggle) - expect(screen.getByText('Child 1'))!.toBeInTheDocument() + expect(screen.getByText('Child 1')).toBeInTheDocument() // Expand child-1 toggle = screen.getByTestId('notion-page-toggle-child-1') await user.click(toggle) - expect(screen.getByText('Grandchild 1'))!.toBeInTheDocument() + expect(screen.getByText('Grandchild 1')).toBeInTheDocument() // Collapse child-1 await user.click(toggle) @@ -228,7 +227,7 @@ describe('PageSelector', () => { it('should only select the item when searching (no descendants)', async () => { const handleSelect = vi.fn() const user = userEvent.setup() - render() + render() const checkbox = screen.getByTestId('checkbox-notion-page-checkbox-child-1') await user.click(checkbox) @@ -240,7 +239,7 @@ describe('PageSelector', () => { it('should deselect only the item when searching (no descendants)', async () => { const handleSelect = vi.fn() const user = userEvent.setup() - render() + render() const checkbox = screen.getByTestId('checkbox-notion-page-checkbox-child-1') await user.click(checkbox) @@ -251,8 +250,8 @@ describe('PageSelector', () => { it('should handle multiple root pages', async () => { render() - expect(screen.getByText('Root 1'))!.toBeInTheDocument() - expect(screen.getByText('Root 2'))!.toBeInTheDocument() + expect(screen.getByText('Root 1')).toBeInTheDocument() + expect(screen.getByText('Root 2')).toBeInTheDocument() }) it('should update preview when clicking preview button with onPreview provided', async () => { @@ -277,61 +276,29 @@ describe('PageSelector', () => { rerender() const row = screen.getByTestId('notion-page-row-root-1') - expect(row)!.toHaveClass('bg-state-base-hover') + expect(row).toHaveClass('bg-state-base-hover') }) it('should render page name with correct title attribute', () => { render() const pageName = screen.getByTestId('notion-page-name-root-1') - expect(pageName)!.toHaveAttribute('title', 'Root 1') + expect(pageName).toHaveAttribute('title', 'Root 1') }) it('should handle empty list gracefully', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() }) it('should filter search results correctly with partial matches', () => { render() // Should show Root 1, Child 1, and Grandchild 1 - // Should show Root 1, Child 1, and Grandchild 1 - expect(screen.getByTestId('notion-page-name-root-1'))!.toBeInTheDocument() - expect(screen.getByTestId('notion-page-name-child-1'))!.toBeInTheDocument() - expect(screen.getByTestId('notion-page-name-grandchild-1'))!.toBeInTheDocument() - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 - // Should not show Root 2, Child 2 + expect(screen.getByTestId('notion-page-name-root-1')).toBeInTheDocument() + expect(screen.getByTestId('notion-page-name-child-1')).toBeInTheDocument() + expect(screen.getByTestId('notion-page-name-grandchild-1')).toBeInTheDocument() // Should not show Root 2, Child 2 expect(screen.queryByTestId('notion-page-name-root-2')).not.toBeInTheDocument() expect(screen.queryByTestId('notion-page-name-child-2')).not.toBeInTheDocument() @@ -346,7 +313,6 @@ describe('PageSelector', () => { await user.click(toggle) // Should expand even though parent is disabled - // Should expand even though parent is disabled - expect(screen.getByText('Child 1'))!.toBeInTheDocument() + expect(screen.getByText('Child 1')).toBeInTheDocument() }) }) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx index 1499fc1d7f..a36403b898 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/prompt-option.tsx @@ -23,7 +23,7 @@ export const PromptMenuItem = memo(({ className={` flex h-6 cursor-pointer items-center rounded-md px-3 hover:bg-state-base-hover ${isSelected && !disabled && 'bg-state-base-hover!'} - ${disabled ? 'cursor-not-allowed opacity-30' : ''} + ${disabled ? 'cursor-not-allowed opacity-30' : 'cursor-pointer hover:bg-state-base-hover'} `} tabIndex={-1} ref={setRefElement} diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx index ee82595d1c..f219f2f805 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/component.spec.tsx @@ -100,8 +100,8 @@ describe('HITLInputComponent', () => { await user.click(screen.getByRole('button', { name: 'emit-same-name' })) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0]![0]).toHaveLength(1) - expect(onChange.mock.calls[0]![0][0].output_variable_name).toBe('user_name') + expect(onChange.mock.calls[0][0]).toHaveLength(1) + expect(onChange.mock.calls[0][0][0].output_variable_name).toBe('user_name') }) it('should replace payload when variable name is renamed', async () => { @@ -124,7 +124,7 @@ describe('HITLInputComponent', () => { await user.click(screen.getByRole('button', { name: 'emit-rename' })) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0]![0][0].output_variable_name).toBe('renamed_name') + expect(onChange.mock.calls[0][0][0].output_variable_name).toBe('renamed_name') }) it('should update existing payload when variable name stays the same', async () => { @@ -157,9 +157,9 @@ describe('HITLInputComponent', () => { await user.click(screen.getByRole('button', { name: 'emit-update' })) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0]![0][0].default.value).toBe('updated') - expect(onChange.mock.calls[0]![0][0].output_variable_name).toBe('user_name') - expect(onChange.mock.calls[0]![0][1].output_variable_name).toBe('other_name') - expect(onChange.mock.calls[0]![0][1].default.value).toBe('other') + expect(onChange.mock.calls[0][0][0].default.value).toBe('updated') + expect(onChange.mock.calls[0][0][0].output_variable_name).toBe('user_name') + expect(onChange.mock.calls[0][0][1].output_variable_name).toBe('other_name') + expect(onChange.mock.calls[0][0][1].default.value).toBe('other') }) }) diff --git a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx index 990a7ced4a..f5efc52c23 100644 --- a/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/hitl-input-block/__tests__/pre-populate.spec.tsx @@ -82,12 +82,12 @@ describe('PrePopulate', () => { />, ) - expect(screen.getByText('Static Content'))!.toBeInTheDocument() + expect(screen.getByText('Static Content')).toBeInTheDocument() await user.keyboard('{Tab}') expect(screen.queryByText('Static Content')).not.toBeInTheDocument() - expect(screen.getByRole('textbox'))!.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() }) it('should update constant value and toggle to variable mode when type switch is clicked', async () => { @@ -154,7 +154,7 @@ describe('PrePopulate', () => { />, ) - const pickerProps = mockVarReferencePicker.mock.calls[0]![0] as VarReferencePickerProps + const pickerProps = mockVarReferencePicker.mock.calls[0][0] as VarReferencePickerProps const allowString = pickerProps.filterVar({ type: 'string' } as Var) const allowNumber = pickerProps.filterVar({ type: 'number' } as Var) diff --git a/web/app/components/base/prompt-editor/plugins/query-block/component.tsx b/web/app/components/base/prompt-editor/plugins/query-block/component.tsx index cd5b60bc9b..a5b5969904 100644 --- a/web/app/components/base/prompt-editor/plugins/query-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/query-block/component.tsx @@ -17,7 +17,7 @@ const QueryBlockComponent: FC = ({ return (
{ />, ) - expect(screen.getByRole('button', { name: 'label' }))!.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'label' })).toBeInTheDocument() expect(mockHasNodes).toHaveBeenCalledWith([WorkflowVariableBlockNode]) expect(mockRegisterCommand).toHaveBeenCalledWith( UPDATE_WORKFLOW_NODES_MAP, @@ -188,7 +188,7 @@ describe('WorkflowVariableBlockComponent', () => { />, ) - expect(screen.getByRole('button', { name: 'label' }))!.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'label' })).toBeInTheDocument() }) it('should pass computed varType when getVarType is provided', () => { @@ -489,7 +489,7 @@ describe('WorkflowVariableBlockComponent', () => { />, ) - const updateHandler = mockRegisterCommand.mock.calls[0]![1] as (payload: UpdateWorkflowNodesMapPayload) => boolean + const updateHandler = mockRegisterCommand.mock.calls[0][1] as (payload: UpdateWorkflowNodesMapPayload) => boolean let result = false act(() => { result = updateHandler({ diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx index 20fc7c6e79..2d13627b20 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx @@ -19,11 +19,11 @@ export class WorkflowVariableBlockNode extends DecoratorNode __getVarType?: GetVarType __availableVariables?: NodeOutPutVar[] - static override getType(): string { + static getType(): string { return 'workflow-variable-block' } - static override clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { + static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { return new WorkflowVariableBlockNode( node.__variables, node.__workflowNodesMap, @@ -33,7 +33,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode ) } - override isInline(): boolean { + isInline(): boolean { return true } @@ -52,17 +52,17 @@ export class WorkflowVariableBlockNode extends DecoratorNode this.__availableVariables = availableVariables } - override createDOM(): HTMLElement { + createDOM(): HTMLElement { const div = document.createElement('div') div.classList.add('inline-flex', 'items-center', 'align-middle') return div } - override updateDOM(): false { + updateDOM(): false { return false } - override decorate(): React.JSX.Element { + decorate(): React.JSX.Element { return ( ) } - static override importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { + static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { const node = $createWorkflowVariableBlockNode( serializedNode.variables, serializedNode.workflowNodesMap, @@ -85,7 +85,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode return node } - override exportJSON(): SerializedNode { + exportJSON(): SerializedNode { const json: SerializedNode = { type: 'workflow-variable-block', version: 1, @@ -119,7 +119,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode return self.__availableVariables } - override getTextContent(): string { + getTextContent(): string { return `{{#${this.getVariables().join('.')}#}}` } } diff --git a/web/app/components/base/prompt-log-modal/index.tsx b/web/app/components/base/prompt-log-modal/index.tsx index 08200623ae..6a79dfffeb 100644 --- a/web/app/components/base/prompt-log-modal/index.tsx +++ b/web/app/components/base/prompt-log-modal/index.tsx @@ -42,13 +42,13 @@ const PromptLogModal: FC = ({ }} ref={ref} > -
+
PROMPT LOG
{ currentLogItem.log?.length === 1 && ( <> - +
) diff --git a/web/app/components/base/select/locale-signin.tsx b/web/app/components/base/select/locale-signin.tsx index 046a76a5d4..3c5dd999f6 100644 --- a/web/app/components/base/select/locale-signin.tsx +++ b/web/app/components/base/select/locale-signin.tsx @@ -36,7 +36,7 @@ export default function LocaleSigninSelect({ leaveTo="transform opacity-0 scale-95" > -
+
{items.map((item) => { return ( diff --git a/web/app/components/base/spinner/index.tsx b/web/app/components/base/spinner/index.tsx index 48ee65b99f..65fea46a91 100644 --- a/web/app/components/base/spinner/index.tsx +++ b/web/app/components/base/spinner/index.tsx @@ -14,7 +14,7 @@ const Spinner: FC = ({ loading = false, children, className }) => { role="status" > Loading... diff --git a/web/app/components/base/tag-management/index.tsx b/web/app/components/base/tag-management/index.tsx index 8e693fb9f1..79c557a8b9 100644 --- a/web/app/components/base/tag-management/index.tsx +++ b/web/app/components/base/tag-management/index.tsx @@ -48,8 +48,8 @@ const TagManagementModal = ({ show, type }: TagManagementModalProps) => { }, [type]) return ( setShowTagManagementModal(false)}> -
{t('tag.manageTags', { ns: 'common' })}
-
setShowTagManagementModal(false)}> +
{t('tag.manageTags', { ns: 'common' })}
+
setShowTagManagementModal(false)}>
diff --git a/web/app/components/base/tag-management/panel.tsx b/web/app/components/base/tag-management/panel.tsx index db0aae1b05..cceb09b4d7 100644 --- a/web/app/components/base/tag-management/panel.tsx +++ b/web/app/components/base/tag-management/panel.tsx @@ -114,7 +114,7 @@ const Panel = (props: PanelProps) => {
-
+
{`${t('tag.create', { ns: 'common' })} `} {`'${keywords}'`}
@@ -127,7 +127,7 @@ const Panel = (props: PanelProps) => { {filteredSelectedTagList.map(tag => (
selectTag(tag)} data-testid="tag-row"> -
+
{tag.name}
@@ -135,7 +135,7 @@ const Panel = (props: PanelProps) => { {filteredTagList.map(tag => (
selectTag(tag)} data-testid="tag-row"> -
+
{tag.name}
@@ -146,7 +146,7 @@ const Panel = (props: PanelProps) => {
-
{t('tag.noTag', { ns: 'common' })}
+
{t('tag.noTag', { ns: 'common' })}
)} @@ -154,7 +154,7 @@ const Panel = (props: PanelProps) => {
setShowTagManagementModal(true)}> -
+
{t('tag.manageTags', { ns: 'common' })}
diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx index a4b8888b27..0ae553ec01 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/__tests__/index.spec.tsx @@ -98,10 +98,10 @@ describe('CloudPlanItem', () => { />, ) - expect(screen.getByText('billing.plans.sandbox.name'))!.toBeInTheDocument() - expect(screen.getByText('billing.plans.sandbox.description'))!.toBeInTheDocument() - expect(screen.getByText('billing.plansCommon.free'))!.toBeInTheDocument() - expect(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))!.toBeInTheDocument() + expect(screen.getByText('billing.plans.sandbox.name')).toBeInTheDocument() + expect(screen.getByText('billing.plans.sandbox.description')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.free')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' })).toBeInTheDocument() }) it('should display yearly pricing with discount when planRange is yearly', () => { @@ -115,9 +115,9 @@ describe('CloudPlanItem', () => { ) const professionalPlan = ALL_PLANS[Plan.professional] - expect(screen.getByText(`$${professionalPlan.price * 12}`))!.toBeInTheDocument() - expect(screen.getByText(`$${professionalPlan.price * 10}`))!.toBeInTheDocument() - expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.year/))!.toBeInTheDocument() + expect(screen.getByText(`$${professionalPlan.price * 12}`)).toBeInTheDocument() + expect(screen.getByText(`$${professionalPlan.price * 10}`)).toBeInTheDocument() + expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.year/)).toBeInTheDocument() }) it('should show "most popular" badge for professional plan', () => { @@ -130,7 +130,7 @@ describe('CloudPlanItem', () => { />, ) - expect(screen.getByText('billing.plansCommon.mostPopular'))!.toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.mostPopular')).toBeInTheDocument() }) it('should not show "most popular" badge for non-professional plans', () => { @@ -157,7 +157,7 @@ describe('CloudPlanItem', () => { ) const button = screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }) - expect(button)!.toBeDisabled() + expect(button).toBeDisabled() }) }) @@ -176,7 +176,7 @@ describe('CloudPlanItem', () => { ) fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' })) - expect(screen.getByText('billing.buyPermissionDeniedTip'))!.toBeInTheDocument() + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() expect(mockBillingInvoices).not.toHaveBeenCalled() }) @@ -320,7 +320,7 @@ describe('CloudPlanItem', () => { expect(openWindow).toHaveBeenCalledTimes(1) // The onError callback should have been passed to openAsyncWindow const callArgs = openWindow.mock.calls[0] - expect(callArgs![1]).toHaveProperty('onError') + expect(callArgs[1]).toHaveProperty('onError') }) }) @@ -336,39 +336,8 @@ describe('CloudPlanItem', () => { ) const teamPlan = ALL_PLANS[Plan.team] - expect(screen.getByText(`$${teamPlan.price}`))!.toBeInTheDocument() - expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.month/))!.toBeInTheDocument() - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price - // Should NOT show crossed-out yearly price + expect(screen.getByText(`$${teamPlan.price}`)).toBeInTheDocument() + expect(screen.getByText(/billing\.plansCommon\.priceTip.*billing\.plansCommon\.month/)).toBeInTheDocument() // Should NOT show crossed-out yearly price expect(screen.queryByText(`$${teamPlan.price * 12}`)).not.toBeInTheDocument() }) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index 99d956ba90..b85f1d8631 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -103,38 +103,38 @@ const CloudPlanItem: FC = ({ {ICON_MAP[plan]}
-
{t(`${i18nPrefix}.name`, { ns: 'billing' })}
+
{t(`${i18nPrefix}.name`, { ns: 'billing' })}
{ isMostPopularPlan && (
- + {t('plansCommon.mostPopular', { ns: 'billing' })}
) }
-
{t(`${i18nPrefix}.description`, { ns: 'billing' })}
+
{t(`${i18nPrefix}.description`, { ns: 'billing' })}
{/* Price */} -
+
{isFreePlan && ( - {t('plansCommon.free', { ns: 'billing' })} + {t('plansCommon.free', { ns: 'billing' })} )} {!isFreePlan && ( <> {isYear && ( - + $ {planInfo.price * 12} )} - + $ {isYear ? planInfo.price * 10 : planInfo.price} - + {t('plansCommon.priceTip', { ns: 'billing' })} {t(`plansCommon.${!isYear ? 'month' : 'year'}`, { ns: 'billing' })} diff --git a/web/app/components/billing/snippet-and-evaluation-plan-guard.tsx b/web/app/components/billing/snippet-and-evaluation-plan-guard.tsx new file mode 100644 index 0000000000..39ed41a348 --- /dev/null +++ b/web/app/components/billing/snippet-and-evaluation-plan-guard.tsx @@ -0,0 +1,40 @@ +'use client' + +import type { ReactNode } from 'react' +import { useEffect } from 'react' +import Loading from '@/app/components/base/loading' +import { useSnippetAndEvaluationPlanAccess } from '@/hooks/use-snippet-and-evaluation-plan-access' +import { useRouter } from '@/next/navigation' + +type SnippetAndEvaluationPlanGuardProps = { + children: ReactNode + fallbackHref: string +} + +const SnippetAndEvaluationPlanGuard = ({ + children, + fallbackHref, +}: SnippetAndEvaluationPlanGuardProps) => { + const router = useRouter() + const { canAccess, isReady } = useSnippetAndEvaluationPlanAccess() + + useEffect(() => { + if (isReady && !canAccess) + router.replace(fallbackHref) + }, [canAccess, fallbackHref, isReady, router]) + + if (!isReady) { + return ( +
+ +
+ ) + } + + if (!canAccess) + return null + + return <>{children} +} + +export default SnippetAndEvaluationPlanGuard diff --git a/web/app/components/billing/utils/index.ts b/web/app/components/billing/utils/index.ts index 2d37eecbd5..c61ed2b54d 100644 --- a/web/app/components/billing/utils/index.ts +++ b/web/app/components/billing/utils/index.ts @@ -1,6 +1,7 @@ import type { BasicPlan, BillingQuota, CurrentPlanInfoBackend } from '../type' import dayjs from 'dayjs' import { ALL_PLANS, NUM_INFINITE } from '@/app/components/billing/config' +import { Plan } from '../type' /** * Parse vectorSpace string from ALL_PLANS config and convert to MB @@ -116,3 +117,21 @@ export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => { }, } } + +export const canAccessSnippetsAndEvaluation = ({ + enableBilling, + isFetchedPlan, + planType, +}: { + enableBilling: boolean + isFetchedPlan: boolean + planType: Plan +}) => { + if (!isFetchedPlan) + return !enableBilling + + if (!enableBilling) + return true + + return planType === Plan.professional || planType === Plan.team || planType === Plan.enterprise +} diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx index 2cc51a3be5..f899dcb33d 100644 --- a/web/app/components/datasets/common/image-previewer/index.tsx +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -137,7 +137,7 @@ const ImagePreviewer = ({ return { ...prev, [image.url]: { - ...prev[image.url]!, + ...prev[image.url], status: 'loading', }, } @@ -168,15 +168,15 @@ const ImagePreviewer = ({ Esc
- {cachedImages[currentImage!.url]!.status === 'loading' && ( + {cachedImages[currentImage.url].status === 'loading' && ( )} - {cachedImages[currentImage!.url]!.status === 'error' && ( + {cachedImages[currentImage.url].status === 'error' && (
- {`Failed to load image: ${currentImage!.url}. Please try again.`} + {`Failed to load image: ${currentImage.url}. Please try again.`}
)} - {cachedImages[currentImage!.url]!.status === 'loaded' && ( + {cachedImages[currentImage.url].status === 'loaded' && (
{currentImage!.name}
- {currentImage!.name} + {currentImage.name} · - {`${cachedImages[currentImage!.url]!.width} ×  ${cachedImages[currentImage!.url]!.height}`} + {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} · - {formatFileSize(currentImage!.size)} + {formatFileSize(currentImage.size)}
)} diff --git a/web/app/components/datasets/create/step-one/upgrade-card.tsx b/web/app/components/datasets/create/step-one/upgrade-card.tsx index e7016206ea..356e15ed43 100644 --- a/web/app/components/datasets/create/step-one/upgrade-card.tsx +++ b/web/app/components/datasets/create/step-one/upgrade-card.tsx @@ -15,9 +15,9 @@ const UpgradeCard: FC = () => { }, [setShowPricingModal]) return ( -
+
-
{t('upgrade.uploadMultipleFiles.title', { ns: 'billing' })}
+
{t('upgrade.uploadMultipleFiles.title', { ns: 'billing' })}
{t('upgrade.uploadMultipleFiles.description', { ns: 'billing' })}
{ render() // Should render Previous and Next buttons with correct text - // Should render Previous and Next buttons with correct text - expect(screen.getByText(/previousStep/i))!.toBeInTheDocument() - expect(screen.getByText(/nextStep/i))!.toBeInTheDocument() + expect(screen.getByText(/previousStep/i)).toBeInTheDocument() + expect(screen.getByText(/nextStep/i)).toBeInTheDocument() }) it('should render Previous and Next buttons when not in setting mode', () => { render() - expect(screen.getByText(/previousStep/i))!.toBeInTheDocument() - expect(screen.getByText(/nextStep/i))!.toBeInTheDocument() + expect(screen.getByText(/previousStep/i)).toBeInTheDocument() + expect(screen.getByText(/nextStep/i)).toBeInTheDocument() }) it('should render Save and Cancel buttons when in setting mode', () => { render() - expect(screen.getByText(/save/i))!.toBeInTheDocument() - expect(screen.getByText(/cancel/i))!.toBeInTheDocument() + expect(screen.getByText(/save/i)).toBeInTheDocument() + expect(screen.getByText(/cancel/i)).toBeInTheDocument() }) }) @@ -1773,14 +1772,14 @@ describe('StepTwoFooter', () => { render() const nextButton = screen.getByText(/nextStep/i).closest('button') - expect(nextButton)!.toBeDisabled() + expect(nextButton).toBeDisabled() }) it('should show loading state on Save button when creating in setting mode', () => { render() const saveButton = screen.getByText(/save/i).closest('button') - expect(saveButton)!.toBeDisabled() + expect(saveButton).toBeDisabled() }) }) }) @@ -1812,50 +1811,18 @@ describe('PreviewPanel', () => { render() // Check for the preview header title text - // Check for the preview header title text - expect(screen.getByText('datasetCreation.stepTwo.preview'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.preview')).toBeInTheDocument() }) it('should render idle state when isIdle is true', () => { render() - expect(screen.getByText(/previewChunkTip/i))!.toBeInTheDocument() + expect(screen.getByText(/previewChunkTip/i)).toBeInTheDocument() }) it('should render loading skeleton when isPending is true', () => { render() - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers - // Should show skeleton containers // Should show skeleton containers expect(screen.queryByText(/previewChunkTip/i)).not.toBeInTheDocument() }) @@ -1874,7 +1841,7 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText('Chunk 1 content'))!.toBeInTheDocument() + expect(screen.getByText('Chunk 1 content')).toBeInTheDocument() }) it('should render QA preview when docForm is qa', () => { @@ -1888,8 +1855,8 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText('Q1'))!.toBeInTheDocument() - expect(screen.getByText('A1'))!.toBeInTheDocument() + expect(screen.getByText('Q1')).toBeInTheDocument() + expect(screen.getByText('A1')).toBeInTheDocument() }) it('should show chunk count badge for non-QA doc form', () => { @@ -1903,7 +1870,7 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText(/25/))!.toBeInTheDocument() + expect(screen.getByText(/25/)).toBeInTheDocument() }) it('should render parent-child preview when docForm is parentChild', () => { @@ -1926,13 +1893,11 @@ describe('PreviewPanel', () => { ) // Should render parent chunk label - // Should render parent chunk label - expect(screen.getByText('Chunk-1'))!.toBeInTheDocument() + expect(screen.getByText('Chunk-1')).toBeInTheDocument() // Should render child chunks - // Should render child chunks - expect(screen.getByText('Child 1'))!.toBeInTheDocument() - expect(screen.getByText('Child 2'))!.toBeInTheDocument() - expect(screen.getByText('Child 3'))!.toBeInTheDocument() + expect(screen.getByText('Child 1')).toBeInTheDocument() + expect(screen.getByText('Child 2')).toBeInTheDocument() + expect(screen.getByText('Child 3')).toBeInTheDocument() }) it('should limit child chunks when chunkForContext is full-doc', () => { @@ -1955,43 +1920,10 @@ describe('PreviewPanel', () => { ) // Should render parent chunk - // Should render parent chunk - expect(screen.getByText('Chunk-1'))!.toBeInTheDocument() + expect(screen.getByText('Chunk-1')).toBeInTheDocument() // full-doc mode limits to FULL_DOC_PREVIEW_LENGTH (50) - // full-doc mode limits to FULL_DOC_PREVIEW_LENGTH (50) - expect(screen.getByText('ChildChunk1'))!.toBeInTheDocument() - expect(screen.getByText('ChildChunk50'))!.toBeInTheDocument() - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit - // Should not render beyond the limit + expect(screen.getByText('ChildChunk1')).toBeInTheDocument() + expect(screen.getByText('ChildChunk50')).toBeInTheDocument() // Should not render beyond the limit expect(screen.queryByText('ChildChunk51')).not.toBeInTheDocument() }) @@ -2012,10 +1944,10 @@ describe('PreviewPanel', () => { />, ) - expect(screen.getByText('Chunk-1'))!.toBeInTheDocument() - expect(screen.getByText('Chunk-2'))!.toBeInTheDocument() - expect(screen.getByText('P1-C1'))!.toBeInTheDocument() - expect(screen.getByText('P2-C1'))!.toBeInTheDocument() + expect(screen.getByText('Chunk-1')).toBeInTheDocument() + expect(screen.getByText('Chunk-2')).toBeInTheDocument() + expect(screen.getByText('P1-C1')).toBeInTheDocument() + expect(screen.getByText('P2-C1')).toBeInTheDocument() }) }) @@ -2290,20 +2222,19 @@ describe('StepTwo Component', () => { describe('Rendering', () => { it('should render without crashing', () => { render() - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) it('should show general chunking options when not in upload', () => { render() // Should render the segmentation section - // Should render the segmentation section - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) it('should show footer with Previous and Next buttons', () => { render() - expect(screen.getByText(/stepTwo\.previousStep/i))!.toBeInTheDocument() - expect(screen.getByText(/stepTwo\.nextStep/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.previousStep/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.nextStep/i)).toBeInTheDocument() }) }) @@ -2351,7 +2282,7 @@ describe('StepTwo Component', () => { render() // GeneralChunkingOptions renders a "Preview Chunk" button const previewButtons = screen.getAllByText(/stepTwo\.previewChunk/i) - fireEvent.click(previewButtons[0]!) + fireEvent.click(previewButtons[0]) // updatePreview calls estimateHook.fetchEstimate() // No error means the handler executed successfully }) @@ -2361,7 +2292,7 @@ describe('StepTwo Component', () => { // ParentChildOptions renders an OptionCard; find the title element and click its parent card const parentChildTitles = screen.getAllByText(/stepTwo\.parentChild/i) // The first match is the title; click it to trigger onDocFormChange - fireEvent.click(parentChildTitles[0]!) + fireEvent.click(parentChildTitles[0]) // handleDocFormChange sets docForm, segmentationType, and resets estimate }) }) @@ -2376,8 +2307,7 @@ describe('StepTwo Component', () => { />, ) // When currentDataset has parentChild doc_form, should show parent-child option - // When currentDataset has parentChild doc_form, should show parent-child option - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) it('should render setting mode with Save/Cancel buttons', () => { @@ -2390,8 +2320,8 @@ describe('StepTwo Component', () => { datasetId="test-id" />, ) - expect(screen.getByText(/stepTwo\.save/i))!.toBeInTheDocument() - expect(screen.getByText(/stepTwo\.cancel/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.save/i)).toBeInTheDocument() + expect(screen.getByText(/stepTwo\.cancel/i)).toBeInTheDocument() }) it('should call onCancel when Cancel button is clicked in setting mode', () => { @@ -2432,9 +2362,8 @@ describe('StepTwo Component', () => { it('should show both general and parent-child options in create page', () => { render() // When isInInit (no datasetId, no isSetting), both options should show - // When isInInit (no datasetId, no isSetting), both options should show - expect(screen.getByText('datasetCreation.stepTwo.general'))!.toBeInTheDocument() - expect(screen.getByText('datasetCreation.stepTwo.parentChild'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.general')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.parentChild')).toBeInTheDocument() }) it('should only show parent-child option when dataset has parentChild doc_form', () => { @@ -2447,9 +2376,7 @@ describe('StepTwo Component', () => { ) // showGeneralOption should be false (parentChild not in [text, qa]) // showParentChildOption should be true - // showGeneralOption should be false (parentChild not in [text, qa]) - // showParentChildOption should be true - expect(screen.getByText('datasetCreation.stepTwo.parentChild'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.parentChild')).toBeInTheDocument() }) it('should show general option only when dataset has text doc_form', () => { @@ -2461,8 +2388,7 @@ describe('StepTwo Component', () => { />, ) // showGeneralOption should be true (text is in [text, qa]) - // showGeneralOption should be true (text is in [text, qa]) - expect(screen.getByText('datasetCreation.stepTwo.general'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.general')).toBeInTheDocument() }) }) @@ -2475,7 +2401,7 @@ describe('StepTwo Component', () => { datasetId="test-id" />, ) - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) it('should show general option for empty dataset (no doc_form)', () => { @@ -2487,7 +2413,7 @@ describe('StepTwo Component', () => { datasetId="test-id" />, ) - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) it('should show both options in empty dataset upload', () => { @@ -2500,9 +2426,8 @@ describe('StepTwo Component', () => { />, ) // isUploadInEmptyDataset=true shows both options - // isUploadInEmptyDataset=true shows both options - expect(screen.getByText('datasetCreation.stepTwo.general'))!.toBeInTheDocument() - expect(screen.getByText('datasetCreation.stepTwo.parentChild'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.general')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.parentChild')).toBeInTheDocument() }) }) @@ -2510,22 +2435,19 @@ describe('StepTwo Component', () => { it('should render indexing mode section', () => { render() // IndexingModeSection renders the index mode title - // IndexingModeSection renders the index mode title - expect(screen.getByText(/stepTwo\.indexMode/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.indexMode/i)).toBeInTheDocument() }) it('should render embedding model selector when QUALIFIED', () => { render() // ModelSelector is mocked and rendered with data-testid - // ModelSelector is mocked and rendered with data-testid - expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() + expect(screen.getByTestId('model-selector')).toBeInTheDocument() }) it('should render retrieval method config', () => { render() // RetrievalMethodConfig is mocked with data-testid - // RetrievalMethodConfig is mocked with data-testid - expect(screen.getByTestId('retrieval-method-config'))!.toBeInTheDocument() + expect(screen.getByTestId('retrieval-method-config')).toBeInTheDocument() }) it('should disable model and retrieval config when datasetId has existing data source', () => { @@ -2538,14 +2460,14 @@ describe('StepTwo Component', () => { ) // isModelAndRetrievalConfigDisabled should be true const modelSelector = screen.getByTestId('model-selector') - expect(modelSelector)!.toHaveAttribute('data-readonly', 'true') + expect(modelSelector).toHaveAttribute('data-readonly', 'true') }) }) describe('Preview Panel', () => { it('should render preview panel', () => { render() - expect(screen.getByText('datasetCreation.stepTwo.preview'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.preview')).toBeInTheDocument() }) it('should hide document picker in setting mode', () => { @@ -2559,8 +2481,7 @@ describe('StepTwo Component', () => { />, ) // Preview panel should still render - // Preview panel should still render - expect(screen.getByText('datasetCreation.stepTwo.preview'))!.toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepTwo.preview')).toBeInTheDocument() }) }) @@ -2577,35 +2498,35 @@ describe('StepTwo Component', () => { it('should switch to QUALIFIED when selecting parentChild in ECONOMICAL mode', async () => { render() await vi.waitFor(() => { - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) const parentChildTitles = screen.getAllByText(/stepTwo\.parentChild/i) - fireEvent.click(parentChildTitles[0]!) + fireEvent.click(parentChildTitles[0]) }) it('should open QA confirm dialog and confirm switch when QA selected in ECONOMICAL mode', async () => { render() await vi.waitFor(() => { - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) const qaCheckbox = screen.getByText(/stepTwo\.useQALanguage/i) fireEvent.click(qaCheckbox) // Dialog should open → click Switch to confirm (triggers handleQAConfirm) const switchButton = await screen.findByText(/stepTwo\.switch/i) - expect(switchButton)!.toBeInTheDocument() + expect(switchButton).toBeInTheDocument() fireEvent.click(switchButton) }) it('should close QA confirm dialog when cancel is clicked', async () => { render() await vi.waitFor(() => { - expect(screen.getByText(/stepTwo\.segmentation/i))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.segmentation/i)).toBeInTheDocument() }) // Open QA confirm dialog const qaCheckbox = screen.getByText(/stepTwo\.useQALanguage/i) fireEvent.click(qaCheckbox) const dialogCancelButtons = await screen.findAllByText(/stepTwo\.cancel/i) - fireEvent.click(dialogCancelButtons[0]!) + fireEvent.click(dialogCancelButtons[0]) }) it('should handle picker change when selecting a different file', () => { @@ -2624,7 +2545,7 @@ describe('StepTwo Component', () => { render() // The default maxChunkLength (1024) now exceeds the limit (100) const previewButtons = screen.getAllByText(/stepTwo\.previewChunk/i) - fireEvent.click(previewButtons[0]!) + fireEvent.click(previewButtons[0]) // Restore document.body.removeAttribute('data-public-indexing-max-segmentation-tokens-length') }) diff --git a/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx b/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx index f1ab5392ce..2c0480e508 100644 --- a/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx +++ b/web/app/components/datasets/create/step-two/components/__tests__/inputs.spec.tsx @@ -12,27 +12,26 @@ describe('DelimiterInput', () => { it('should render separator label', () => { render() - expect(screen.getByText(`${ns}.stepTwo.separator`))!.toBeInTheDocument() + expect(screen.getByText(`${ns}.stepTwo.separator`)).toBeInTheDocument() }) it('should render text input with placeholder', () => { render() const input = screen.getByPlaceholderText(`${ns}.stepTwo.separatorPlaceholder`) - expect(input)!.toBeInTheDocument() - expect(input)!.toHaveAttribute('type', 'text') + expect(input).toBeInTheDocument() + expect(input).toHaveAttribute('type', 'text') }) it('should pass through value and onChange props', () => { const onChange = vi.fn() render() - expect(screen.getByDisplayValue('test-val'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('test-val')).toBeInTheDocument() }) it('should render tooltip content', () => { render() // Tooltip triggers render; component mounts without error - // Tooltip triggers render; component mounts without error - expect(screen.getByText(`${ns}.stepTwo.separator`))!.toBeInTheDocument() + expect(screen.getByText(`${ns}.stepTwo.separator`)).toBeInTheDocument() }) it('should suppress onChange during IME composition', () => { @@ -48,7 +47,7 @@ describe('DelimiterInput', () => { fireEvent.compositionEnd(input) expect(onChange).toHaveBeenCalledTimes(1) - expect(onChange.mock.calls[0]![0].target.value).toBe(finalValue) + expect(onChange.mock.calls[0][0].target.value).toBe(finalValue) }) }) @@ -59,24 +58,24 @@ describe('MaxLengthInput', () => { it('should render max length label', () => { render() - expect(screen.getByText(`${ns}.stepTwo.maxLength`))!.toBeInTheDocument() + expect(screen.getByText(`${ns}.stepTwo.maxLength`)).toBeInTheDocument() }) it('should render number input', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toBeInTheDocument() + expect(input).toBeInTheDocument() }) it('should accept value prop', () => { render() - expect(screen.getByRole('textbox'))!.toHaveValue('500') + expect(screen.getByRole('textbox')).toHaveValue('500') }) it('should have min of 1', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toBeInTheDocument() + expect(input).toBeInTheDocument() }) it('should reset to the minimum when users clear the value', () => { @@ -108,18 +107,18 @@ describe('OverlapInput', () => { it('should render number input', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toBeInTheDocument() + expect(input).toBeInTheDocument() }) it('should accept value prop', () => { render() - expect(screen.getByRole('textbox'))!.toHaveValue('50') + expect(screen.getByRole('textbox')).toHaveValue('50') }) it('should have min of 1', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toBeInTheDocument() + expect(input).toBeInTheDocument() }) it('should reset to the minimum when users clear the value', () => { diff --git a/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx b/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx index 946c04aa93..313ad9c051 100644 --- a/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx +++ b/web/app/components/datasets/create/website/firecrawl/__tests__/options.spec.tsx @@ -35,10 +35,9 @@ describe('Options', () => { render() // Check that key elements are rendered - // Check that key elements are rendered - expect(screen.getByText(/crawlSubPage/i))!.toBeInTheDocument() - expect(screen.getByText(/limit/i))!.toBeInTheDocument() - expect(screen.getByText(/maxDepth/i))!.toBeInTheDocument() + expect(screen.getByText(/crawlSubPage/i)).toBeInTheDocument() + expect(screen.getByText(/limit/i)).toBeInTheDocument() + expect(screen.getByText(/maxDepth/i)).toBeInTheDocument() }) it('should render all form fields', () => { @@ -46,16 +45,14 @@ describe('Options', () => { render() // Checkboxes - // Checkboxes - expect(screen.getByText(/crawlSubPage/i))!.toBeInTheDocument() - expect(screen.getByText(/extractOnlyMainContent/i))!.toBeInTheDocument() + expect(screen.getByText(/crawlSubPage/i)).toBeInTheDocument() + expect(screen.getByText(/extractOnlyMainContent/i)).toBeInTheDocument() // Text/Number fields - // Text/Number fields - expect(screen.getByText(/limit/i))!.toBeInTheDocument() - expect(screen.getByText(/maxDepth/i))!.toBeInTheDocument() - expect(screen.getByText(/excludePaths/i))!.toBeInTheDocument() - expect(screen.getByText(/includeOnlyPaths/i))!.toBeInTheDocument() + expect(screen.getByText(/limit/i)).toBeInTheDocument() + expect(screen.getByText(/maxDepth/i)).toBeInTheDocument() + expect(screen.getByText(/excludePaths/i)).toBeInTheDocument() + expect(screen.getByText(/includeOnlyPaths/i)).toBeInTheDocument() }) it('should render with custom className', () => { @@ -65,7 +62,7 @@ describe('Options', () => { ) const rootElement = container.firstChild as HTMLElement - expect(rootElement)!.toHaveClass('custom-class') + expect(rootElement).toHaveClass('custom-class') }) it('should render limit field with required indicator', () => { @@ -74,7 +71,7 @@ describe('Options', () => { // Limit field should have required indicator (*) const requiredIndicator = screen.getByText('*') - expect(requiredIndicator)!.toBeInTheDocument() + expect(requiredIndicator).toBeInTheDocument() }) it('should render placeholder for excludes field', () => { @@ -82,7 +79,7 @@ describe('Options', () => { render() const excludesInput = screen.getByPlaceholderText('blog/*, /about/*') - expect(excludesInput)!.toBeInTheDocument() + expect(excludesInput).toBeInTheDocument() }) it('should render placeholder for includes field', () => { @@ -90,7 +87,7 @@ describe('Options', () => { render() const includesInput = screen.getByPlaceholderText('articles/*') - expect(includesInput)!.toBeInTheDocument() + expect(includesInput).toBeInTheDocument() }) it('should render two checkboxes', () => { @@ -109,8 +106,7 @@ describe('Options', () => { render() // First checkbox should have check icon when checked - // First checkbox should have check icon when checked - expect(screen.queryByTestId('check-icon-crawl-sub-page'))!.toBeInTheDocument() + expect(screen.queryByTestId('check-icon-crawl-sub-page')).toBeInTheDocument() }) it('should display crawl_sub_pages checkbox without check icon when false', () => { @@ -122,7 +118,7 @@ describe('Options', () => { it('should display only_main_content checkbox with check icon when true', () => { const payload = createMockCrawlOptions({ only_main_content: true }) render() - expect(screen.getByTestId('check-icon-only-main-content'))!.toBeInTheDocument() + expect(screen.getByTestId('check-icon-only-main-content')).toBeInTheDocument() }) it('should display only_main_content checkbox without check icon when false', () => { @@ -136,7 +132,7 @@ describe('Options', () => { render() const limitInput = screen.getByDisplayValue('25') - expect(limitInput)!.toBeInTheDocument() + expect(limitInput).toBeInTheDocument() }) it('should display max_depth value in input', () => { @@ -144,7 +140,7 @@ describe('Options', () => { render() const maxDepthInput = screen.getByDisplayValue('5') - expect(maxDepthInput)!.toBeInTheDocument() + expect(maxDepthInput).toBeInTheDocument() }) it('should display excludes value in input', () => { @@ -152,7 +148,7 @@ describe('Options', () => { render() const excludesInput = screen.getByDisplayValue('test/*') - expect(excludesInput)!.toBeInTheDocument() + expect(excludesInput).toBeInTheDocument() }) it('should display includes value in input', () => { @@ -160,7 +156,7 @@ describe('Options', () => { render() const includesInput = screen.getByDisplayValue('docs/*') - expect(includesInput)!.toBeInTheDocument() + expect(includesInput).toBeInTheDocument() }) }) @@ -170,7 +166,7 @@ describe('Options', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[0]!) + fireEvent.click(checkboxes[0]) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -183,7 +179,7 @@ describe('Options', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[1]!) + fireEvent.click(checkboxes[1]) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -255,8 +251,7 @@ describe('Options', () => { render() // Component should render without crashing - // Component should render without crashing - expect(screen.getByText(/limit/i))!.toBeInTheDocument() + expect(screen.getByText(/limit/i)).toBeInTheDocument() }) it('should handle zero values', () => { @@ -278,8 +273,8 @@ describe('Options', () => { }) render() - expect(screen.getByDisplayValue('9999'))!.toBeInTheDocument() - expect(screen.getByDisplayValue('100'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('9999')).toBeInTheDocument() + expect(screen.getByDisplayValue('100')).toBeInTheDocument() }) it('should handle special characters in text fields', () => { @@ -289,8 +284,8 @@ describe('Options', () => { }) render() - expect(screen.getByDisplayValue('path/*/file?query=1¶m=2'))!.toBeInTheDocument() - expect(screen.getByDisplayValue('docs/**/*.md'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('path/*/file?query=1¶m=2')).toBeInTheDocument() + expect(screen.getByDisplayValue('docs/**/*.md')).toBeInTheDocument() }) it('should preserve other payload fields when updating one field', () => { @@ -362,7 +357,7 @@ describe('Options', () => { rerender() - expect(screen.getByText(/limit/i))!.toBeInTheDocument() + expect(screen.getByText(/limit/i)).toBeInTheDocument() }) it('should re-render when payload changes', () => { @@ -370,10 +365,10 @@ describe('Options', () => { const payload2 = createMockCrawlOptions({ limit: 20 }) const { rerender } = render() - expect(screen.getByDisplayValue('10'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('10')).toBeInTheDocument() rerender() - expect(screen.getByDisplayValue('20'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('20')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx b/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx index b65df109ad..bda01dc152 100644 --- a/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx +++ b/web/app/components/datasets/create/website/watercrawl/__tests__/options.spec.tsx @@ -34,12 +34,12 @@ describe('Options (watercrawl)', () => { const payload = createMockCrawlOptions() render() - expect(screen.getByText(/crawlSubPage/i))!.toBeInTheDocument() - expect(screen.getByText(/extractOnlyMainContent/i))!.toBeInTheDocument() - expect(screen.getByText(/limit/i))!.toBeInTheDocument() - expect(screen.getByText(/maxDepth/i))!.toBeInTheDocument() - expect(screen.getByText(/excludePaths/i))!.toBeInTheDocument() - expect(screen.getByText(/includeOnlyPaths/i))!.toBeInTheDocument() + expect(screen.getByText(/crawlSubPage/i)).toBeInTheDocument() + expect(screen.getByText(/extractOnlyMainContent/i)).toBeInTheDocument() + expect(screen.getByText(/limit/i)).toBeInTheDocument() + expect(screen.getByText(/maxDepth/i)).toBeInTheDocument() + expect(screen.getByText(/excludePaths/i)).toBeInTheDocument() + expect(screen.getByText(/includeOnlyPaths/i)).toBeInTheDocument() }) it('should render two checkboxes', () => { @@ -55,21 +55,21 @@ describe('Options (watercrawl)', () => { render() const requiredIndicator = screen.getByText('*') - expect(requiredIndicator)!.toBeInTheDocument() + expect(requiredIndicator).toBeInTheDocument() }) it('should render placeholder for excludes field', () => { const payload = createMockCrawlOptions() render() - expect(screen.getByPlaceholderText('blog/*, /about/*'))!.toBeInTheDocument() + expect(screen.getByPlaceholderText('blog/*, /about/*')).toBeInTheDocument() }) it('should render placeholder for includes field', () => { const payload = createMockCrawlOptions() render() - expect(screen.getByPlaceholderText('articles/*'))!.toBeInTheDocument() + expect(screen.getByPlaceholderText('articles/*')).toBeInTheDocument() }) it('should render with custom className', () => { @@ -79,7 +79,7 @@ describe('Options (watercrawl)', () => { ) const rootElement = container.firstChild as HTMLElement - expect(rootElement)!.toHaveClass('custom-class') + expect(rootElement).toHaveClass('custom-class') }) }) @@ -89,7 +89,7 @@ describe('Options (watercrawl)', () => { const payload = createMockCrawlOptions({ crawl_sub_pages: true }) render() - expect(screen.getByTestId('check-icon-crawl-sub-pages'))!.toBeInTheDocument() + expect(screen.getByTestId('check-icon-crawl-sub-pages')).toBeInTheDocument() }) it('should display crawl_sub_pages checkbox without check icon when false', () => { @@ -97,13 +97,13 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - expect(checkboxes[0]!.querySelector('svg')).not.toBeInTheDocument() + expect(checkboxes[0].querySelector('svg')).not.toBeInTheDocument() }) it('should display only_main_content checkbox with check icon when true', () => { const payload = createMockCrawlOptions({ only_main_content: true }) render() - expect(screen.getByTestId('check-icon-only-main-content'))!.toBeInTheDocument() + expect(screen.getByTestId('check-icon-only-main-content')).toBeInTheDocument() }) it('should display only_main_content checkbox without check icon when false', () => { @@ -111,35 +111,35 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - expect(checkboxes[1]!.querySelector('svg')).not.toBeInTheDocument() + expect(checkboxes[1].querySelector('svg')).not.toBeInTheDocument() }) it('should display limit value in input', () => { const payload = createMockCrawlOptions({ limit: 25 }) render() - expect(screen.getByDisplayValue('25'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('25')).toBeInTheDocument() }) it('should display max_depth value in input', () => { const payload = createMockCrawlOptions({ max_depth: 5 }) render() - expect(screen.getByDisplayValue('5'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('5')).toBeInTheDocument() }) it('should display excludes value in input', () => { const payload = createMockCrawlOptions({ excludes: 'test/*' }) render() - expect(screen.getByDisplayValue('test/*'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('test/*')).toBeInTheDocument() }) it('should display includes value in input', () => { const payload = createMockCrawlOptions({ includes: 'docs/*' }) render() - expect(screen.getByDisplayValue('docs/*'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('docs/*')).toBeInTheDocument() }) }) @@ -149,7 +149,7 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[0]!) + fireEvent.click(checkboxes[0]) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -162,7 +162,7 @@ describe('Options (watercrawl)', () => { const { container } = render() const checkboxes = getCheckboxes(container) - fireEvent.click(checkboxes[1]!) + fireEvent.click(checkboxes[1]) expect(mockOnChange).toHaveBeenCalledWith({ ...payload, @@ -264,10 +264,10 @@ describe('Options (watercrawl)', () => { const payload2 = createMockCrawlOptions({ limit: 20 }) const { rerender } = render() - expect(screen.getByDisplayValue('10'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('10')).toBeInTheDocument() rerender() - expect(screen.getByDisplayValue('20'))!.toBeInTheDocument() + expect(screen.getByDisplayValue('20')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx index c89059d185..0d60ef86db 100644 --- a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx @@ -105,7 +105,7 @@ describe('Operations', () => { describe('rendering', () => { it('should render without crashing', () => { render() - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) it('should render buttons when embeddingAvailable', () => { @@ -122,7 +122,7 @@ describe('Operations', () => { it('should render disabled switch when embeddingAvailable is false in list scene', () => { render() const disabledSwitch = screen.getByRole('switch') - expect(disabledSwitch)!.toHaveAttribute('aria-disabled', 'true') + expect(disabledSwitch).toHaveAttribute('aria-disabled', 'true') }) }) @@ -209,7 +209,7 @@ describe('Operations', () => { const buttons = screen.getAllByRole('button') const settingsButton = buttons[0] await act(async () => { - fireEvent.click(settingsButton!) + fireEvent.click(settingsButton) }) expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1/settings') }) @@ -219,7 +219,7 @@ describe('Operations', () => { it('should render differently in detail scene', () => { render() const container = document.querySelector('.flex.items-center') - expect(container)!.toBeInTheDocument() + expect(container).toBeInTheDocument() }) it('should not render switch in detail scene', () => { @@ -239,7 +239,7 @@ describe('Operations', () => { onSelectedIdChange={mockOnSelectedIdChange} />, ) - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) }) @@ -257,8 +257,7 @@ describe('Operations', () => { render() await openPopover() // Check if popover content is visible - // Check if popover content is visible - expect(screen.getByText('datasetDocuments.list.table.rename'))!.toBeInTheDocument() + expect(screen.getByText('datasetDocuments.list.table.rename')).toBeInTheDocument() }) it('should call archive when archive action is clicked', async () => { @@ -298,8 +297,7 @@ describe('Operations', () => { fireEvent.click(deleteButton) }) // Check if confirmation modal is shown - // Check if confirmation modal is shown - expect(screen.getByText('datasetDocuments.list.delete.title'))!.toBeInTheDocument() + expect(screen.getByText('datasetDocuments.list.delete.title')).toBeInTheDocument() }) it('should call delete when confirm is clicked in delete modal', async () => { @@ -326,8 +324,7 @@ describe('Operations', () => { fireEvent.click(deleteButton) }) // Verify modal is shown - // Verify modal is shown - expect(screen.getByText('datasetDocuments.list.delete.title'))!.toBeInTheDocument() + expect(screen.getByText('datasetDocuments.list.delete.title')).toBeInTheDocument() // Find and click the cancel button const cancelButton = screen.getByText('common.operation.cancel') await act(async () => { @@ -369,7 +366,7 @@ describe('Operations', () => { await user.click(renameAction) const renameInput = await screen.findByRole('textbox') - expect(renameInput)!.toHaveValue('Test Document') + expect(renameInput).toHaveValue('Test Document') }) it('should call sync for notion data source', async () => { @@ -461,7 +458,7 @@ describe('Operations', () => { />, ) await openPopover() - expect(screen.getByText('datasetDocuments.list.action.download'))!.toBeInTheDocument() + expect(screen.getByText('datasetDocuments.list.action.download')).toBeInTheDocument() }) it('should download archived file when download is clicked', async () => { @@ -546,7 +543,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, display_status: 'indexing' }} />, ) - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) it('should render resume action when status is paused', () => { @@ -556,7 +553,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, display_status: 'paused' }} />, ) - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) it('should not show pause/resume for available status', async () => { @@ -585,7 +582,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, data_source_type: 'notion_import' }} />, ) - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) it('should handle web data source type', () => { @@ -595,7 +592,7 @@ describe('Operations', () => { detail={{ ...defaultDetail, data_source_type: 'website_crawl' }} />, ) - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) it('should not show download for non-file data source', async () => { @@ -625,7 +622,7 @@ describe('Operations', () => { it('should accept custom className prop', () => { // The className is passed to CustomPopover, verify component renders without errors render() - expect(document.querySelector('.flex.items-center'))!.toBeInTheDocument() + expect(document.querySelector('.flex.items-center')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx index 01d6299492..97ae1c92a1 100644 --- a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx @@ -112,26 +112,25 @@ describe('DocumentList', () => { describe('Rendering', () => { it('should render without crashing', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render all documents', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('Document 1.txt'))!.toBeInTheDocument() - expect(screen.getByText('Document 2.txt'))!.toBeInTheDocument() - expect(screen.getByText('Document 3.txt'))!.toBeInTheDocument() + expect(screen.getByText('Document 1.txt')).toBeInTheDocument() + expect(screen.getByText('Document 2.txt')).toBeInTheDocument() + expect(screen.getByText('Document 3.txt')).toBeInTheDocument() }) it('should render table headers', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('#'))!.toBeInTheDocument() + expect(screen.getByText('#')).toBeInTheDocument() }) it('should render pagination when total is provided', () => { render(, { wrapper: createWrapper() }) // Pagination component should be present - // Pagination component should be present - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should not render pagination when total is 0', () => { @@ -140,13 +139,13 @@ describe('DocumentList', () => { pagination: { ...defaultPagination, total: 0 }, } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render empty table when no documents', () => { const props = { ...defaultProps, documents: [] } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) }) @@ -166,8 +165,7 @@ describe('DocumentList', () => { const props = { ...defaultProps, embeddingAvailable: false } render(, { wrapper: createWrapper() }) // Row checkboxes should still be there, but header checkbox should be hidden - // Row checkboxes should still be there, but header checkbox should be hidden - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should call onSelectedIdChange when select all is clicked', () => { @@ -177,7 +175,7 @@ describe('DocumentList', () => { const checkboxes = findCheckboxes(container) if (checkboxes.length > 0) { - fireEvent.click(checkboxes[0]!) + fireEvent.click(checkboxes[0]) expect(onSelectedIdChange).toHaveBeenCalled() } }) @@ -192,7 +190,7 @@ describe('DocumentList', () => { // When checked, checkbox should have a check icon (svg) inside props.selectedIds.forEach((id) => { const checkIcon = screen.getByTestId(`check-icon-doc-row-${id}`) - expect(checkIcon)!.toBeInTheDocument() + expect(checkIcon).toBeInTheDocument() }) }) @@ -208,9 +206,7 @@ describe('DocumentList', () => { expect(checkboxes.length).toBeGreaterThan(0) // Header checkbox should show indeterminate icon, not check icon // Just verify it's rendered - // Header checkbox should show indeterminate icon, not check icon - // Just verify it's rendered - expect(checkboxes[0])!.toBeInTheDocument() + expect(checkboxes[0]).toBeInTheDocument() }) it('should call onSelectedIdChange with single document when row checkbox is clicked', () => { @@ -220,7 +216,7 @@ describe('DocumentList', () => { const checkboxes = findCheckboxes(container) if (checkboxes.length > 1) { - fireEvent.click(checkboxes[1]!) + fireEvent.click(checkboxes[1]) expect(onSelectedIdChange).toHaveBeenCalled() } }) @@ -240,7 +236,7 @@ describe('DocumentList', () => { const sortableHeaders = container.querySelectorAll('thead button') if (sortableHeaders.length > 0) - fireEvent.click(sortableHeaders[0]!) + fireEvent.click(sortableHeaders[0]) expect(onSortChange).toHaveBeenCalled() }) @@ -255,16 +251,14 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction component should be visible - // BatchAction component should be visible - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should not show batch action bar when no documents selected', () => { render(, { wrapper: createWrapper() }) // BatchAction should not be present - // BatchAction should not be present - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render batch action bar with archive option', () => { @@ -275,8 +269,7 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction component should be visible when documents are selected - // BatchAction component should be visible when documents are selected - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render batch action bar with enable option', () => { @@ -286,7 +279,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render batch action bar with disable option', () => { @@ -296,7 +289,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render batch action bar with delete option', () => { @@ -306,7 +299,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should clear selection when cancel is clicked', () => { @@ -336,8 +329,7 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction should be visible - // BatchAction should be visible - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should show re-index option for error documents', () => { @@ -351,8 +343,7 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // BatchAction with re-index should be present for error documents - // BatchAction with re-index should be present for error documents - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) }) @@ -363,7 +354,7 @@ describe('DocumentList', () => { const rows = screen.getAllByRole('row') // First row is header, second row is first document if (rows.length > 1) { - fireEvent.click(rows[1]!) + fireEvent.click(rows[1]) expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1') } }) @@ -385,11 +376,11 @@ describe('DocumentList', () => { const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') if (renameButtons.length > 0) { await act(async () => { - fireEvent.click(renameButtons[0]!) + fireEvent.click(renameButtons[0]) }) } - expect(screen.getByRole('dialog', { name: 'datasetDocuments.list.table.rename' }))!.toBeInTheDocument() + expect(screen.getByRole('dialog', { name: 'datasetDocuments.list.table.rename' })).toBeInTheDocument() }) it('should call onUpdate when document is renamed', () => { @@ -398,8 +389,7 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // The handleRenamed callback wraps onUpdate - // The handleRenamed callback wraps onUpdate - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) }) @@ -418,7 +408,7 @@ describe('DocumentList', () => { }) } - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should call onManageMetadata when manage metadata is triggered', () => { @@ -431,27 +421,26 @@ describe('DocumentList', () => { render(, { wrapper: createWrapper() }) // The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata - // The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) }) describe('Chunking Mode', () => { it('should render with general mode', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render with QA mode', () => { // This test uses the default mock which returns ChunkingMode.text // The component will compute isQAMode based on doc_form render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should render with parent-child mode', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) }) @@ -460,7 +449,7 @@ describe('DocumentList', () => { const props = { ...defaultProps, documents: [] } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should handle documents with missing optional fields', () => { @@ -474,7 +463,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should handle remote sort value', () => { @@ -484,7 +473,7 @@ describe('DocumentList', () => { } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }) it('should handle large number of documents', () => { @@ -493,7 +482,7 @@ describe('DocumentList', () => { const props = { ...defaultProps, documents: manyDocs } render(, { wrapper: createWrapper() }) - expect(screen.getByRole('table'))!.toBeInTheDocument() + expect(screen.getByRole('table')).toBeInTheDocument() }, 10000) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx index b6b02ed829..d5e4f480be 100644 --- a/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/__tests__/document-table-row.spec.tsx @@ -103,23 +103,23 @@ describe('DocumentTableRow', () => { describe('Rendering', () => { it('should render without crashing', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('test-document.txt'))!.toBeInTheDocument() + expect(screen.getByText('test-document.txt')).toBeInTheDocument() }) it('should render index number correctly', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('6'))!.toBeInTheDocument() + expect(screen.getByText('6')).toBeInTheDocument() }) it('should render document name with tooltip', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByText('test-document.txt'))!.toBeInTheDocument() + expect(screen.getByText('test-document.txt')).toBeInTheDocument() }) it('should render checkbox element', () => { const { container } = render(, { wrapper: createWrapper() }) const checkbox = findCheckbox(container) - expect(checkbox)!.toBeInTheDocument() + expect(checkbox).toBeInTheDocument() }) }) @@ -127,14 +127,14 @@ describe('DocumentTableRow', () => { it('should show check icon when isSelected is true', () => { const { container } = render(, { wrapper: createWrapper() }) const checkbox = findCheckbox(container) - expect(checkbox)!.toBeInTheDocument() - expect(screen.getByTestId('check-icon-doc-row-doc-1'))!.toBeInTheDocument() + expect(checkbox).toBeInTheDocument() + expect(screen.getByTestId('check-icon-doc-row-doc-1')).toBeInTheDocument() }) it('should not show check icon when isSelected is false', () => { const { container } = render(, { wrapper: createWrapper() }) const checkbox = findCheckbox(container) - expect(checkbox)!.toBeInTheDocument() + expect(checkbox).toBeInTheDocument() expect(screen.queryByTestId('check-icon-doc-row-doc-1')).not.toBeInTheDocument() }) @@ -200,13 +200,13 @@ describe('DocumentTableRow', () => { it('should display word count less than 1000 as is', () => { const doc = createMockDoc({ word_count: 500 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('500'))!.toBeInTheDocument() + expect(screen.getByText('500')).toBeInTheDocument() }) it('should display word count 1000 or more in k format', () => { const doc = createMockDoc({ word_count: 1500 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('1.5k'))!.toBeInTheDocument() + expect(screen.getByText('1.5k')).toBeInTheDocument() }) it('should display 0 with empty style when word_count is 0', () => { @@ -219,7 +219,7 @@ describe('DocumentTableRow', () => { it('should handle undefined word_count', () => { const doc = createMockDoc({ word_count: undefined as unknown as number }) const { container } = render(, { wrapper: createWrapper() }) - expect(container)!.toBeInTheDocument() + expect(container).toBeInTheDocument() }) }) @@ -227,13 +227,13 @@ describe('DocumentTableRow', () => { it('should display hit count less than 1000 as is', () => { const doc = createMockDoc({ hit_count: 100 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('100'))!.toBeInTheDocument() + expect(screen.getByText('100')).toBeInTheDocument() }) it('should display hit count 1000 or more in k format', () => { const doc = createMockDoc({ hit_count: 2500 }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('2.5k'))!.toBeInTheDocument() + expect(screen.getByText('2.5k')).toBeInTheDocument() }) it('should display 0 with empty style when hit_count is 0', () => { @@ -248,13 +248,12 @@ describe('DocumentTableRow', () => { it('should render ChunkingModeLabel with general mode', () => { render(, { wrapper: createWrapper() }) // ChunkingModeLabel should be rendered - // ChunkingModeLabel should be rendered - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) it('should render ChunkingModeLabel with QA mode', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) }) @@ -262,13 +261,13 @@ describe('DocumentTableRow', () => { it('should render SummaryStatus when summary_index_status is present', () => { const doc = createMockDoc({ summary_index_status: 'completed' }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) it('should not render SummaryStatus when summary_index_status is absent', () => { const doc = createMockDoc({ summary_index_status: undefined }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) }) @@ -283,7 +282,7 @@ describe('DocumentTableRow', () => { // Find the rename button by finding the RiEditLine icon's parent const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') if (renameButtons.length > 0) { - fireEvent.click(renameButtons[0]!) + fireEvent.click(renameButtons[0]) expect(onShowRenameModal).toHaveBeenCalledWith(defaultProps.doc) expect(mockPush).not.toHaveBeenCalled() } @@ -293,13 +292,13 @@ describe('DocumentTableRow', () => { describe('Operations', () => { it('should pass selectedIds to Operations component', () => { render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) it('should pass onSelectedIdChange to Operations component', () => { const onSelectedIdChange = vi.fn() render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) }) @@ -307,7 +306,7 @@ describe('DocumentTableRow', () => { it('should render with FILE data source type', () => { const doc = createMockDoc({ data_source_type: DataSourceType.FILE }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) it('should render with NOTION data source type', () => { @@ -316,13 +315,13 @@ describe('DocumentTableRow', () => { data_source_info: { notion_page_icon: 'icon.png' }, }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) it('should render with WEB data source type', () => { const doc = createMockDoc({ data_source_type: DataSourceType.WEB }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) }) @@ -330,13 +329,13 @@ describe('DocumentTableRow', () => { it('should handle document with very long name', () => { const doc = createMockDoc({ name: `${'a'.repeat(500)}.txt` }) render(, { wrapper: createWrapper() }) - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) it('should handle document with special characters in name', () => { const doc = createMockDoc({ name: '.txt' }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('.txt'))!.toBeInTheDocument() + expect(screen.getByText('.txt')).toBeInTheDocument() }) it('should memoize the component', () => { @@ -344,7 +343,7 @@ describe('DocumentTableRow', () => { const { rerender } = render(, { wrapper }) rerender() - expect(screen.getByRole('row'))!.toBeInTheDocument() + expect(screen.getByRole('row')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx index c8a06ea807..899c70e216 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx @@ -43,7 +43,7 @@ const Options = ({ if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `"${firstIssue!.path.join('.')}" ${firstIssue!.message}` + const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` toast.error(errorMessage) return errorMessage } diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx index 7fde02adcd..33703d56b2 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/form.tsx @@ -33,7 +33,7 @@ const Form = ({ if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `"${firstIssue!.path.join('.')}" ${firstIssue!.message}` + const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` toast.error(errorMessage) return errorMessage } diff --git a/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx b/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx index d9427f5117..25b7abe7ea 100644 --- a/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/components/query-input/__tests__/index.spec.tsx @@ -79,17 +79,17 @@ describe('QueryInput', () => { it('should render title', () => { render() - expect(screen.getByText('datasetHitTesting.input.title'))!.toBeInTheDocument() + expect(screen.getByText('datasetHitTesting.input.title')).toBeInTheDocument() }) it('should render textarea with query text', () => { render() - expect(screen.getByTestId('textarea'))!.toBeInTheDocument() + expect(screen.getByTestId('textarea')).toBeInTheDocument() }) it('should render submit button', () => { render() - expect(screen.getByRole('button', { name: /input\.testing/ }))!.toBeInTheDocument() + expect(screen.getByRole('button', { name: /input\.testing/ })).toBeInTheDocument() }) it('should disable submit button when text is empty', () => { @@ -98,17 +98,17 @@ describe('QueryInput', () => { queries: [{ content: '', content_type: 'text_query', file_info: null }] satisfies Query[], } render() - expect(screen.getByRole('button', { name: /input\.testing/ }))!.toBeDisabled() + expect(screen.getByRole('button', { name: /input\.testing/ })).toBeDisabled() }) it('should render retrieval method for non-external mode', () => { render() - expect(screen.getByText('dataset.retrieval.semantic_search.title'))!.toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() }) it('should render settings button for external mode', () => { render() - expect(screen.getByText('datasetHitTesting.settingTitle'))!.toBeInTheDocument() + expect(screen.getByText('datasetHitTesting.settingTitle')).toBeInTheDocument() }) it('should disable submit button when text exceeds 200 characters', () => { @@ -117,15 +117,15 @@ describe('QueryInput', () => { queries: [{ content: 'a'.repeat(201), content_type: 'text_query', file_info: null }] satisfies Query[], } render() - expect(screen.getByRole('button', { name: /input\.testing/ }))!.toBeDisabled() + expect(screen.getByRole('button', { name: /input\.testing/ })).toBeDisabled() }) it('should show loading state on submit button when loading', () => { render() const submitButton = screen.getByRole('button', { name: /input\.testing/ }) - expect(submitButton)!.toBeDisabled() - expect(submitButton)!.toHaveAttribute('aria-busy', 'true') - expect(submitButton.querySelector('.animate-spin'))!.toBeInTheDocument() + expect(submitButton).toBeDisabled() + expect(submitButton).toHaveAttribute('aria-busy', 'true') + expect(submitButton.querySelector('.animate-spin')).toBeInTheDocument() }) // Cover line 83: images useMemo with image_query data @@ -141,37 +141,6 @@ describe('QueryInput', () => { ] render() - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image - // Submit should be enabled since we have text + uploaded image // Submit should be enabled since we have text + uploaded image expect(screen.getByRole('button', { name: /input\.testing/ })).not.toBeDisabled() }) @@ -184,7 +153,7 @@ describe('QueryInput', () => { // Click settings button to open modal fireEvent.click(screen.getByRole('button', { name: /settingTitle/ })) - expect(screen.getByTestId('external-retrieval-modal'))!.toBeInTheDocument() + expect(screen.getByTestId('external-retrieval-modal')).toBeInTheDocument() // Close modal fireEvent.click(screen.getByTestId('modal-close')) @@ -196,7 +165,7 @@ describe('QueryInput', () => { // Open modal fireEvent.click(screen.getByRole('button', { name: /settingTitle/ })) - expect(screen.getByTestId('external-retrieval-modal'))!.toBeInTheDocument() + expect(screen.getByTestId('external-retrieval-modal')).toBeInTheDocument() // Save settings fireEvent.click(screen.getByTestId('modal-save')) @@ -305,7 +274,7 @@ describe('QueryInput', () => { ]), ) // Should not contain image_query - const calledWith = defaultProps.setQueries.mock.calls[0]![0] as Query[] + const calledWith = defaultProps.setQueries.mock.calls[0][0] as Query[] expect(calledWith.filter(q => q.content_type === 'image_query')).toHaveLength(0) }) }) @@ -443,7 +412,7 @@ describe('QueryInput', () => { it('should show keyword_search when isEconomy is true', () => { render() - expect(screen.getByText('dataset.retrieval.keyword_search.title'))!.toBeInTheDocument() + expect(screen.getByText('dataset.retrieval.keyword_search.title')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx b/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx index 40c925222c..d9b88e20bb 100644 --- a/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx +++ b/web/app/components/datasets/metadata/edit-metadata-batch/__tests__/modal.spec.tsx @@ -120,14 +120,14 @@ describe('EditMetadataBatchModal', () => { it('should render without crashing', async () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) it('should render document count', async () => { render() await waitFor(() => { - expect(screen.getByText(/5/))!.toBeInTheDocument() + expect(screen.getByText(/5/)).toBeInTheDocument() }) }) @@ -142,8 +142,8 @@ describe('EditMetadataBatchModal', () => { it('should render field names for existing items', async () => { render() await waitFor(() => { - expect(screen.getByText('field_one'))!.toBeInTheDocument() - expect(screen.getByText('field_two'))!.toBeInTheDocument() + expect(screen.getByText('field_one')).toBeInTheDocument() + expect(screen.getByText('field_two')).toBeInTheDocument() }) }) @@ -158,7 +158,7 @@ describe('EditMetadataBatchModal', () => { it('should render select metadata modal', async () => { render() await waitFor(() => { - expect(screen.getByTestId('select-modal'))!.toBeInTheDocument() + expect(screen.getByTestId('select-modal')).toBeInTheDocument() }) }) }) @@ -169,7 +169,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) const cancelButton = screen.getByText(/cancel/i) @@ -183,7 +183,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // Find the primary save button (not the one in SelectMetadataModal) @@ -196,17 +196,17 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) const checkboxContainer = document.querySelector('[data-testid*="checkbox"]') - expect(checkboxContainer)!.toBeInTheDocument() + expect(checkboxContainer).toBeInTheDocument() if (checkboxContainer) { fireEvent.click(checkboxContainer) await waitFor(() => { const checkIcon = screen.getByTestId('check-icon-apply-to-all') - expect(checkIcon)!.toBeInTheDocument() + expect(checkIcon).toBeInTheDocument() }) } }) @@ -216,7 +216,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) }) @@ -226,7 +226,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('change-1')) @@ -239,7 +239,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('remove-1')) @@ -252,7 +252,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // First change the item @@ -269,14 +269,14 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('select-metadata')) // Should now have add-row for the new item await waitFor(() => { - expect(screen.getByTestId('add-row'))!.toBeInTheDocument() + expect(screen.getByTestId('add-row')).toBeInTheDocument() }) }) @@ -284,14 +284,14 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // First add an item fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row'))!.toBeInTheDocument() + expect(screen.getByTestId('add-row')).toBeInTheDocument() }) // Then remove it @@ -306,20 +306,20 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // First add an item fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row'))!.toBeInTheDocument() + expect(screen.getByTestId('add-row')).toBeInTheDocument() }) // Then change it fireEvent.click(screen.getByTestId('add-change-new-1')) - expect(screen.getByTestId('add-row'))!.toBeInTheDocument() + expect(screen.getByTestId('add-row')).toBeInTheDocument() }) it('should call doAddMetaData when saving new metadata with valid name', async () => { @@ -328,7 +328,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('save-metadata')) @@ -344,7 +344,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('save-metadata')) @@ -368,7 +368,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('save-metadata')) @@ -388,7 +388,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByTestId('manage-metadata')) @@ -401,14 +401,14 @@ describe('EditMetadataBatchModal', () => { it('should pass correct datasetId', async () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) it('should display correct document number', async () => { render() await waitFor(() => { - expect(screen.getByText(/10/))!.toBeInTheDocument() + expect(screen.getByText(/10/)).toBeInTheDocument() }) }) @@ -427,7 +427,7 @@ describe('EditMetadataBatchModal', () => { ] render() await waitFor(() => { - expect(screen.getByTestId('edit-row'))!.toBeInTheDocument() + expect(screen.getByTestId('edit-row')).toBeInTheDocument() }) }) @@ -436,7 +436,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // Find the primary save button @@ -453,7 +453,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) @@ -470,7 +470,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) const checkboxContainer = document.querySelector('[data-testid*="checkbox"]') @@ -493,7 +493,7 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // Remove an item @@ -503,7 +503,7 @@ describe('EditMetadataBatchModal', () => { expect(onSave).toHaveBeenCalled() // The first argument should not contain the deleted item (id '1') - const savedList = onSave.mock.calls[0]![0] as MetadataItemInBatchEdit[] + const savedList = onSave.mock.calls[0][0] as MetadataItemInBatchEdit[] const hasDeletedItem = savedList.some(item => item.id === '1') expect(hasDeletedItem).toBe(false) }) @@ -512,13 +512,13 @@ describe('EditMetadataBatchModal', () => { render() await waitFor(() => { - expect(screen.getByRole('dialog'))!.toBeInTheDocument() + expect(screen.getByRole('dialog')).toBeInTheDocument() }) // Add first item fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row'))!.toBeInTheDocument() + expect(screen.getByTestId('add-row')).toBeInTheDocument() }) // Remove it @@ -531,7 +531,7 @@ describe('EditMetadataBatchModal', () => { // Add again fireEvent.click(screen.getByTestId('select-metadata')) await waitFor(() => { - expect(screen.getByTestId('add-row'))!.toBeInTheDocument() + expect(screen.getByTestId('add-row')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx b/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx index 71f23324f6..ddd624a076 100644 --- a/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx +++ b/web/app/components/datasets/metadata/metadata-document/__tests__/index.spec.tsx @@ -99,7 +99,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(container.firstChild)!.toBeInTheDocument() + expect(container.firstChild).toBeInTheDocument() }) it('should render metadata fields when hasData is true', () => { @@ -110,8 +110,8 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(screen.getByText('field_one'))!.toBeInTheDocument() - expect(screen.getByText('field_two'))!.toBeInTheDocument() + expect(screen.getByText('field_one')).toBeInTheDocument() + expect(screen.getByText('field_two')).toBeInTheDocument() }) it('should render no-data state when hasData is false and not in edit mode', () => { @@ -147,8 +147,8 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText(/save/i))!.toBeInTheDocument() - expect(screen.getByText(/cancel/i))!.toBeInTheDocument() + expect(screen.getByText(/save/i)).toBeInTheDocument() + expect(screen.getByText(/cancel/i)).toBeInTheDocument() }) it('should render built-in section when builtInEnabled is true', () => { @@ -166,7 +166,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('created_at'))!.toBeInTheDocument() + expect(screen.getByText('created_at')).toBeInTheDocument() }) it('should render divider when builtInEnabled is true', () => { @@ -185,7 +185,7 @@ describe('MetadataDocument', () => { ) const divider = container.querySelector('.bg-linear-to-r') - expect(divider)!.toBeInTheDocument() + expect(divider).toBeInTheDocument() }) it('should render origin info section', () => { @@ -202,7 +202,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('source'))!.toBeInTheDocument() + expect(screen.getByText('source')).toBeInTheDocument() }) it('should render technical parameters section', () => { @@ -219,7 +219,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('word_count'))!.toBeInTheDocument() + expect(screen.getByText('word_count')).toBeInTheDocument() }) it('should render all sections together', () => { @@ -239,10 +239,10 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('field_one'))!.toBeInTheDocument() - expect(screen.getByText('created_at'))!.toBeInTheDocument() - expect(screen.getByText('source'))!.toBeInTheDocument() - expect(screen.getByText('word_count'))!.toBeInTheDocument() + expect(screen.getByText('field_one')).toBeInTheDocument() + expect(screen.getByText('created_at')).toBeInTheDocument() + expect(screen.getByText('source')).toBeInTheDocument() + expect(screen.getByText('word_count')).toBeInTheDocument() }) }) @@ -255,7 +255,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(screen.getByText(/edit/i))!.toBeInTheDocument() + expect(screen.getByText(/edit/i)).toBeInTheDocument() }) it('should call startToEdit when edit button is clicked', () => { @@ -362,9 +362,8 @@ describe('MetadataDocument', () => { ) // Should show save/cancel buttons - // Should show save/cancel buttons - expect(screen.getByText(/save/i))!.toBeInTheDocument() - expect(screen.getByText(/cancel/i))!.toBeInTheDocument() + expect(screen.getByText(/save/i)).toBeInTheDocument() + expect(screen.getByText(/cancel/i)).toBeInTheDocument() }) }) @@ -387,7 +386,7 @@ describe('MetadataDocument', () => { const inputs = container.querySelectorAll('input') if (inputs.length > 0) { - fireEvent.change(inputs[0]!, { target: { value: 'new value' } }) + fireEvent.change(inputs[0], { target: { value: 'new value' } }) await waitFor(() => { expect(setTempList).toHaveBeenCalled() @@ -455,7 +454,7 @@ describe('MetadataDocument', () => { const inputs = container.querySelectorAll('input') if (inputs.length > 0) { - fireEvent.change(inputs[0]!, { target: { value: 'updated' } }) + fireEvent.change(inputs[0], { target: { value: 'updated' } }) await waitFor(() => { expect(setTempList).toHaveBeenCalled() }) @@ -484,7 +483,7 @@ describe('MetadataDocument', () => { expect(deleteContainers.length).toBeGreaterThan(0) if (deleteContainers.length > 0) { - const deleteIcon = deleteContainers[0]!.querySelector('svg') + const deleteIcon = deleteContainers[0].querySelector('svg') if (deleteIcon) fireEvent.click(deleteIcon) @@ -505,7 +504,7 @@ describe('MetadataDocument', () => { className="custom-class" />, ) - expect(container.firstChild)!.toHaveClass('custom-class') + expect(container.firstChild).toHaveClass('custom-class') }) it('should use tempList when in edit mode', () => { @@ -525,7 +524,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('temp_field'))!.toBeInTheDocument() + expect(screen.getByText('temp_field')).toBeInTheDocument() }) it('should use list when not in edit mode', () => { @@ -537,8 +536,8 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('field_one'))!.toBeInTheDocument() - expect(screen.getByText('field_two'))!.toBeInTheDocument() + expect(screen.getByText('field_one')).toBeInTheDocument() + expect(screen.getByText('field_two')).toBeInTheDocument() }) it('should pass datasetId to child components', () => { @@ -550,8 +549,7 @@ describe('MetadataDocument', () => { />, ) // Component should render without errors - // Component should render without errors - expect(screen.getByText('field_one'))!.toBeInTheDocument() + expect(screen.getByText('field_one')).toBeInTheDocument() }) }) @@ -590,37 +588,6 @@ describe('MetadataDocument', () => { />, ) - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered - // NoData component should not be rendered // NoData component should not be rendered expect(screen.queryByText(/start/i)).not.toBeInTheDocument() }) @@ -640,37 +607,6 @@ describe('MetadataDocument', () => { />, ) - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined - // headerRight should be null/undefined // headerRight should be null/undefined expect(screen.queryByText(/^edit$/i)).not.toBeInTheDocument() }) @@ -692,7 +628,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(container.firstChild)!.toBeInTheDocument() + expect(container.firstChild).toBeInTheDocument() }) it('should render correctly with minimal props', () => { @@ -703,7 +639,7 @@ describe('MetadataDocument', () => { docDetail={mockDocDetail as Parameters[0]['docDetail']} />, ) - expect(container.firstChild)!.toBeInTheDocument() + expect(container.firstChild).toBeInTheDocument() }) it('should handle switching between view and edit mode', () => { @@ -715,7 +651,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText(/edit/i))!.toBeInTheDocument() + expect(screen.getByText(/edit/i)).toBeInTheDocument() unmount() @@ -732,8 +668,8 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText(/save/i))!.toBeInTheDocument() - expect(screen.getByText(/cancel/i))!.toBeInTheDocument() + expect(screen.getByText(/save/i)).toBeInTheDocument() + expect(screen.getByText(/cancel/i)).toBeInTheDocument() }) it('should handle multiple items in all sections', () => { @@ -766,11 +702,11 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('user_field_1'))!.toBeInTheDocument() - expect(screen.getByText('user_field_2'))!.toBeInTheDocument() - expect(screen.getByText('created_at'))!.toBeInTheDocument() - expect(screen.getByText('source'))!.toBeInTheDocument() - expect(screen.getByText('word_count'))!.toBeInTheDocument() + expect(screen.getByText('user_field_1')).toBeInTheDocument() + expect(screen.getByText('user_field_2')).toBeInTheDocument() + expect(screen.getByText('created_at')).toBeInTheDocument() + expect(screen.getByText('source')).toBeInTheDocument() + expect(screen.getByText('word_count')).toBeInTheDocument() }) it('should handle null values in metadata', () => { @@ -789,7 +725,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('null_field'))!.toBeInTheDocument() + expect(screen.getByText('null_field')).toBeInTheDocument() }) it('should handle undefined values in metadata', () => { @@ -808,7 +744,7 @@ describe('MetadataDocument', () => { />, ) - expect(screen.getByText('undefined_field'))!.toBeInTheDocument() + expect(screen.getByText('undefined_field')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx b/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx index 5e81611fc4..7441274155 100644 --- a/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx +++ b/web/app/components/datasets/settings/index-method/__tests__/index.spec.tsx @@ -19,12 +19,12 @@ describe('IndexMethod', () => { describe('Rendering', () => { it('should render without crashing', () => { render() - expect(screen.getByText(/stepTwo\.qualified/))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.qualified/)).toBeInTheDocument() }) it('should render High Quality option', () => { render() - expect(screen.getByText(/stepTwo\.qualified/))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.qualified/)).toBeInTheDocument() }) it('should render Economy option', () => { @@ -34,17 +34,17 @@ describe('IndexMethod', () => { it('should render High Quality description', () => { render() - expect(screen.getByText(/form\.indexMethodHighQualityTip/))!.toBeInTheDocument() + expect(screen.getByText(/form\.indexMethodHighQualityTip/)).toBeInTheDocument() }) it('should render Economy description', () => { render() - expect(screen.getByText(/form\.indexMethodEconomyTip/))!.toBeInTheDocument() + expect(screen.getByText(/form\.indexMethodEconomyTip/)).toBeInTheDocument() }) it('should render recommended badge on High Quality', () => { render() - expect(screen.getByText(/stepTwo\.recommend/))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.recommend/)).toBeInTheDocument() }) }) @@ -82,7 +82,7 @@ describe('IndexMethod', () => { // Find and click Economy option - use getAllByText and get the first one (title) const economyTitles = screen.getAllByText(/form\.indexMethodEconomy/) const economyTitle = economyTitles[0] - const card = economyTitle!.closest('div')?.parentElement?.parentElement?.parentElement + const card = economyTitle.closest('div')?.parentElement?.parentElement?.parentElement fireEvent.click(card!) expect(handleChange).toHaveBeenCalledWith(IndexingType.ECONOMICAL) @@ -114,7 +114,7 @@ describe('IndexMethod', () => { // Try to click Economy option - use getAllByText and get the first one (title) const economyTitles = screen.getAllByText(/form\.indexMethodEconomy/) const economyTitle = economyTitles[0] - const card = economyTitle!.closest('div')?.parentElement?.parentElement?.parentElement + const card = economyTitle.closest('div')?.parentElement?.parentElement?.parentElement fireEvent.click(card!) // Should not call onChange because Economy is disabled when current is QUALIFIED @@ -125,13 +125,13 @@ describe('IndexMethod', () => { describe('KeywordNumber', () => { it('should render KeywordNumber component inside Economy option', () => { render() - expect(getKeywordSlider())!.toBeInTheDocument() + expect(getKeywordSlider()).toBeInTheDocument() }) it('should pass keywordNumber to KeywordNumber component', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toHaveValue('25') + expect(input).toHaveValue('25') }) it('should call onKeywordNumberChange when KeywordNumber changes', () => { @@ -160,13 +160,13 @@ describe('IndexMethod', () => { it('should show orange effect color for High Quality option', () => { const { container } = render() const orangeEffect = container.querySelector('.bg-util-colors-orange-orange-500') - expect(orangeEffect)!.toBeInTheDocument() + expect(orangeEffect).toBeInTheDocument() }) it('should show indigo effect color for Economy option', () => { const { container } = render() const indigoEffect = container.querySelector('.bg-util-colors-indigo-indigo-600') - expect(indigoEffect)!.toBeInTheDocument() + expect(indigoEffect).toBeInTheDocument() }) }) @@ -188,20 +188,19 @@ describe('IndexMethod', () => { it('should handle undefined currentValue', () => { render() // Should render without error - // Should render without error - expect(screen.getByText(/stepTwo\.qualified/))!.toBeInTheDocument() + expect(screen.getByText(/stepTwo\.qualified/)).toBeInTheDocument() }) it('should handle minimum keywordNumber', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toHaveValue('0') + expect(input).toHaveValue('0') }) it('should handle max keywordNumber', () => { render() const input = screen.getByRole('textbox') - expect(input)!.toHaveValue('50') + expect(input).toHaveValue('50') }) }) }) diff --git a/web/app/components/datasets/settings/index-method/keyword-number.tsx b/web/app/components/datasets/settings/index-method/keyword-number.tsx index 03992fb027..feb63c1d65 100644 --- a/web/app/components/datasets/settings/index-method/keyword-number.tsx +++ b/web/app/components/datasets/settings/index-method/keyword-number.tsx @@ -33,7 +33,7 @@ const KeyWordNumber = ({ return (
-
+
{t('form.numberOfKeywords', { ns: 'datasetSettings' })}
- - + + - - + + diff --git a/web/app/components/datasets/settings/permission-selector/index.tsx b/web/app/components/datasets/settings/permission-selector/index.tsx index 8c31799add..a7182d8f79 100644 --- a/web/app/components/datasets/settings/permission-selector/index.tsx +++ b/web/app/components/datasets/settings/permission-selector/index.tsx @@ -133,8 +133,8 @@ const PermissionSelector = ({ { selectedMembers.length === 1 && ( ) @@ -143,14 +143,14 @@ const PermissionSelector = ({ selectedMembers.length >= 2 && ( <> diff --git a/web/app/components/evaluation/__tests__/default-metric-descriptions.spec.ts b/web/app/components/evaluation/__tests__/default-metric-descriptions.spec.ts new file mode 100644 index 0000000000..9671220334 --- /dev/null +++ b/web/app/components/evaluation/__tests__/default-metric-descriptions.spec.ts @@ -0,0 +1,34 @@ +import { getDefaultMetricDescription, getDefaultMetricDescriptionI18nKey, getTranslatedMetricDescription } from '../default-metric-descriptions' + +describe('default metric descriptions', () => { + it('should resolve descriptions for kebab-case metric ids', () => { + expect(getDefaultMetricDescription('context-precision')).toContain('retrieval pipeline returns little noise') + expect(getDefaultMetricDescription('answer-correctness')).toContain('factual accuracy and completeness') + }) + + it('should normalize snake_case metric ids from backend payloads', () => { + expect(getDefaultMetricDescription('CONTEXT_RECALL')).toContain('does not miss important supporting evidence') + expect(getDefaultMetricDescription('TOOL_CORRECTNESS')).toContain('tool-use strategy matches the expected behavior') + }) + + it('should support the legacy relevance alias', () => { + expect(getDefaultMetricDescription('relevance')).toContain('addresses the user\'s question') + }) + + it('should resolve i18n keys for builtin metrics', () => { + expect(getDefaultMetricDescriptionI18nKey('context-precision')).toBe('metrics.builtin.description.contextPrecision') + expect(getDefaultMetricDescriptionI18nKey('ANSWER_RELEVANCY')).toBe('metrics.builtin.description.answerRelevancy') + }) + + it('should use translated content when translation key exists', () => { + const t = vi.fn((key: string, options?: { defaultValue?: string }) => { + if (key === 'metrics.builtin.description.faithfulness') + return '忠实性中文文案' + + return options?.defaultValue ?? key + }) + + expect(getTranslatedMetricDescription(t as never, 'faithfulness')).toBe('忠实性中文文案') + expect(getTranslatedMetricDescription(t as never, 'latency', 'Latency fallback')).toBe('Latency fallback') + }) +}) diff --git a/web/app/components/evaluation/__tests__/index.spec.tsx b/web/app/components/evaluation/__tests__/index.spec.tsx new file mode 100644 index 0000000000..4f8db5a32d --- /dev/null +++ b/web/app/components/evaluation/__tests__/index.spec.tsx @@ -0,0 +1,519 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import Evaluation from '..' +import ConditionsSection from '../components/conditions-section' +import { useEvaluationStore } from '../store' + +const mockUpload = vi.hoisted(() => vi.fn()) +const mockUseAvailableEvaluationMetrics = vi.hoisted(() => vi.fn()) +const mockUseEvaluationConfig = vi.hoisted(() => vi.fn()) +const mockUseEvaluationNodeInfoMutation = vi.hoisted(() => vi.fn()) +const mockUseSaveEvaluationConfigMutation = vi.hoisted(() => vi.fn()) +const mockUseStartEvaluationRunMutation = vi.hoisted(() => vi.fn()) +const mockUsePublishedPipelineInfo = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelList: () => ({ + data: [{ + provider: 'openai', + models: [{ model: 'gpt-4o-mini' }], + }], + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + default: ({ + defaultModel, + onSelect, + }: { + defaultModel?: { provider: string, model: string } + onSelect: (model: { provider: string, model: string }) => void + }) => ( +
+
+ {defaultModel ? `${defaultModel.provider}:${defaultModel.model}` : 'empty'} +
+ +
+ ), +})) + +vi.mock('@/service/base', () => ({ + upload: (...args: unknown[]) => mockUpload(...args), +})) + +vi.mock('@/service/use-evaluation', () => ({ + useEvaluationConfig: (...args: unknown[]) => mockUseEvaluationConfig(...args), + useAvailableEvaluationMetrics: (...args: unknown[]) => mockUseAvailableEvaluationMetrics(...args), + useEvaluationNodeInfoMutation: (...args: unknown[]) => mockUseEvaluationNodeInfoMutation(...args), + useSaveEvaluationConfigMutation: (...args: unknown[]) => mockUseSaveEvaluationConfigMutation(...args), + useStartEvaluationRunMutation: (...args: unknown[]) => mockUseStartEvaluationRunMutation(...args), +})) + +vi.mock('@/service/use-pipeline', () => ({ + usePublishedPipelineInfo: (...args: unknown[]) => mockUsePublishedPipelineInfo(...args), +})) + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: { pipeline_id: string } }) => unknown) => + selector({ dataset: { pipeline_id: 'pipeline-1' } }), +})) + +vi.mock('@/service/use-workflow', () => ({ + useAppWorkflow: () => ({ + data: { + graph: { + nodes: [{ + id: 'start', + data: { + type: 'start', + variables: [{ + variable: 'query', + type: 'text-input', + }], + }, + }], + }, + }, + isLoading: false, + }), +})) + +vi.mock('@/service/use-snippet-workflows', () => ({ + useSnippetPublishedWorkflow: () => ({ + data: { + graph: { + nodes: [{ + id: 'start', + data: { + type: 'start', + variables: [{ + variable: 'query', + type: 'text-input', + }], + }, + }], + }, + }, + isLoading: false, + }), +})) + +const renderWithQueryClient = (ui: ReactNode) => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + mutations: { + retry: false, + }, + }, + }) + + return render(ui, { + wrapper: ({ children }: { children: ReactNode }) => ( + + {children} + + ), + }) +} + +describe('Evaluation', () => { + beforeEach(() => { + useEvaluationStore.setState({ resources: {} }) + vi.clearAllMocks() + mockUseEvaluationConfig.mockReturnValue({ + data: null, + }) + + mockUseAvailableEvaluationMetrics.mockReturnValue({ + data: { + metrics: ['answer-correctness', 'faithfulness', 'context-precision', 'context-recall', 'context-relevance'], + }, + isLoading: false, + }) + + mockUseEvaluationNodeInfoMutation.mockReturnValue({ + isPending: false, + mutate: (_input: unknown, options?: { onSuccess?: (data: Record>) => void }) => { + options?.onSuccess?.({ + 'answer-correctness': [ + { node_id: 'node-answer', title: 'Answer Node', type: 'llm' }, + ], + 'faithfulness': [ + { node_id: 'node-faithfulness', title: 'Retriever Node', type: 'retriever' }, + ], + }) + }, + }) + mockUseSaveEvaluationConfigMutation.mockReturnValue({ + isPending: false, + mutate: vi.fn(), + }) + mockUseStartEvaluationRunMutation.mockReturnValue({ + isPending: false, + mutate: vi.fn(), + }) + mockUsePublishedPipelineInfo.mockReturnValue({ + data: { + graph: { + nodes: [{ + id: 'knowledge-node', + data: { + type: 'knowledge-index', + title: 'Knowledge Base', + }, + }], + edges: [], + }, + }, + }) + mockUpload.mockResolvedValue({ + id: 'uploaded-file-id', + name: 'evaluation.csv', + }) + }) + + it('should search, select metric nodes, and save evaluation config', () => { + const saveConfig = vi.fn() + mockUseSaveEvaluationConfigMutation.mockReturnValue({ + isPending: false, + mutate: saveConfig, + }) + + renderWithQueryClient() + + expect(screen.getByTestId('evaluation-model-selector')).toHaveTextContent('openai:gpt-4o-mini') + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.add' })) + + fireEvent.change(screen.getByPlaceholderText('evaluation.metrics.searchNodeOrMetrics'), { + target: { value: 'does-not-exist' }, + }) + + expect(screen.getByText('evaluation.metrics.noResults')).toBeInTheDocument() + + fireEvent.change(screen.getByPlaceholderText('evaluation.metrics.searchNodeOrMetrics'), { + target: { value: 'faith' }, + }) + + fireEvent.click(screen.getByTestId('evaluation-metric-node-faithfulness-node-faithfulness')) + expect(screen.getAllByText('Faithfulness').length).toBeGreaterThan(0) + expect(screen.getAllByText('Retriever Node').length).toBeGreaterThan(0) + + fireEvent.change(screen.getByPlaceholderText('evaluation.metrics.searchNodeOrMetrics'), { + target: { value: '' }, + }) + + fireEvent.click(screen.getByTestId('evaluation-metric-node-answer-correctness-node-answer')) + expect(screen.getAllByText('Answer Correctness').length).toBeGreaterThan(0) + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + + expect(saveConfig).toHaveBeenCalledWith({ + params: { + targetType: 'apps', + targetId: 'app-1', + }, + body: { + evaluation_model: 'gpt-4o-mini', + evaluation_model_provider: 'openai', + default_metrics: [ + { + metric: 'faithfulness', + value_type: 'number', + node_info_list: [ + { node_id: 'node-faithfulness', title: 'Retriever Node', type: 'retriever' }, + ], + }, + { + metric: 'answer-correctness', + value_type: 'number', + node_info_list: [ + { node_id: 'node-answer', title: 'Answer Node', type: 'llm' }, + ], + }, + ], + customized_metrics: null, + judgment_config: null, + }, + }, { + onSuccess: expect.any(Function), + onError: expect.any(Function), + }) + }) + + it('should hide the value row for empty operators', () => { + const resourceType = 'apps' + const resourceId = 'app-2' + const store = useEvaluationStore.getState() + let conditionId = '' + + act(() => { + store.ensureResource(resourceType, resourceId) + store.setJudgeModel(resourceType, resourceId, 'openai::gpt-4o-mini') + store.addBuiltinMetric(resourceType, resourceId, 'faithfulness', [ + { node_id: 'node-faithfulness', title: 'Retriever Node', type: 'retriever' }, + ]) + store.addCondition(resourceType, resourceId) + + const condition = useEvaluationStore.getState().resources['apps:app-2'].judgmentConfig.conditions[0] + conditionId = condition.id + store.updateConditionOperator(resourceType, resourceId, conditionId, '=') + }) + + let rerender: ReturnType['rerender'] + act(() => { + ({ rerender } = renderWithQueryClient()) + }) + + expect(screen.getByPlaceholderText('evaluation.conditions.valuePlaceholder')).toBeInTheDocument() + + act(() => { + store.updateConditionOperator(resourceType, resourceId, conditionId, 'is null') + rerender() + }) + + expect(screen.queryByPlaceholderText('evaluation.conditions.valuePlaceholder')).not.toBeInTheDocument() + }) + + it('should add a condition from grouped metric dropdown items', () => { + const resourceType = 'apps' + const resourceId = 'app-conditions-dropdown' + const store = useEvaluationStore.getState() + + act(() => { + store.ensureResource(resourceType, resourceId) + store.setJudgeModel(resourceType, resourceId, 'openai::gpt-4o-mini') + store.addBuiltinMetric(resourceType, resourceId, 'faithfulness', [ + { node_id: 'node-faithfulness', title: 'Retriever Node', type: 'retriever' }, + ]) + store.addCustomMetric(resourceType, resourceId) + + const customMetric = useEvaluationStore.getState().resources['apps:app-conditions-dropdown'].metrics.find(metric => metric.kind === 'custom-workflow')! + store.setCustomMetricWorkflow(resourceType, resourceId, customMetric.id, { + workflowId: 'workflow-1', + workflowAppId: 'workflow-app-1', + workflowName: 'Review Workflow', + }) + store.syncCustomMetricOutputs(resourceType, resourceId, customMetric.id, [{ + id: 'reason', + valueType: 'string', + }]) + }) + + render() + + fireEvent.click(screen.getByRole('combobox', { name: 'evaluation.conditions.addCondition' })) + + expect(screen.getByText('Faithfulness')).toBeInTheDocument() + expect(screen.getByText('Review Workflow')).toBeInTheDocument() + expect(screen.getByText('Retriever Node')).toBeInTheDocument() + expect(screen.getByText('reason')).toBeInTheDocument() + expect(screen.getByText('evaluation.conditions.valueTypes.number')).toBeInTheDocument() + expect(screen.getByText('evaluation.conditions.valueTypes.string')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('option', { name: /reason/i })) + + const condition = useEvaluationStore.getState().resources['apps:app-conditions-dropdown'].judgmentConfig.conditions[0] + + expect(condition.variableSelector).toEqual(['workflow-1', 'reason']) + expect(screen.getAllByText('Review Workflow').length).toBeGreaterThan(0) + }) + + it('should render the metric no-node empty state', () => { + mockUseAvailableEvaluationMetrics.mockReturnValue({ + data: { + metrics: ['context-precision'], + }, + isLoading: false, + }) + + mockUseEvaluationNodeInfoMutation.mockReturnValue({ + isPending: false, + mutate: (_input: unknown, options?: { onSuccess?: (data: Record>) => void }) => { + options?.onSuccess?.({ + 'context-precision': [], + }) + }, + }) + + renderWithQueryClient() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.add' })) + + expect(screen.getByText('evaluation.metrics.noNodesInWorkflow')).toBeInTheDocument() + }) + + it('should render the global empty state when no metrics are available', () => { + mockUseAvailableEvaluationMetrics.mockReturnValue({ + data: { + metrics: [], + }, + isLoading: false, + }) + + renderWithQueryClient() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.add' })) + + expect(screen.getByText('evaluation.metrics.noResults')).toBeInTheDocument() + }) + + it('should show more nodes when a metric has more than three nodes', () => { + mockUseAvailableEvaluationMetrics.mockReturnValue({ + data: { + metrics: ['answer-correctness'], + }, + isLoading: false, + }) + + mockUseEvaluationNodeInfoMutation.mockReturnValue({ + isPending: false, + mutate: (_input: unknown, options?: { onSuccess?: (data: Record>) => void }) => { + options?.onSuccess?.({ + 'answer-correctness': [ + { node_id: 'node-1', title: 'LLM 1', type: 'llm' }, + { node_id: 'node-2', title: 'LLM 2', type: 'llm' }, + { node_id: 'node-3', title: 'LLM 3', type: 'llm' }, + { node_id: 'node-4', title: 'LLM 4', type: 'llm' }, + ], + }) + }, + }) + + renderWithQueryClient() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.add' })) + + expect(screen.getByText('LLM 3')).toBeInTheDocument() + expect(screen.queryByText('LLM 4')).not.toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.showMore' })) + + expect(screen.getByText('LLM 4')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'evaluation.metrics.showLess' })).toBeInTheDocument() + }) + + it('should render the pipeline-specific layout without auto-selecting a judge model', () => { + renderWithQueryClient() + + expect(screen.getByTestId('evaluation-model-selector')).toHaveTextContent('empty') + expect(screen.getByText('evaluation.history.columns.time')).toBeInTheDocument() + expect(screen.getByText('Context Precision')).toBeInTheDocument() + expect(screen.getByText('Context Recall')).toBeInTheDocument() + expect(screen.getByText('Context Relevance')).toBeInTheDocument() + expect(screen.getByText('evaluation.results.empty')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'evaluation.pipeline.uploadAndRun' })).toBeDisabled() + }) + + it('should render selected pipeline metrics from config with the default threshold input', () => { + mockUseEvaluationConfig.mockReturnValue({ + data: { + evaluation_model: null, + evaluation_model_provider: null, + default_metrics: [{ + metric: 'context-precision', + }], + customized_metrics: null, + judgment_config: null, + }, + }) + + renderWithQueryClient() + + expect(screen.getByText('Context Precision')).toBeInTheDocument() + expect(screen.getByDisplayValue('0.85')).toBeInTheDocument() + }) + + it('should enable pipeline batch actions after selecting a judge model and metric', () => { + renderWithQueryClient() + + fireEvent.click(screen.getByRole('button', { name: 'select-model' })) + fireEvent.click(screen.getByRole('button', { name: /Context Precision/i })) + + expect(screen.getByDisplayValue('0.85')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'evaluation.batch.downloadTemplate' })).toBeEnabled() + expect(screen.getByRole('button', { name: 'evaluation.pipeline.uploadAndRun' })).toBeEnabled() + }) + + it('should upload and start a pipeline evaluation run', async () => { + const startRun = vi.fn() + mockUseStartEvaluationRunMutation.mockReturnValue({ + isPending: false, + mutate: startRun, + }) + mockUpload.mockResolvedValue({ + id: 'file-1', + name: 'pipeline-evaluation.csv', + }) + + renderWithQueryClient() + + fireEvent.click(screen.getByRole('button', { name: 'select-model' })) + fireEvent.click(screen.getByRole('button', { name: /Context Precision/i })) + fireEvent.click(screen.getByRole('button', { name: 'evaluation.pipeline.uploadAndRun' })) + + expect(screen.getAllByText('query').length).toBeGreaterThan(0) + expect(screen.getAllByText('Expect Results').length).toBeGreaterThan(0) + + const fileInput = document.querySelector('input[type="file"][accept=".csv,.xlsx"]') + expect(fileInput).toBeInTheDocument() + + fireEvent.change(fileInput!, { + target: { + files: [new File(['query,Expect Results'], 'pipeline-evaluation.csv', { type: 'text/csv' })], + }, + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalledWith({ + xhr: expect.any(XMLHttpRequest), + data: expect.any(FormData), + }) + }) + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.batch.run' })) + + await waitFor(() => { + expect(startRun).toHaveBeenCalledWith({ + params: { + targetType: 'datasets', + targetId: 'dataset-run', + }, + body: { + evaluation_model: 'gpt-4o-mini', + evaluation_model_provider: 'openai', + default_metrics: [{ + metric: 'context-precision', + value_type: 'number', + node_info_list: [ + { node_id: 'knowledge-node', title: 'Knowledge Base', type: 'knowledge-index' }, + ], + }], + customized_metrics: null, + judgment_config: { + logical_operator: 'and', + conditions: [{ + variable_selector: ['knowledge-node', 'context-precision'], + comparison_operator: '≥', + value: '0.85', + }], + }, + file_id: 'file-1', + }, + }, { + onSuccess: expect.any(Function), + onError: expect.any(Function), + }) + }) + }) +}) diff --git a/web/app/components/evaluation/__tests__/store.spec.ts b/web/app/components/evaluation/__tests__/store.spec.ts new file mode 100644 index 0000000000..150f285b52 --- /dev/null +++ b/web/app/components/evaluation/__tests__/store.spec.ts @@ -0,0 +1,456 @@ +import type { EvaluationConfig } from '@/types/evaluation' +import { getEvaluationMockConfig } from '../mock' +import { + getAllowedOperators, + isCustomMetricConfigured, + requiresConditionValue, + useEvaluationStore, +} from '../store' +import { buildEvaluationConfigPayload, buildEvaluationRunRequest } from '../store-utils' + +describe('evaluation store', () => { + beforeEach(() => { + useEvaluationStore.setState({ resources: {} }) + }) + + it('should configure a custom metric mapping to a valid state', () => { + const resourceType = 'apps' + const resourceId = 'app-1' + const store = useEvaluationStore.getState() + const config = getEvaluationMockConfig(resourceType) + + store.ensureResource(resourceType, resourceId) + store.addCustomMetric(resourceType, resourceId) + + const initialMetric = useEvaluationStore.getState().resources['apps:app-1'].metrics.find(metric => metric.kind === 'custom-workflow') + expect(initialMetric).toBeDefined() + expect(isCustomMetricConfigured(initialMetric!)).toBe(false) + + store.setCustomMetricWorkflow(resourceType, resourceId, initialMetric!.id, { + workflowId: config.workflowOptions[0].id, + workflowAppId: 'custom-workflow-app-id', + workflowName: config.workflowOptions[0].label, + }) + store.syncCustomMetricMappings(resourceType, resourceId, initialMetric!.id, ['query']) + store.syncCustomMetricOutputs(resourceType, resourceId, initialMetric!.id, [{ + id: 'score', + valueType: 'number', + }]) + + const syncedMetric = useEvaluationStore.getState().resources['apps:app-1'].metrics.find(metric => metric.id === initialMetric!.id) + store.updateCustomMetricMapping(resourceType, resourceId, initialMetric!.id, syncedMetric!.customConfig!.mappings[0].id, { + outputVariableId: 'answer', + }) + + const configuredMetric = useEvaluationStore.getState().resources['apps:app-1'].metrics.find(metric => metric.id === initialMetric!.id) + expect(isCustomMetricConfigured(configuredMetric!)).toBe(true) + expect(configuredMetric!.customConfig!.workflowAppId).toBe('custom-workflow-app-id') + expect(configuredMetric!.customConfig!.workflowName).toBe(config.workflowOptions[0].label) + expect(configuredMetric!.customConfig!.outputs).toEqual([{ id: 'score', valueType: 'number' }]) + }) + + it('should only add one custom metric', () => { + const resourceType = 'apps' + const resourceId = 'app-custom-limit' + const store = useEvaluationStore.getState() + + store.ensureResource(resourceType, resourceId) + store.addCustomMetric(resourceType, resourceId) + store.addCustomMetric(resourceType, resourceId) + + expect( + useEvaluationStore + .getState() + .resources['apps:app-custom-limit'] + .metrics + .filter(metric => metric.kind === 'custom-workflow'), + ).toHaveLength(1) + }) + + it('should add and remove builtin metrics', () => { + const resourceType = 'apps' + const resourceId = 'app-2' + const store = useEvaluationStore.getState() + const config = getEvaluationMockConfig(resourceType) + + store.ensureResource(resourceType, resourceId) + store.addBuiltinMetric(resourceType, resourceId, config.builtinMetrics[1].id) + + const addedMetric = useEvaluationStore.getState().resources['apps:app-2'].metrics.find(metric => metric.optionId === config.builtinMetrics[1].id) + expect(addedMetric).toBeDefined() + + store.removeMetric(resourceType, resourceId, addedMetric!.id) + + expect(useEvaluationStore.getState().resources['apps:app-2'].metrics.some(metric => metric.id === addedMetric!.id)).toBe(false) + }) + + it('should upsert builtin metric node selections and prune stale conditions', () => { + const resourceType = 'apps' + const resourceId = 'app-4' + const store = useEvaluationStore.getState() + const config = getEvaluationMockConfig(resourceType) + const metricId = config.builtinMetrics[0].id + + store.ensureResource(resourceType, resourceId) + store.addBuiltinMetric(resourceType, resourceId, metricId, [ + { node_id: 'node-1', title: 'Answer Node', type: 'answer' }, + ]) + store.addCondition(resourceType, resourceId) + + store.addBuiltinMetric(resourceType, resourceId, metricId, [ + { node_id: 'node-2', title: 'Retriever Node', type: 'retriever' }, + ]) + + const state = useEvaluationStore.getState().resources['apps:app-4'] + const metric = state.metrics.find(item => item.optionId === metricId) + + expect(metric?.nodeInfoList).toEqual([ + { node_id: 'node-2', title: 'Retriever Node', type: 'retriever' }, + ]) + expect(state.metrics.filter(item => item.optionId === metricId)).toHaveLength(1) + expect(state.judgmentConfig.conditions).toHaveLength(0) + }) + + it('should build numeric conditions from selected metrics', () => { + const resourceType = 'apps' + const resourceId = 'app-conditions' + const store = useEvaluationStore.getState() + const config = getEvaluationMockConfig(resourceType) + + store.ensureResource(resourceType, resourceId) + store.addBuiltinMetric(resourceType, resourceId, config.builtinMetrics[0].id, [ + { node_id: 'node-answer', title: 'Answer Node', type: 'llm' }, + ]) + store.setConditionLogicalOperator(resourceType, resourceId, 'or') + store.addCondition(resourceType, resourceId) + + const state = useEvaluationStore.getState().resources['apps:app-conditions'] + const condition = state.judgmentConfig.conditions[0] + + expect(state.judgmentConfig.logicalOperator).toBe('or') + expect(condition.variableSelector).toEqual(['node-answer', 'answer-correctness']) + expect(condition.comparisonOperator).toBe('=') + expect(getAllowedOperators(state.metrics, condition.variableSelector)).toEqual(['=', '≠', '>', '<', '≥', '≤', 'is null', 'is not null']) + }) + + it('should add a condition from the selected custom metric output', () => { + const resourceType = 'apps' + const resourceId = 'app-condition-selector' + const store = useEvaluationStore.getState() + const config = getEvaluationMockConfig(resourceType) + + store.ensureResource(resourceType, resourceId) + store.addCustomMetric(resourceType, resourceId) + + const customMetric = useEvaluationStore.getState().resources['apps:app-condition-selector'].metrics.find(metric => metric.kind === 'custom-workflow')! + store.setCustomMetricWorkflow(resourceType, resourceId, customMetric.id, { + workflowId: config.workflowOptions[0].id, + workflowAppId: 'custom-workflow-app-id', + workflowName: config.workflowOptions[0].label, + }) + store.syncCustomMetricOutputs(resourceType, resourceId, customMetric.id, [{ + id: 'reason', + valueType: 'string', + }]) + + store.addCondition(resourceType, resourceId, [config.workflowOptions[0].id, 'reason']) + + const condition = useEvaluationStore.getState().resources['apps:app-condition-selector'].judgmentConfig.conditions[0] + + expect(condition.variableSelector).toEqual([config.workflowOptions[0].id, 'reason']) + expect(condition.comparisonOperator).toBe('contains') + expect(condition.value).toBeNull() + }) + + it('should clear values for operators without values', () => { + const resourceType = 'apps' + const resourceId = 'app-3' + const store = useEvaluationStore.getState() + const config = getEvaluationMockConfig(resourceType) + + store.ensureResource(resourceType, resourceId) + store.addCustomMetric(resourceType, resourceId) + + const customMetric = useEvaluationStore.getState().resources['apps:app-3'].metrics.find(metric => metric.kind === 'custom-workflow')! + store.setCustomMetricWorkflow(resourceType, resourceId, customMetric.id, { + workflowId: config.workflowOptions[0].id, + workflowAppId: 'custom-workflow-app-id', + workflowName: config.workflowOptions[0].label, + }) + store.syncCustomMetricOutputs(resourceType, resourceId, customMetric.id, [{ + id: 'reason', + valueType: 'string', + }]) + store.addCondition(resourceType, resourceId) + + const condition = useEvaluationStore.getState().resources['apps:app-3'].judgmentConfig.conditions[0] + + store.updateConditionMetric(resourceType, resourceId, condition.id, [config.workflowOptions[0].id, 'reason']) + store.updateConditionValue(resourceType, resourceId, condition.id, 'needs follow-up') + store.updateConditionOperator(resourceType, resourceId, condition.id, 'empty') + + const updatedCondition = useEvaluationStore.getState().resources['apps:app-3'].judgmentConfig.conditions[0] + + expect(requiresConditionValue('empty')).toBe(false) + expect(updatedCondition.value).toBeNull() + }) + + it('should hydrate resource state from judgment_config', () => { + const resourceType = 'apps' + const resourceId = 'app-5' + const store = useEvaluationStore.getState() + const config: EvaluationConfig = { + evaluation_model: 'gpt-4o-mini', + evaluation_model_provider: 'openai', + default_metrics: [{ + metric: 'faithfulness', + node_info_list: [ + { node_id: 'node-1', title: 'Retriever', type: 'retriever' }, + ], + }], + customized_metrics: { + evaluation_workflow_id: 'workflow-precision-review', + input_fields: { + query: 'answer', + }, + output_fields: [{ + variable: 'reason', + value_type: 'string', + }], + }, + judgment_config: { + logical_operator: 'or', + conditions: [{ + variable_selector: ['node-1', 'faithfulness'], + comparison_operator: '≥', + value: '0.9', + }], + }, + } + + store.ensureResource(resourceType, resourceId) + store.setBatchTab(resourceType, resourceId, 'history') + store.setUploadedFileName(resourceType, resourceId, 'batch.csv') + useEvaluationStore.setState(state => ({ + resources: { + ...state.resources, + 'apps:app-5': { + ...state.resources['apps:app-5'], + batchRecords: [{ + id: 'batch-1', + fileName: 'batch.csv', + status: 'success', + startedAt: '10:00:00', + summary: 'App evaluation batch', + }], + }, + }, + })) + store.hydrateResource(resourceType, resourceId, config) + + const hydratedState = useEvaluationStore.getState().resources['apps:app-5'] + + expect(hydratedState.judgeModelId).toBe('openai::gpt-4o-mini') + expect(hydratedState.metrics).toHaveLength(2) + expect(hydratedState.metrics[0].optionId).toBe('faithfulness') + expect(hydratedState.metrics[0].threshold).toBe(0.85) + expect(hydratedState.metrics[0].nodeInfoList).toEqual([ + { node_id: 'node-1', title: 'Retriever', type: 'retriever' }, + ]) + expect(hydratedState.metrics[1].kind).toBe('custom-workflow') + expect(hydratedState.metrics[1].customConfig?.workflowId).toBe('workflow-precision-review') + expect(hydratedState.metrics[1].customConfig?.mappings[0].inputVariableId).toBe('query') + expect(hydratedState.metrics[1].customConfig?.mappings[0].outputVariableId).toBe('answer') + expect(hydratedState.metrics[1].customConfig?.outputs).toEqual([{ id: 'reason', valueType: 'string' }]) + expect(hydratedState.judgmentConfig.logicalOperator).toBe('or') + expect(hydratedState.judgmentConfig.conditions[0]).toMatchObject({ + variableSelector: ['node-1', 'faithfulness'], + comparisonOperator: '≥', + value: '0.9', + }) + expect(hydratedState.activeBatchTab).toBe('history') + expect(hydratedState.uploadedFileName).toBe('batch.csv') + expect(hydratedState.batchRecords).toHaveLength(1) + }) + + it('should build an evaluation config save payload from resource state', () => { + const resourceType = 'apps' + const resourceId = 'app-save-config' + const store = useEvaluationStore.getState() + + store.ensureResource(resourceType, resourceId) + store.setJudgeModel(resourceType, resourceId, 'openai::gpt-4o-mini') + store.addBuiltinMetric(resourceType, resourceId, 'faithfulness', [ + { node_id: 'node-faithfulness', title: 'Retriever Node', type: 'retriever' }, + ]) + store.addCustomMetric(resourceType, resourceId) + + const customMetric = useEvaluationStore.getState().resources['apps:app-save-config'].metrics.find(metric => metric.kind === 'custom-workflow')! + store.setCustomMetricWorkflow(resourceType, resourceId, customMetric.id, { + workflowId: 'workflow-precision-review', + workflowAppId: 'evaluation-workflow-app-id', + workflowName: 'Precision Review', + }) + store.syncCustomMetricMappings(resourceType, resourceId, customMetric.id, ['query']) + store.syncCustomMetricOutputs(resourceType, resourceId, customMetric.id, [{ + id: 'score', + valueType: 'number', + }]) + + const syncedMetric = useEvaluationStore.getState().resources['apps:app-save-config'].metrics.find(metric => metric.id === customMetric.id)! + store.updateCustomMetricMapping(resourceType, resourceId, customMetric.id, syncedMetric.customConfig!.mappings[0].id, { + outputVariableId: '{{#node-answer.output#}}', + }) + store.addCondition(resourceType, resourceId, ['workflow-precision-review', 'score']) + + const condition = useEvaluationStore.getState().resources['apps:app-save-config'].judgmentConfig.conditions[0] + store.updateConditionOperator(resourceType, resourceId, condition.id, '≥') + store.updateConditionValue(resourceType, resourceId, condition.id, '0.8') + + const resource = useEvaluationStore.getState().resources['apps:app-save-config'] + const expectedPayload = { + evaluation_model: 'gpt-4o-mini', + evaluation_model_provider: 'openai', + default_metrics: [{ + metric: 'faithfulness', + value_type: 'number', + node_info_list: [ + { node_id: 'node-faithfulness', title: 'Retriever Node', type: 'retriever' }, + ], + }], + customized_metrics: { + evaluation_workflow_id: 'evaluation-workflow-app-id', + input_fields: { + query: '{{#node-answer.output#}}', + }, + output_fields: [{ + variable: 'score', + value_type: 'number', + }], + }, + judgment_config: { + logical_operator: 'and', + conditions: [{ + variable_selector: ['evaluation-workflow-app-id', 'score'], + comparison_operator: '≥', + value: '0.8', + }], + }, + } + + expect(buildEvaluationConfigPayload(resource, resourceType)).toEqual(expectedPayload) + expect(buildEvaluationRunRequest(resource, 'file-1', resourceType)).toEqual({ + ...expectedPayload, + file_id: 'file-1', + }) + }) + + it('should hydrate pipeline metrics from fixed knowledge-index conditions', () => { + const resourceType = 'datasets' + const resourceId = 'dataset-hydrate' + const store = useEvaluationStore.getState() + const config: EvaluationConfig = { + evaluation_model: 'gpt-4o-mini', + evaluation_model_provider: 'openai', + default_metrics: [{ + metric: 'context-precision', + node_info_list: [ + { node_id: 'knowledge-node', title: 'Knowledge Base', type: 'knowledge-index' }, + ], + }], + customized_metrics: { + evaluation_workflow_id: 'should-be-ignored', + input_fields: { + query: 'answer', + }, + output_fields: [{ + variable: 'score', + value_type: 'number', + }], + }, + judgment_config: { + logical_operator: 'or', + conditions: [{ + variable_selector: ['knowledge-node', 'context-precision'], + comparison_operator: '≥', + value: '0.92', + }], + }, + } + + store.hydrateResource(resourceType, resourceId, config) + + const hydratedState = useEvaluationStore.getState().resources['datasets:dataset-hydrate'] + + expect(hydratedState.judgeModelId).toBe('openai::gpt-4o-mini') + expect(hydratedState.metrics).toHaveLength(1) + expect(hydratedState.metrics[0]).toMatchObject({ + optionId: 'context-precision', + kind: 'builtin', + valueType: 'number', + threshold: 0.92, + nodeInfoList: [ + { node_id: 'knowledge-node', title: 'Knowledge Base', type: 'knowledge-index' }, + ], + }) + }) + + it('should build pipeline judgment payload from metric thresholds', () => { + const resourceType = 'datasets' + const resourceId = 'dataset-save-config' + const store = useEvaluationStore.getState() + const knowledgeNodeInfo = [{ node_id: 'knowledge-node', title: 'Knowledge Base', type: 'knowledge-index' }] + + store.ensureResource(resourceType, resourceId) + store.setJudgeModel(resourceType, resourceId, 'openai::gpt-4o-mini') + store.addBuiltinMetric(resourceType, resourceId, 'context-precision', knowledgeNodeInfo) + store.addBuiltinMetric(resourceType, resourceId, 'context-recall', knowledgeNodeInfo) + + const resourceWithMetrics = useEvaluationStore.getState().resources['datasets:dataset-save-config'] + const contextPrecisionMetric = resourceWithMetrics.metrics.find(metric => metric.optionId === 'context-precision')! + const contextRecallMetric = resourceWithMetrics.metrics.find(metric => metric.optionId === 'context-recall')! + + store.updateMetricThreshold(resourceType, resourceId, contextPrecisionMetric.id, 0.91) + store.updateMetricThreshold(resourceType, resourceId, contextRecallMetric.id, 0.88) + + const resource = useEvaluationStore.getState().resources['datasets:dataset-save-config'] + const expectedPayload = { + evaluation_model: 'gpt-4o-mini', + evaluation_model_provider: 'openai', + default_metrics: [ + { + metric: 'context-precision', + value_type: 'number', + node_info_list: knowledgeNodeInfo, + }, + { + metric: 'context-recall', + value_type: 'number', + node_info_list: knowledgeNodeInfo, + }, + ], + customized_metrics: null, + judgment_config: { + logical_operator: 'and', + conditions: [ + { + variable_selector: ['knowledge-node', 'context-precision'], + comparison_operator: '≥', + value: '0.91', + }, + { + variable_selector: ['knowledge-node', 'context-recall'], + comparison_operator: '≥', + value: '0.88', + }, + ], + }, + } + + expect(buildEvaluationConfigPayload(resource, resourceType)).toEqual(expectedPayload) + expect(buildEvaluationRunRequest(resource, 'file-1', resourceType)).toEqual({ + ...expectedPayload, + file_id: 'file-1', + }) + }) +}) diff --git a/web/app/components/evaluation/components/batch-test-panel/history-tab.tsx b/web/app/components/evaluation/components/batch-test-panel/history-tab.tsx new file mode 100644 index 0000000000..4da6f6fe48 --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/history-tab.tsx @@ -0,0 +1,202 @@ +import type { EvaluationResourceProps } from '../../types' +import type { EvaluationLog, EvaluationLogFile } from '@/types/evaluation' +import { cn } from '@langgenius/dify-ui/cn' +import { keepPreviousData, useMutation, useQuery } from '@tanstack/react-query' +import { useEffect, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Pagination from '@/app/components/base/pagination' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' +import { consoleClient, consoleQuery } from '@/service/client' +import { downloadUrl } from '@/utils/download' +import { useEvaluationResource, useEvaluationStore } from '../../store' + +const PAGE_SIZE = 16 +const LOADING_ROW_IDS = ['1', '2', '3', '4', '5', '6'] + +const formatCreatedAt = (createdAt: string) => { + if (!createdAt) + return '-' + + return createdAt.includes('T') ? createdAt.slice(0, 10) : createdAt +} + +const getLogRunId = (record: EvaluationLog) => { + return record.run_id ?? record.evaluation_run_id ?? record.id ?? null +} + +const HistoryTab = ({ + resourceType, + resourceId, +}: EvaluationResourceProps) => { + const { t } = useTranslation('evaluation') + const [page, setPage] = useState(0) + const resource = useEvaluationResource(resourceType, resourceId) + const setSelectedRunId = useEvaluationStore(state => state.setSelectedRunId) + const logsQuery = useQuery({ + ...consoleQuery.evaluation.logs.queryOptions({ + input: { + params: { + targetType: resourceType, + targetId: resourceId, + }, + query: { + page: page + 1, + page_size: PAGE_SIZE, + }, + }, + refetchOnWindowFocus: false, + }), + placeholderData: keepPreviousData, + }) + const fileDownloadMutation = useMutation({ + mutationFn: async (file: EvaluationLogFile) => { + const fileInfo = await consoleClient.evaluation.file({ + params: { + targetType: resourceType, + targetId: resourceId, + fileId: file.id, + }, + }) + + downloadUrl({ url: fileInfo.download_url, fileName: file.name }) + }, + }) + const records = useMemo(() => logsQuery.data?.data ?? [], [logsQuery.data?.data]) + const total = logsQuery.data?.total ?? 0 + const isInitialLoading = logsQuery.isLoading && !logsQuery.data + + useEffect(() => { + if (resource.selectedRunId) + return + + const firstRunId = records.map(getLogRunId).find((runId): runId is string => !!runId) + if (firstRunId) + setSelectedRunId(resourceType, resourceId, firstRunId) + }, [records, resource.selectedRunId, resourceId, resourceType, setSelectedRunId]) + + return ( +
+
+ + + + + + + + + + + + + + + + + + + {isInitialLoading && LOADING_ROW_IDS.map(rowId => ( + + + + ))} + {!isInitialLoading && records.map(record => ( + { + const runId = getLogRunId(record) + if (runId) + setSelectedRunId(resourceType, resourceId, runId) + }} + > + + + + + + + ))} + +
+ + {t('history.columns.time')} + + {t('history.columns.creator')}{t('history.columns.version')}{t('history.columns.status')} +
+
+
{formatCreatedAt(record.created_at)}{record.created_by}{record.version || '-'} + {record.result_file + ? + : } + + + event.stopPropagation()} + /> + )} + > + + + { + event.stopPropagation() + fileDownloadMutation.mutate(record.test_file) + }} + > + + { + event.stopPropagation() + if (record.result_file) + fileDownloadMutation.mutate(record.result_file) + }} + > + + + +
+ {!isInitialLoading && records.length === 0 && ( +
+ {t('history.empty')} +
+ )} +
+ {total > PAGE_SIZE && ( + + )} +
+ ) +} + +export default HistoryTab diff --git a/web/app/components/evaluation/components/batch-test-panel/index.tsx b/web/app/components/evaluation/components/batch-test-panel/index.tsx new file mode 100644 index 0000000000..99ee09af11 --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/index.tsx @@ -0,0 +1,120 @@ +'use client' + +import type { BatchTestTab, EvaluationResourceProps } from '../../types' +import { cn } from '@langgenius/dify-ui/cn' +import { useTranslation } from 'react-i18next' +import { Button } from '@/app/components/base/ui/button' +import { toast } from '@/app/components/base/ui/toast' +import { useSaveEvaluationConfigMutation } from '@/service/use-evaluation' +import { isEvaluationRunnable, useEvaluationResource, useEvaluationStore } from '../../store' +import { buildEvaluationConfigPayload } from '../../store-utils' +import { TAB_CLASS_NAME } from '../../utils' +import HistoryTab from './history-tab' +import InputFieldsTab from './input-fields-tab' + +const BATCH_TABS: BatchTestTab[] = ['input-fields', 'history'] + +const BatchTestPanel = ({ + resourceType, + resourceId, +}: EvaluationResourceProps) => { + const { t } = useTranslation('evaluation') + const { t: tCommon } = useTranslation('common') + const tabLabels: Record = { + 'input-fields': t('batch.tabs.input-fields'), + 'history': t('batch.tabs.history'), + } + const resource = useEvaluationResource(resourceType, resourceId) + const setBatchTab = useEvaluationStore(state => state.setBatchTab) + const saveConfigMutation = useSaveEvaluationConfigMutation() + const isRunnable = isEvaluationRunnable(resource) + const isPanelReady = !!resource.judgeModelId && resource.metrics.length > 0 + + const handleSave = () => { + if (!isRunnable) { + toast.warning(t('batch.validation')) + return + } + + const body = buildEvaluationConfigPayload(resource, resourceType) + + if (!body) { + toast.warning(t('batch.validation')) + return + } + + saveConfigMutation.mutate({ + params: { + targetType: resourceType, + targetId: resourceId, + }, + body, + }, { + onSuccess: () => { + toast.success(tCommon('api.saved')) + }, + onError: () => { + toast.error(t('config.saveFailed')) + }, + }) + } + + return ( +
+
+
+
+
{t('batch.title')}
+
{t('batch.description')}
+
+ +
+
+
+
+
+
+
+
+ {BATCH_TABS.map(tab => ( + + ))} +
+
+
+ {resource.activeBatchTab === 'input-fields' && ( + + )} + {resource.activeBatchTab === 'history' && } +
+
+ ) +} + +export default BatchTestPanel diff --git a/web/app/components/evaluation/components/batch-test-panel/input-fields-tab.tsx b/web/app/components/evaluation/components/batch-test-panel/input-fields-tab.tsx new file mode 100644 index 0000000000..27b1b062cb --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/input-fields-tab.tsx @@ -0,0 +1,71 @@ +import type { EvaluationResourceProps } from '../../types' +import { useTranslation } from 'react-i18next' +import { Button } from '@/app/components/base/ui/button' +import { getEvaluationMockConfig } from '../../mock' +import InputFieldsRequirements from './input-fields/input-fields-requirements' +import UploadRunPopover from './input-fields/upload-run-popover' +import { useInputFieldsActions } from './input-fields/use-input-fields-actions' +import { usePublishedInputFields } from './input-fields/use-published-input-fields' + +type InputFieldsTabProps = EvaluationResourceProps & { + isPanelReady: boolean + isRunnable: boolean +} + +const InputFieldsTab = ({ + resourceType, + resourceId, + isPanelReady, + isRunnable, +}: InputFieldsTabProps) => { + const { t } = useTranslation('evaluation') + const config = getEvaluationMockConfig(resourceType) + const { inputFields, isInputFieldsLoading } = usePublishedInputFields(resourceType, resourceId) + const actions = useInputFieldsActions({ + resourceType, + resourceId, + inputFields, + isInputFieldsLoading, + isPanelReady, + isRunnable, + templateFileName: config.templateFileName, + }) + + return ( +
+ +
+ + +
+ {!isRunnable && ( +
+ {t('batch.validation')} +
+ )} +
+ ) +} + +export default InputFieldsTab diff --git a/web/app/components/evaluation/components/batch-test-panel/input-fields/input-fields-requirements.tsx b/web/app/components/evaluation/components/batch-test-panel/input-fields/input-fields-requirements.tsx new file mode 100644 index 0000000000..83201ea5a7 --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/input-fields/input-fields-requirements.tsx @@ -0,0 +1,45 @@ +import type { InputField } from './input-fields-utils' +import { useTranslation } from 'react-i18next' + +type InputFieldsRequirementsProps = { + inputFields: InputField[] + isLoading: boolean +} + +const InputFieldsRequirements = ({ + inputFields, + isLoading, +}: InputFieldsRequirementsProps) => { + const { t } = useTranslation('evaluation') + + return ( +
+
{t('batch.requirementsTitle')}
+
{t('batch.requirementsDescription')}
+
+ {isLoading && ( +
+ {t('batch.loadingInputFields')} +
+ )} + {!isLoading && inputFields.length === 0 && ( +
+ {t('batch.noInputFields')} +
+ )} + {!isLoading && inputFields.map(field => ( +
+
+ {field.name} +
+
+ {field.type} +
+
+ ))} +
+
+ ) +} + +export default InputFieldsRequirements diff --git a/web/app/components/evaluation/components/batch-test-panel/input-fields/input-fields-utils.ts b/web/app/components/evaluation/components/batch-test-panel/input-fields/input-fields-utils.ts new file mode 100644 index 0000000000..5a71b81d06 --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/input-fields/input-fields-utils.ts @@ -0,0 +1,54 @@ +import type { StartNodeType } from '@/app/components/workflow/nodes/start/types' +import type { InputVar, Node } from '@/app/components/workflow/types' +import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' + +export type InputField = { + name: string + type: string +} + +export const getGraphNodes = (graph?: Record) => { + return Array.isArray(graph?.nodes) ? graph.nodes as Node[] : [] +} + +export const getStartNodeInputFields = (nodes?: Node[]): InputField[] => { + const startNode = nodes?.find(node => node.data.type === BlockEnum.Start) as Node | undefined + const variables = startNode?.data.variables + + if (!Array.isArray(variables)) + return [] + + return variables + .filter((variable): variable is InputVar => typeof variable.variable === 'string' && !!variable.variable) + .map(variable => ({ + name: variable.variable, + type: inputVarTypeToVarType(variable.type ?? InputVarType.textInput), + })) +} + +const escapeCsvCell = (value: string) => { + if (!/[",\n\r]/.test(value)) + return value + + return `"${value.replace(/"/g, '""')}"` +} + +export const buildTemplateCsvContent = (inputFields: InputField[]) => { + return `${inputFields.map(field => escapeCsvCell(field.name)).join(',')}\n` +} + +export const getFileExtension = (fileName: string) => { + const extension = fileName.split('.').pop() + return extension && extension !== fileName ? extension.toUpperCase() : '' +} + +export const getExampleValue = (field: InputField, booleanLabel: string) => { + if (field.type === 'number') + return '0.7' + + if (field.type === 'boolean') + return booleanLabel + + return field.name +} diff --git a/web/app/components/evaluation/components/batch-test-panel/input-fields/upload-run-popover.tsx b/web/app/components/evaluation/components/batch-test-panel/input-fields/upload-run-popover.tsx new file mode 100644 index 0000000000..e19209ffb5 --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/input-fields/upload-run-popover.tsx @@ -0,0 +1,189 @@ +import type { ChangeEvent, DragEvent } from 'react' +import type { InputField } from './input-fields-utils' +import { cn } from '@langgenius/dify-ui/cn' +import { useRef } from 'react' +import { useTranslation } from 'react-i18next' +import { Button } from '@/app/components/base/ui/button' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' +import { getExampleValue } from './input-fields-utils' + +type UploadRunPopoverProps = { + open: boolean + onOpenChange: (open: boolean) => void + triggerDisabled: boolean + triggerLabel?: string + inputFields: InputField[] + currentFileName: string | null | undefined + currentFileExtension: string + currentFileSize: string | number + isFileUploading: boolean + isRunDisabled: boolean + isRunning: boolean + onUploadFile: (file: File | undefined) => void + onClearUploadedFile: () => void + onDownloadTemplate: () => void + onRun: () => void +} + +const UploadRunPopover = ({ + open, + onOpenChange, + triggerDisabled, + triggerLabel, + inputFields, + currentFileName, + currentFileExtension, + currentFileSize, + isFileUploading, + isRunDisabled, + isRunning, + onUploadFile, + onClearUploadedFile, + onDownloadTemplate, + onRun, +}: UploadRunPopoverProps) => { + const { t } = useTranslation('evaluation') + const { t: tCommon } = useTranslation('common') + const fileInputRef = useRef(null) + const previewFields = inputFields.slice(0, 3) + const booleanExampleValue = t('conditions.boolean.true') + + const handleFileChange = (event: ChangeEvent) => { + onUploadFile(event.target.files?.[0]) + event.target.value = '' + } + + const handleDropFile = (event: DragEvent) => { + event.preventDefault() + onUploadFile(event.dataTransfer.files?.[0]) + } + + return ( + + + {triggerLabel ?? t('batch.uploadAndRun')} + + )} + /> + +
+
+ + {currentFileName + ? ( +
+
+
+
+
+ {currentFileName} +
+
+ {!!currentFileExtension && {currentFileExtension}} + {!!currentFileExtension && !!currentFileSize && ·} + {!!currentFileSize && {currentFileSize}} +
+
+
+ {isFileUploading && ( +
+
+ ) + : ( +
event.preventDefault()} + onDrop={handleDropFile} + > + +
+
+ {t('batch.uploadDropzonePrefix')} + {' '} + {t('batch.uploadDropzoneEmphasis')} + {' '} + {t('batch.uploadDropzoneSuffix')} +
+
+ {t('batch.uploadDropzoneDownloadPrefix')} + {' '} + +
+
+
+ )} + + {!!previewFields.length && ( +
+
{t('batch.example')}
+
+ {previewFields.map((field, index) => ( +
+
+ {field.name} +
+
+ {getExampleValue(field, booleanExampleValue)} +
+
+ ))} +
+
+ )} +
+
+ + +
+
+
+
+ ) +} + +export default UploadRunPopover diff --git a/web/app/components/evaluation/components/batch-test-panel/input-fields/use-input-fields-actions.ts b/web/app/components/evaluation/components/batch-test-panel/input-fields/use-input-fields-actions.ts new file mode 100644 index 0000000000..8db1b0fbdd --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/input-fields/use-input-fields-actions.ts @@ -0,0 +1,167 @@ +import type { EvaluationResourceProps } from '../../../types' +import type { InputField } from './input-fields-utils' +import { useMutation } from '@tanstack/react-query' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { toast } from '@/app/components/base/ui/toast' +import { upload } from '@/service/base' +import { useStartEvaluationRunMutation } from '@/service/use-evaluation' +import { formatFileSize } from '@/utils/format' +import { useEvaluationResource, useEvaluationStore } from '../../../store' +import { buildEvaluationRunRequest } from '../../../store-utils' +import { buildTemplateCsvContent, getFileExtension } from './input-fields-utils' + +type UploadedFileMeta = { + name: string + size: number +} + +type UseInputFieldsActionsParams = EvaluationResourceProps & { + inputFields: InputField[] + isInputFieldsLoading: boolean + isPanelReady: boolean + isRunnable: boolean + templateFileName: string +} + +export const useInputFieldsActions = ({ + resourceType, + resourceId, + inputFields, + isInputFieldsLoading, + isPanelReady, + isRunnable, + templateFileName, +}: UseInputFieldsActionsParams) => { + const { t } = useTranslation('evaluation') + const resource = useEvaluationResource(resourceType, resourceId) + const setBatchTab = useEvaluationStore(state => state.setBatchTab) + const setSelectedRunId = useEvaluationStore(state => state.setSelectedRunId) + const setUploadedFile = useEvaluationStore(state => state.setUploadedFile) + const setUploadedFileName = useEvaluationStore(state => state.setUploadedFileName) + const startRunMutation = useStartEvaluationRunMutation() + const [isUploadPopoverOpen, setIsUploadPopoverOpen] = useState(false) + const [uploadedFileMeta, setUploadedFileMeta] = useState(null) + const uploadMutation = useMutation({ + mutationFn: (file: File) => { + const formData = new FormData() + formData.append('file', file) + + return upload({ + xhr: new XMLHttpRequest(), + data: formData, + }) + }, + onSuccess: (uploadedFile, file) => { + setUploadedFile(resourceType, resourceId, { + id: uploadedFile.id, + name: typeof uploadedFile.name === 'string' ? uploadedFile.name : file.name, + }) + }, + onError: () => { + setUploadedFileMeta(null) + setUploadedFile(resourceType, resourceId, null) + toast.error(t('batch.uploadError')) + }, + }) + + const isFileUploading = uploadMutation.isPending + const isRunning = startRunMutation.isPending + const uploadedFileId = resource.uploadedFileId + const currentFileName = uploadedFileMeta?.name ?? resource.uploadedFileName + const canDownloadTemplate = isPanelReady && !isInputFieldsLoading && inputFields.length > 0 + const isRunDisabled = !isRunnable || !uploadedFileId || isFileUploading || isRunning + const uploadButtonDisabled = !isPanelReady || isInputFieldsLoading || isRunning + + const handleDownloadTemplate = () => { + if (!inputFields.length) { + toast.warning(t('batch.noInputFields')) + return + } + + const content = buildTemplateCsvContent(inputFields) + const link = document.createElement('a') + link.href = `data:text/csv;charset=utf-8,${encodeURIComponent(content)}` + link.download = templateFileName + link.click() + } + + const handleRun = () => { + if (!isRunnable) { + toast.warning(t('batch.validation')) + return + } + + if (isFileUploading) { + toast.warning(t('batch.uploading')) + return + } + + if (!uploadedFileId) { + toast.warning(t('batch.fileRequired')) + return + } + + const body = buildEvaluationRunRequest(resource, uploadedFileId, resourceType) + + if (!body) { + toast.warning(t('batch.validation')) + return + } + + startRunMutation.mutate({ + params: { + targetType: resourceType, + targetId: resourceId, + }, + body, + }, { + onSuccess: (run) => { + toast.success(t('batch.runStarted')) + setSelectedRunId(resourceType, resourceId, run.id) + setIsUploadPopoverOpen(false) + setBatchTab(resourceType, resourceId, 'history') + }, + onError: () => { + toast.error(t('batch.runFailed')) + }, + }) + } + + const handleUploadFile = (file: File | undefined) => { + if (!file) { + setUploadedFileMeta(null) + setUploadedFile(resourceType, resourceId, null) + return + } + + setUploadedFileMeta({ + name: file.name, + size: file.size, + }) + setUploadedFileName(resourceType, resourceId, file.name) + uploadMutation.mutate(file) + } + + const handleClearUploadedFile = () => { + setUploadedFileMeta(null) + setUploadedFile(resourceType, resourceId, null) + } + + return { + canDownloadTemplate, + currentFileExtension: currentFileName ? getFileExtension(currentFileName) : '', + currentFileName, + currentFileSize: uploadedFileMeta ? formatFileSize(uploadedFileMeta.size) : '', + handleClearUploadedFile, + handleDownloadTemplate, + handleRun, + handleUploadFile, + isFileUploading, + isRunning, + isRunDisabled, + isUploadPopoverOpen, + setIsUploadPopoverOpen, + uploadButtonDisabled, + } +} diff --git a/web/app/components/evaluation/components/batch-test-panel/input-fields/use-published-input-fields.ts b/web/app/components/evaluation/components/batch-test-panel/input-fields/use-published-input-fields.ts new file mode 100644 index 0000000000..a319603026 --- /dev/null +++ b/web/app/components/evaluation/components/batch-test-panel/input-fields/use-published-input-fields.ts @@ -0,0 +1,29 @@ +import type { EvaluationResourceType } from '../../../types' +import { useMemo } from 'react' +import { useSnippetPublishedWorkflow } from '@/service/use-snippet-workflows' +import { useAppWorkflow } from '@/service/use-workflow' +import { getGraphNodes, getStartNodeInputFields } from './input-fields-utils' + +export const usePublishedInputFields = ( + resourceType: EvaluationResourceType, + resourceId: string, +) => { + const { data: currentAppWorkflow, isLoading: isAppWorkflowLoading } = useAppWorkflow(resourceType === 'apps' ? resourceId : '') + const { data: currentSnippetWorkflow, isLoading: isSnippetWorkflowLoading } = useSnippetPublishedWorkflow(resourceType === 'snippets' ? resourceId : '') + + const inputFields = useMemo(() => { + if (resourceType === 'apps') + return getStartNodeInputFields(currentAppWorkflow?.graph.nodes) + + if (resourceType === 'snippets') + return getStartNodeInputFields(getGraphNodes(currentSnippetWorkflow?.graph)) + + return [] + }, [currentAppWorkflow?.graph.nodes, currentSnippetWorkflow?.graph, resourceType]) + + return { + inputFields, + isInputFieldsLoading: (resourceType === 'apps' && isAppWorkflowLoading) + || (resourceType === 'snippets' && isSnippetWorkflowLoading), + } +} diff --git a/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx b/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx new file mode 100644 index 0000000000..cbe44c05b2 --- /dev/null +++ b/web/app/components/evaluation/components/conditions-section/add-condition-select.tsx @@ -0,0 +1,75 @@ +'use client' + +import type { ConditionMetricOptionGroup, EvaluationResourceProps } from '../../types' +import { cn } from '@langgenius/dify-ui/cn' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + Select, + SelectContent, + SelectGroup, + SelectGroupLabel, + SelectItem, + SelectTrigger, +} from '@/app/components/base/ui/select' +import { useEvaluationStore } from '../../store' +import { getConditionMetricValueTypeTranslationKey } from '../../utils' + +type AddConditionSelectProps = EvaluationResourceProps & { + metricOptionGroups: ConditionMetricOptionGroup[] + disabled: boolean +} + +const AddConditionSelect = ({ + resourceType, + resourceId, + metricOptionGroups, + disabled, +}: AddConditionSelectProps) => { + const { t } = useTranslation('evaluation') + const addCondition = useEvaluationStore(state => state.addCondition) + const [selectKey, setSelectKey] = useState(0) + + return ( + + ) +} + +export default AddConditionSelect diff --git a/web/app/components/evaluation/components/conditions-section/condition-group.tsx b/web/app/components/evaluation/components/conditions-section/condition-group.tsx new file mode 100644 index 0000000000..c37fb615dc --- /dev/null +++ b/web/app/components/evaluation/components/conditions-section/condition-group.tsx @@ -0,0 +1,302 @@ +'use client' + +import type { + ComparisonOperator, + ConditionMetricOption, + EvaluationResourceProps, + JudgmentConditionItem, +} from '../../types' +import { cn } from '@langgenius/dify-ui/cn' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' +import { + Select, + SelectContent, + SelectGroup, + SelectGroupLabel, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/app/components/base/ui/select' +import { getAllowedOperators, requiresConditionValue, useEvaluationResource, useEvaluationStore } from '../../store' +import { + buildConditionMetricOptions, + getComparisonOperatorLabel, + getConditionMetricValueTypeTranslationKey, + groupConditionMetricOptions, + isSelectorEqual, + serializeVariableSelector, +} from '../../utils' + +type ConditionMetricLabelProps = { + metric?: ConditionMetricOption + placeholder: string +} + +type ConditionMetricSelectProps = { + metric?: ConditionMetricOption + metricOptions: ConditionMetricOption[] + placeholder: string + onChange: (variableSelector: [string, string]) => void +} + +type ConditionOperatorSelectProps = { + operator: ComparisonOperator + operators: ComparisonOperator[] + onChange: (operator: ComparisonOperator) => void +} + +type ConditionValueInputProps = { + metric?: ConditionMetricOption + condition: JudgmentConditionItem + onChange: (value: string | string[] | boolean | null) => void +} + +type ConditionGroupProps = EvaluationResourceProps + +const getMetricValueTypeIconClassName = (valueType: ConditionMetricOption['valueType']) => { + if (valueType === 'number') + return 'i-ri-hashtag' + + if (valueType === 'boolean') + return 'i-ri-checkbox-circle-line' + + return 'i-ri-bar-chart-box-line' +} + +const ConditionMetricLabel = ({ + metric, + placeholder, +}: ConditionMetricLabelProps) => { + if (!metric) + return {placeholder} + + return ( +
+
+ + {metric.itemLabel} +
+ {metric.groupLabel} +
+ ) +} + +const ConditionMetricSelect = ({ + metric, + metricOptions, + placeholder, + onChange, +}: ConditionMetricSelectProps) => { + const { t } = useTranslation('evaluation') + const groupedMetricOptions = useMemo(() => { + return groupConditionMetricOptions(metricOptions) + }, [metricOptions]) + + return ( + + ) +} + +const ConditionOperatorSelect = ({ + operator, + operators, + onChange, +}: ConditionOperatorSelectProps) => { + const { t } = useTranslation() + + return ( + + ) +} + +const ConditionValueInput = ({ + metric, + condition, + onChange, +}: ConditionValueInputProps) => { + const { t } = useTranslation('evaluation') + + if (!metric || !requiresConditionValue(condition.comparisonOperator)) + return null + + if (metric.valueType === 'boolean') { + return ( +
+ +
+ ) + } + + const isMultiValue = condition.comparisonOperator === 'in' || condition.comparisonOperator === 'not in' + const inputValue = Array.isArray(condition.value) + ? condition.value.join(', ') + : typeof condition.value === 'boolean' + ? '' + : condition.value ?? '' + + return ( +
+ { + if (isMultiValue) { + onChange(e.target.value.split(',').map(item => item.trim()).filter(Boolean)) + return + } + + onChange(e.target.value === '' ? null : e.target.value) + }} + /> +
+ ) +} + +const ConditionGroup = ({ + resourceType, + resourceId, +}: ConditionGroupProps) => { + const { t } = useTranslation('evaluation') + const resource = useEvaluationResource(resourceType, resourceId) + const metricOptions = useMemo(() => buildConditionMetricOptions(resource.metrics), [resource.metrics]) + const logicalLabels = { + and: t('conditions.logical.and'), + or: t('conditions.logical.or'), + } + const setConditionLogicalOperator = useEvaluationStore(state => state.setConditionLogicalOperator) + const removeCondition = useEvaluationStore(state => state.removeCondition) + const updateConditionMetric = useEvaluationStore(state => state.updateConditionMetric) + const updateConditionOperator = useEvaluationStore(state => state.updateConditionOperator) + const updateConditionValue = useEvaluationStore(state => state.updateConditionValue) + + return ( +
+
+
+
+ {(['and', 'or'] as const).map(operator => ( + + ))} +
+
+
+ +
+ {resource.judgmentConfig.conditions.map((condition) => { + const metric = metricOptions.find(option => isSelectorEqual(option.variableSelector, condition.variableSelector)) + const allowedOperators = getAllowedOperators(resource.metrics, condition.variableSelector) + const showValue = !!metric && requiresConditionValue(condition.comparisonOperator) + + return ( +
+
+
+
+ updateConditionMetric(resourceType, resourceId, condition.id, value)} + /> +
+
+ updateConditionOperator(resourceType, resourceId, condition.id, value)} + /> +
+ {showValue && ( +
+ updateConditionValue(resourceType, resourceId, condition.id, value)} + /> +
+ )} +
+
+ +
+
+ ) + })} +
+
+ ) +} + +export default ConditionGroup diff --git a/web/app/components/evaluation/components/conditions-section/index.tsx b/web/app/components/evaluation/components/conditions-section/index.tsx new file mode 100644 index 0000000000..fb28a56a38 --- /dev/null +++ b/web/app/components/evaluation/components/conditions-section/index.tsx @@ -0,0 +1,51 @@ +'use client' + +import type { EvaluationResourceProps } from '../../types' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { useEvaluationResource } from '../../store' +import { buildConditionMetricOptions, groupConditionMetricOptions } from '../../utils' +import { InlineSectionHeader } from '../section-header' +import AddConditionSelect from './add-condition-select' +import ConditionGroup from './condition-group' + +const ConditionsSection = ({ + resourceType, + resourceId, +}: EvaluationResourceProps) => { + const { t } = useTranslation('evaluation') + const resource = useEvaluationResource(resourceType, resourceId) + const conditionMetricOptions = useMemo(() => buildConditionMetricOptions(resource.metrics), [resource.metrics]) + const groupedConditionMetricOptions = useMemo(() => groupConditionMetricOptions(conditionMetricOptions), [conditionMetricOptions]) + const canAddCondition = conditionMetricOptions.length > 0 + + return ( +
+ +
+ {resource.judgmentConfig.conditions.length === 0 && ( +
+ {t('conditions.emptyDescription')} +
+ )} + {resource.judgmentConfig.conditions.length > 0 && ( + + )} + +
+
+ ) +} + +export default ConditionsSection diff --git a/web/app/components/evaluation/components/custom-metric-editor/__tests__/index.spec.tsx b/web/app/components/evaluation/components/custom-metric-editor/__tests__/index.spec.tsx new file mode 100644 index 0000000000..5f2c026908 --- /dev/null +++ b/web/app/components/evaluation/components/custom-metric-editor/__tests__/index.spec.tsx @@ -0,0 +1,364 @@ +import type { EvaluationMetric } from '../../../types' +import type { CodeNodeType } from '@/app/components/workflow/nodes/code/types' +import type { EndNodeType } from '@/app/components/workflow/nodes/end/types' +import type { StartNodeType } from '@/app/components/workflow/nodes/start/types' +import type { Node } from '@/app/components/workflow/types' +import type { SnippetWorkflow } from '@/types/snippet' +import type { FetchWorkflowDraftResponse } from '@/types/workflow' +import { render, screen } from '@testing-library/react' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { BlockEnum, InputVarType, VarType } from '@/app/components/workflow/types' +import CustomMetricEditorCard from '..' +import { useEvaluationStore } from '../../../store' + +const mockUseAppWorkflow = vi.hoisted(() => vi.fn()) +const mockUseSnippetPublishedWorkflow = vi.hoisted(() => vi.fn()) +const mockUseAvailableEvaluationWorkflows = vi.hoisted(() => vi.fn()) +const mockUseInfiniteScroll = vi.hoisted(() => vi.fn()) +const mockPublishedGraphVariablePicker = vi.hoisted(() => vi.fn()) + +vi.mock('@/service/use-workflow', () => ({ + useAppWorkflow: (...args: unknown[]) => mockUseAppWorkflow(...args), +})) + +vi.mock('@/service/use-snippet-workflows', () => ({ + useSnippetPublishedWorkflow: (...args: unknown[]) => mockUseSnippetPublishedWorkflow(...args), +})) + +vi.mock('@/service/use-evaluation', () => ({ + useAvailableEvaluationWorkflows: (...args: unknown[]) => mockUseAvailableEvaluationWorkflows(...args), +})) + +vi.mock('ahooks', () => ({ + useInfiniteScroll: (...args: unknown[]) => mockUseInfiniteScroll(...args), +})) + +vi.mock('../published-graph-variable-picker', () => ({ + default: (props: Record) => { + mockPublishedGraphVariablePicker(props) + return
+ }, +})) + +const createStartNode = (): Node => ({ + id: 'start-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + variables: [ + { + variable: 'user_question', + label: 'User Question', + type: InputVarType.textInput, + required: true, + }, + { + variable: 'retrieved_context', + label: 'Retrieved Context', + type: InputVarType.textInput, + required: true, + }, + ], + }, +}) + +const createEndNode = ( + outputs: EndNodeType['outputs'], +): Node => ({ + id: 'end-node', + type: 'custom', + position: { x: 100, y: 0 }, + data: { + type: BlockEnum.End, + title: 'End', + desc: '', + outputs, + }, +}) + +const createCodeNode = ( + id: string, + title: string, + outputs: Record, +): Node => ({ + id, + type: 'custom', + position: { x: 100, y: 0 }, + data: { + type: BlockEnum.Code, + title, + desc: '', + code: '', + code_language: CodeLanguage.python3, + outputs: Object.fromEntries( + Object.entries(outputs).map(([key, value]) => [ + key, + { + type: value.type, + children: null, + }, + ]), + ), + variables: [], + }, +}) + +const createWorkflow = ( + nodes: Node[], +): FetchWorkflowDraftResponse => ({ + id: 'workflow-1', + graph: { + nodes, + edges: [], + }, + features: {}, + created_at: 1710000000, + created_by: { + id: 'user-1', + name: 'User One', + email: 'user-one@example.com', + }, + hash: 'hash-1', + updated_at: 1710000001, + updated_by: { + id: 'user-2', + name: 'User Two', + email: 'user-two@example.com', + }, + tool_published: true, + environment_variables: [], + conversation_variables: [], + version: '1', + marked_name: 'Evaluation Workflow', + marked_comment: 'Published', +}) + +const createSnippetWorkflow = ( + nodes: Node[], +): SnippetWorkflow => ({ + id: 'snippet-workflow-1', + graph: { + nodes, + edges: [], + }, + features: {}, + hash: 'snippet-hash-1', + created_at: 1710000000, + updated_at: 1710000001, +}) + +const createMetric = (): EvaluationMetric => ({ + id: 'metric-1', + optionId: 'custom-1', + kind: 'custom-workflow', + label: 'Custom Evaluator', + description: 'Map workflow variables to your evaluation inputs.', + valueType: 'number', + customConfig: { + workflowId: 'workflow-1', + workflowAppId: 'workflow-app-1', + workflowName: 'Evaluation Workflow', + mappings: [{ + id: 'mapping-1', + inputVariableId: 'user_question', + outputVariableId: 'current-node.answer', + }, { + id: 'mapping-2', + inputVariableId: 'retrieved_context', + outputVariableId: 'current-node.score', + }], + outputs: [], + }, +}) + +describe('CustomMetricEditorCard', () => { + beforeEach(() => { + vi.clearAllMocks() + useEvaluationStore.setState({ resources: {} }) + mockPublishedGraphVariablePicker.mockReset() + + mockUseInfiniteScroll.mockImplementation(() => undefined) + mockUseAvailableEvaluationWorkflows.mockReturnValue({ + data: { + pages: [{ + items: [], + page: 1, + limit: 20, + has_more: false, + }], + }, + fetchNextPage: vi.fn(), + hasNextPage: false, + isFetching: false, + isFetchingNextPage: false, + isLoading: false, + }) + mockUseSnippetPublishedWorkflow.mockReturnValue({ data: undefined }) + }) + + // Verify the selected evaluation workflow still drives the output summary section. + describe('Outputs', () => { + it('should render the selected workflow outputs from the end node', () => { + const selectedWorkflow = createWorkflow([ + createStartNode(), + createEndNode([ + { variable: 'answer_score', value_selector: ['end', 'answer_score'], value_type: VarType.number }, + { variable: 'reason', value_selector: ['end', 'reason'], value_type: VarType.string }, + ]), + ]) + const currentAppWorkflow = createWorkflow([ + createCodeNode('current-node', 'Current Node', { + answer: { type: VarType.string }, + score: { type: VarType.number }, + }), + ]) + + mockUseAppWorkflow.mockImplementation((appId: string) => { + if (appId === 'workflow-app-1') + return { data: selectedWorkflow } + if (appId === 'app-under-test') + return { data: currentAppWorkflow } + + return { data: undefined } + }) + + render( + , + ) + + expect(screen.getByText('evaluation.metrics.custom.outputTitle')).toBeInTheDocument() + expect(screen.getAllByText('answer_score').length).toBeGreaterThan(0) + expect(screen.getAllByText('number').length).toBeGreaterThan(0) + expect(screen.getAllByText('reason').length).toBeGreaterThan(0) + expect(screen.getAllByText('string').length).toBeGreaterThan(0) + }) + + it('should hide the output section when the selected workflow has no end outputs', () => { + const selectedWorkflow = createWorkflow([ + createStartNode(), + createEndNode([]), + ]) + const currentAppWorkflow = createWorkflow([ + createCodeNode('current-node', 'Current Node', { + answer: { type: VarType.string }, + }), + ]) + + mockUseAppWorkflow.mockImplementation((appId: string) => { + if (appId === 'workflow-app-1') + return { data: selectedWorkflow } + if (appId === 'app-under-test') + return { data: currentAppWorkflow } + + return { data: undefined } + }) + + render( + , + ) + + expect(screen.queryByText('evaluation.metrics.custom.outputTitle')).not.toBeInTheDocument() + }) + }) + + // Verify mapping rows use workflow start variables on the left and current published graph variables on the right. + describe('Variable Mapping', () => { + it('should pass the current app published graph and saved selector values to the picker', () => { + const selectedWorkflow = createWorkflow([ + createStartNode(), + createEndNode([ + { variable: 'answer_score', value_selector: ['end', 'answer_score'], value_type: VarType.number }, + { variable: 'reason', value_selector: ['end', 'reason'], value_type: VarType.string }, + ]), + ]) + const currentAppWorkflow = createWorkflow([ + createStartNode(), + createCodeNode('current-node', 'Current Node', { + answer: { type: VarType.string }, + score: { type: VarType.number }, + }), + ]) + + mockUseAppWorkflow.mockImplementation((appId: string) => { + if (appId === 'workflow-app-1') + return { data: selectedWorkflow } + if (appId === 'app-under-test') + return { data: currentAppWorkflow } + + return { data: undefined } + }) + + render( + , + ) + + expect(screen.getByText('user_question')).toBeInTheDocument() + expect(screen.getByText('retrieved_context')).toBeInTheDocument() + expect(screen.getAllByText('string')).toHaveLength(3) + expect(mockPublishedGraphVariablePicker).toHaveBeenCalledTimes(2) + expect(mockPublishedGraphVariablePicker.mock.calls[0][0]).toMatchObject({ + nodes: currentAppWorkflow.graph.nodes, + edges: currentAppWorkflow.graph.edges, + value: 'current-node.answer', + }) + expect(mockPublishedGraphVariablePicker.mock.calls[1][0]).toMatchObject({ + nodes: currentAppWorkflow.graph.nodes, + edges: currentAppWorkflow.graph.edges, + value: 'current-node.score', + }) + }) + + it('should use the current snippet published graph when editing a snippet evaluation', () => { + const selectedWorkflow = createWorkflow([ + createStartNode(), + createEndNode([ + { variable: 'reason', value_selector: ['end', 'reason'], value_type: VarType.string }, + ]), + ]) + const currentSnippetWorkflow = createSnippetWorkflow([ + createCodeNode('snippet-node', 'Snippet Node', { + result: { type: VarType.string }, + }), + ]) + + mockUseAppWorkflow.mockImplementation((appId: string) => { + if (appId === 'workflow-app-1') + return { data: selectedWorkflow } + + return { data: undefined } + }) + mockUseSnippetPublishedWorkflow.mockReturnValue({ + data: currentSnippetWorkflow, + }) + + render( + , + ) + + expect(mockPublishedGraphVariablePicker).toHaveBeenCalledTimes(2) + expect(mockPublishedGraphVariablePicker.mock.calls[0][0]).toMatchObject({ + nodes: currentSnippetWorkflow.graph.nodes, + edges: currentSnippetWorkflow.graph.edges, + }) + }) + }) +}) diff --git a/web/app/components/evaluation/components/custom-metric-editor/__tests__/workflow-selector.spec.tsx b/web/app/components/evaluation/components/custom-metric-editor/__tests__/workflow-selector.spec.tsx new file mode 100644 index 0000000000..3cf8f30b99 --- /dev/null +++ b/web/app/components/evaluation/components/custom-metric-editor/__tests__/workflow-selector.spec.tsx @@ -0,0 +1,158 @@ +import type { ComponentProps } from 'react' +import type { AvailableEvaluationWorkflow } from '@/types/evaluation' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import WorkflowSelector from '../workflow-selector' + +const mockUseAvailableEvaluationWorkflows = vi.hoisted(() => vi.fn()) +const mockUseInfiniteScroll = vi.hoisted(() => vi.fn()) + +let loadMoreHandler: (() => Promise<{ list: unknown[] }>) | null = null + +vi.mock('@/service/use-evaluation', () => ({ + useAvailableEvaluationWorkflows: (...args: unknown[]) => mockUseAvailableEvaluationWorkflows(...args), +})) + +vi.mock('ahooks', () => ({ + useInfiniteScroll: (...args: unknown[]) => mockUseInfiniteScroll(...args), +})) + +const createWorkflow = ( + overrides: Partial = {}, +): AvailableEvaluationWorkflow => ({ + id: 'workflow-1', + app_id: 'app-1', + app_name: 'Review Workflow App', + type: 'evaluation', + version: '1', + marked_name: 'Review Workflow', + marked_comment: 'Production release', + hash: 'hash-1', + created_by: { + id: 'user-1', + name: 'User One', + email: 'user-one@example.com', + }, + created_at: 1710000000, + updated_by: null, + updated_at: 1710000000, + ...overrides, +}) + +const setupWorkflowQueryMock = (overrides?: { + workflows?: AvailableEvaluationWorkflow[] + hasNextPage?: boolean + isFetchingNextPage?: boolean +}) => { + const fetchNextPage = vi.fn() + + mockUseAvailableEvaluationWorkflows.mockReturnValue({ + data: { + pages: [{ + items: overrides?.workflows ?? [createWorkflow()], + page: 1, + limit: 20, + has_more: overrides?.hasNextPage ?? false, + }], + }, + fetchNextPage, + hasNextPage: overrides?.hasNextPage ?? false, + isFetching: false, + isFetchingNextPage: overrides?.isFetchingNextPage ?? false, + isLoading: false, + }) + + return { fetchNextPage } +} + +const renderWorkflowSelector = (props?: Partial>) => { + return render( + , + ) +} + +describe('WorkflowSelector', () => { + beforeEach(() => { + vi.clearAllMocks() + loadMoreHandler = null + + setupWorkflowQueryMock() + mockUseInfiniteScroll.mockImplementation((handler) => { + loadMoreHandler = handler as () => Promise<{ list: unknown[] }> + }) + }) + + // Cover trigger rendering and selected label fallback. + describe('Rendering', () => { + it('should render the workflow placeholder when value is empty', () => { + renderWorkflowSelector() + + expect(screen.getByRole('button', { name: 'evaluation.metrics.custom.workflowLabel' })).toBeInTheDocument() + expect(screen.getByText('evaluation.metrics.custom.workflowPlaceholder')).toBeInTheDocument() + }) + + it('should render the selected workflow name from props when value is set', () => { + setupWorkflowQueryMock({ workflows: [] }) + + renderWorkflowSelector({ + value: 'workflow-1', + selectedWorkflowName: 'Saved Review Workflow', + }) + + expect(screen.getByText('Saved Review Workflow')).toBeInTheDocument() + }) + }) + + // Cover opening the popover and choosing one workflow option. + describe('Interactions', () => { + it('should call onSelect with the clicked workflow', async () => { + const onSelect = vi.fn() + + renderWorkflowSelector({ onSelect }) + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.custom.workflowLabel' })) + + const option = await screen.findByRole('option', { name: 'Review Workflow' }) + fireEvent.click(option) + + expect(onSelect).toHaveBeenCalledWith(createWorkflow()) + }) + }) + + // Cover the infinite-scroll callback used by the ScrollArea viewport. + describe('Pagination', () => { + it('should fetch the next page when the load-more callback runs and more pages exist', async () => { + const { fetchNextPage } = setupWorkflowQueryMock({ hasNextPage: true }) + + renderWorkflowSelector() + + await waitFor(() => expect(loadMoreHandler).not.toBeNull()) + + await act(async () => { + await loadMoreHandler?.() + }) + + expect(fetchNextPage).toHaveBeenCalledTimes(1) + }) + + it('should not fetch the next page when the current request is already fetching', async () => { + const { fetchNextPage } = setupWorkflowQueryMock({ + hasNextPage: true, + isFetchingNextPage: true, + }) + + renderWorkflowSelector() + + await waitFor(() => expect(loadMoreHandler).not.toBeNull()) + + await act(async () => { + await loadMoreHandler?.() + }) + + expect(fetchNextPage).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/evaluation/components/custom-metric-editor/index.tsx b/web/app/components/evaluation/components/custom-metric-editor/index.tsx new file mode 100644 index 0000000000..694e48d310 --- /dev/null +++ b/web/app/components/evaluation/components/custom-metric-editor/index.tsx @@ -0,0 +1,217 @@ +'use client' + +import type { EvaluationMetric, EvaluationResourceProps } from '../../types' +import type { EndNodeType } from '@/app/components/workflow/nodes/end/types' +import type { StartNodeType } from '@/app/components/workflow/nodes/start/types' +import type { Edge, InputVar, Node } from '@/app/components/workflow/types' +import { useEffect, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import { useSnippetPublishedWorkflow } from '@/service/use-snippet-workflows' +import { useAppWorkflow } from '@/service/use-workflow' +import { isCustomMetricConfigured, useEvaluationStore } from '../../store' +import MappingRow from './mapping-row' +import WorkflowSelector from './workflow-selector' + +type CustomMetricEditorCardProps = EvaluationResourceProps & { + metric: EvaluationMetric +} + +const getWorkflowInputVariables = ( + nodes?: Array, +) => { + const startNode = nodes?.find(node => node.data.type === BlockEnum.Start) as Node | undefined + if (!startNode || !Array.isArray(startNode.data.variables)) + return [] + + return startNode.data.variables.map((variable: InputVar) => ({ + id: variable.variable, + valueType: inputVarTypeToVarType(variable.type ?? InputVarType.textInput), + })) +} + +const getWorkflowOutputs = (nodes?: Array) => { + return (nodes ?? []) + .filter(node => node.data.type === BlockEnum.End) + .flatMap((node) => { + const endNode = node as Node + if (!Array.isArray(endNode.data.outputs)) + return [] + + return endNode.data.outputs + .filter(output => typeof output.variable === 'string' && !!output.variable) + .map(output => ({ + id: output.variable, + valueType: typeof output.value_type === 'string' ? output.value_type : null, + nodeTitle: typeof endNode.data.title === 'string' && endNode.data.title ? endNode.data.title : 'End', + })) + }) +} + +const getWorkflowName = (workflow: { + marked_name?: string + app_name?: string + id: string +}) => { + return workflow.marked_name || workflow.app_name || workflow.id +} + +const getGraphNodes = (graph?: Record) => { + return Array.isArray(graph?.nodes) ? graph.nodes as Node[] : [] +} + +const getGraphEdges = (graph?: Record) => { + return Array.isArray(graph?.edges) ? graph.edges as Edge[] : [] +} + +const CustomMetricEditorCard = ({ + resourceType, + resourceId, + metric, +}: CustomMetricEditorCardProps) => { + const { t } = useTranslation('evaluation') + const setCustomMetricWorkflow = useEvaluationStore(state => state.setCustomMetricWorkflow) + const syncCustomMetricMappings = useEvaluationStore(state => state.syncCustomMetricMappings) + const syncCustomMetricOutputs = useEvaluationStore(state => state.syncCustomMetricOutputs) + const updateCustomMetricMapping = useEvaluationStore(state => state.updateCustomMetricMapping) + const { data: selectedWorkflow } = useAppWorkflow(metric.customConfig?.workflowAppId ?? '') + const { data: currentAppWorkflow } = useAppWorkflow(resourceType === 'apps' ? resourceId : '') + const { data: currentSnippetWorkflow } = useSnippetPublishedWorkflow(resourceType === 'snippets' ? resourceId : '') + const inputVariables = useMemo(() => { + return getWorkflowInputVariables(selectedWorkflow?.graph.nodes) + }, [selectedWorkflow?.graph.nodes]) + const workflowOutputs = useMemo(() => { + return getWorkflowOutputs(selectedWorkflow?.graph.nodes) + }, [selectedWorkflow?.graph.nodes]) + const publishedGraph = useMemo(() => { + if (resourceType === 'apps') { + return { + nodes: currentAppWorkflow?.graph.nodes ?? [], + edges: currentAppWorkflow?.graph.edges ?? [], + environmentVariables: currentAppWorkflow?.environment_variables ?? [], + conversationVariables: currentAppWorkflow?.conversation_variables ?? [], + } + } + + return { + nodes: getGraphNodes(currentSnippetWorkflow?.graph), + edges: getGraphEdges(currentSnippetWorkflow?.graph), + environmentVariables: [], + conversationVariables: [], + } + }, [ + currentAppWorkflow?.conversation_variables, + currentAppWorkflow?.environment_variables, + currentAppWorkflow?.graph.edges, + currentAppWorkflow?.graph.nodes, + currentSnippetWorkflow?.graph, + resourceType, + ]) + const inputVariableIds = useMemo(() => inputVariables.map(variable => variable.id), [inputVariables]) + const isConfigured = isCustomMetricConfigured(metric) + + useEffect(() => { + if (!metric.customConfig?.workflowId) + return + + const currentInputVariableIds = metric.customConfig.mappings + .map(mapping => mapping.inputVariableId) + .filter((value): value is string => !!value) + + if (currentInputVariableIds.length === inputVariableIds.length + && currentInputVariableIds.every((value, index) => value === inputVariableIds[index])) { + return + } + + syncCustomMetricMappings(resourceType, resourceId, metric.id, inputVariableIds) + }, [inputVariableIds, metric.customConfig?.mappings, metric.customConfig?.workflowId, metric.id, resourceId, resourceType, syncCustomMetricMappings]) + + useEffect(() => { + if (!metric.customConfig?.workflowId) + return + + const currentOutputs = metric.customConfig.outputs + if ( + currentOutputs.length === workflowOutputs.length + && currentOutputs.every((output, index) => + output.id === workflowOutputs[index]?.id && output.valueType === workflowOutputs[index]?.valueType, + ) + ) { + return + } + + syncCustomMetricOutputs(resourceType, resourceId, metric.id, workflowOutputs) + }, [metric.customConfig?.outputs, metric.customConfig?.workflowId, metric.id, resourceId, resourceType, syncCustomMetricOutputs, workflowOutputs]) + + if (!metric.customConfig) + return null + + return ( +
+ setCustomMetricWorkflow(resourceType, resourceId, metric.id, { + workflowId: workflow.id, + workflowAppId: workflow.app_id, + workflowName: getWorkflowName(workflow), + })} + /> + +
+
+
{t('metrics.custom.mappingTitle')}
+
+
+ {inputVariables.map((inputVariable) => { + const mapping = metric.customConfig?.mappings.find(item => item.inputVariableId === inputVariable.id) + + return ( + { + if (!mapping) + return + + updateCustomMetricMapping(resourceType, resourceId, metric.id, mapping.id, { outputVariableId }) + }} + /> + ) + })} +
+ {!isConfigured && ( +
+ {t('metrics.custom.mappingWarning')} +
+ )} +
+ + {!!workflowOutputs.length && ( +
+
+ {t('metrics.custom.outputTitle')} +
+
+ {workflowOutputs.map((output, index) => ( +
+ {output.id} + {output.valueType && ( + {output.valueType} + )} + {index < workflowOutputs.length - 1 && ( + , + )} +
+ ))} +
+
+ )} +
+ ) +} + +export default CustomMetricEditorCard diff --git a/web/app/components/evaluation/components/custom-metric-editor/mapping-row.tsx b/web/app/components/evaluation/components/custom-metric-editor/mapping-row.tsx new file mode 100644 index 0000000000..1340682674 --- /dev/null +++ b/web/app/components/evaluation/components/custom-metric-editor/mapping-row.tsx @@ -0,0 +1,64 @@ +'use client' + +import type { + ConversationVariable, + Edge, + EnvironmentVariable, + Node, +} from '@/app/components/workflow/types' +import { useTranslation } from 'react-i18next' +import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' +import PublishedGraphVariablePicker from './published-graph-variable-picker' + +type MappingRowProps = { + inputVariable: { + id: string + valueType: string + } + publishedGraph: { + nodes: Node[] + edges: Edge[] + environmentVariables: EnvironmentVariable[] + conversationVariables: ConversationVariable[] + } + value: string | null + onUpdate: (outputVariableId: string | null) => void +} + +const MappingRow = ({ + inputVariable, + publishedGraph, + value, + onUpdate, +}: MappingRowProps) => { + const { t } = useTranslation('evaluation') + + return ( +
+
+
+ +
{inputVariable.id}
+
{inputVariable.valueType}
+
+
+ +
+ +
+ + +
+ ) +} + +export default MappingRow diff --git a/web/app/components/evaluation/components/custom-metric-editor/published-graph-variable-picker.tsx b/web/app/components/evaluation/components/custom-metric-editor/published-graph-variable-picker.tsx new file mode 100644 index 0000000000..4aeb6b615d --- /dev/null +++ b/web/app/components/evaluation/components/custom-metric-editor/published-graph-variable-picker.tsx @@ -0,0 +1,118 @@ +'use client' + +import type { EndNodeType } from '@/app/components/workflow/nodes/end/types' +import type { + ConversationVariable, + Edge, + EnvironmentVariable, + Node, + ValueSelector, +} from '@/app/components/workflow/types' +import { useMemo } from 'react' +import ReactFlow, { ReactFlowProvider } from 'reactflow' +import { WorkflowContext } from '@/app/components/workflow/context' +import { createHooksStore, HooksStoreContext } from '@/app/components/workflow/hooks-store' +import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' +import { createWorkflowStore } from '@/app/components/workflow/store/workflow' +import { BlockEnum } from '@/app/components/workflow/types' +import { variableTransformer } from '@/app/components/workflow/utils/variable' + +type PublishedGraphVariablePickerProps = { + className?: string + nodes: Node[] + edges: Edge[] + environmentVariables?: EnvironmentVariable[] + conversationVariables?: ConversationVariable[] + placeholder: string + value: string | null + onChange: (value: string | null) => void +} + +const PICKER_NODE_ID = '__evaluation-variable-picker__' + +const createPickerNode = (): Node => ({ + id: PICKER_NODE_ID, + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.End, + title: 'End', + desc: '', + outputs: [], + }, +}) + +const PublishedGraphVariablePicker = ({ + className, + nodes, + edges, + environmentVariables = [], + conversationVariables = [], + placeholder, + value, + onChange, +}: PublishedGraphVariablePickerProps) => { + const workflowStore = useMemo(() => { + const store = createWorkflowStore({}) + store.setState({ + isWorkflowDataLoaded: true, + environmentVariables, + conversationVariables, + ragPipelineVariables: [], + dataSourceList: [], + }) + return store + }, [conversationVariables, environmentVariables]) + + const hooksStore = useMemo(() => createHooksStore({}), []) + + const pickerNodes = useMemo(() => { + return [...nodes, createPickerNode()] + }, [nodes]) + + const pickerValue = useMemo(() => { + if (!value) + return [] + + return variableTransformer(value) as ValueSelector + }, [value]) + + return ( + + +
+ + + + { + if (!Array.isArray(nextValue) || !nextValue.length) { + onChange(null) + return + } + + onChange(nextValue.join('.')) + }} + availableNodes={nodes} + placeholder={placeholder} + /> + +
+
+
+ ) +} + +export default PublishedGraphVariablePicker diff --git a/web/app/components/evaluation/components/custom-metric-editor/workflow-selector.tsx b/web/app/components/evaluation/components/custom-metric-editor/workflow-selector.tsx new file mode 100644 index 0000000000..c85c192aff --- /dev/null +++ b/web/app/components/evaluation/components/custom-metric-editor/workflow-selector.tsx @@ -0,0 +1,214 @@ +'use client' + +import type { AvailableEvaluationWorkflow } from '@/types/evaluation' +import { cn } from '@langgenius/dify-ui/cn' +import { useInfiniteScroll } from 'ahooks' +import * as React from 'react' +import { useDeferredValue, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Input from '@/app/components/base/input' +import Loading from '@/app/components/base/loading' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' +import { + ScrollAreaContent, + ScrollAreaRoot, + ScrollAreaScrollbar, + ScrollAreaThumb, + ScrollAreaViewport, +} from '@/app/components/base/ui/scroll-area' +import { useAvailableEvaluationWorkflows } from '@/service/use-evaluation' + +type WorkflowSelectorProps = { + value: string | null + selectedWorkflowName?: string | null + onSelect: (workflow: AvailableEvaluationWorkflow) => void +} + +const PAGE_SIZE = 20 + +const getWorkflowName = (workflow: AvailableEvaluationWorkflow) => { + return workflow.marked_name || workflow.app_name || workflow.id +} + +const WorkflowSelector = ({ + value, + selectedWorkflowName, + onSelect, +}: WorkflowSelectorProps) => { + const { t } = useTranslation('evaluation') + const [isOpen, setIsOpen] = useState(false) + const [searchText, setSearchText] = useState('') + const deferredSearchText = useDeferredValue(searchText) + const viewportRef = useRef(null) + + const keyword = deferredSearchText.trim() || undefined + + const { + data, + fetchNextPage, + hasNextPage, + isFetching, + isFetchingNextPage, + isLoading, + } = useAvailableEvaluationWorkflows( + { + page: 1, + limit: PAGE_SIZE, + keyword, + }, + { enabled: isOpen }, + ) + + const workflows = useMemo(() => { + return (data?.pages ?? []).flatMap(page => page.items) + }, [data?.pages]) + + const currentWorkflowName = useMemo(() => { + if (!value) + return null + + const selectedWorkflow = workflows.find(workflow => workflow.id === value) + if (selectedWorkflow) + return getWorkflowName(selectedWorkflow) + + return selectedWorkflowName ?? null + }, [selectedWorkflowName, value, workflows]) + + const isNoMore = hasNextPage === false + + useInfiniteScroll( + async () => { + if (!hasNextPage || isFetchingNextPage) + return { list: [] } + + await fetchNextPage() + return { list: [] } + }, + { + target: viewportRef, + isNoMore: () => isNoMore, + reloadDeps: [isFetchingNextPage, isNoMore, keyword], + }, + ) + + const handleOpenChange = (nextOpen: boolean) => { + setIsOpen(nextOpen) + + if (!nextOpen) + setSearchText('') + } + + return ( + + +
+
+
+
+
+
+
+ {currentWorkflowName ?? t('metrics.custom.workflowPlaceholder')} +
+
+
+ + + + )} + /> + + +
+
+ setSearchText(event.target.value)} + onClear={() => setSearchText('')} + /> +
+ + {(isLoading || (isFetching && workflows.length === 0)) + ? ( +
+ +
+ ) + : !workflows.length + ? ( +
+ {t('noData', { ns: 'common' })} +
+ ) + : ( + + + + {workflows.map(workflow => ( + + ))} + + {isFetchingNextPage && ( +
+ +
+ )} +
+
+ + + +
+ )} +
+
+
+ ) +} + +export default React.memo(WorkflowSelector) diff --git a/web/app/components/evaluation/components/judge-model-selector.tsx b/web/app/components/evaluation/components/judge-model-selector.tsx new file mode 100644 index 0000000000..8f9ee4aff6 --- /dev/null +++ b/web/app/components/evaluation/components/judge-model-selector.tsx @@ -0,0 +1,48 @@ +'use client' + +import type { EvaluationResourceProps } from '../types' +import { useEffect } from 'react' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import { useEvaluationResource, useEvaluationStore } from '../store' +import { decodeModelSelection, encodeModelSelection } from '../utils' + +type JudgeModelSelectorProps = EvaluationResourceProps & { + autoSelectFirst?: boolean +} + +const JudgeModelSelector = ({ + resourceType, + resourceId, + autoSelectFirst = true, +}: JudgeModelSelectorProps) => { + const { data: modelList } = useModelList(ModelTypeEnum.textGeneration) + const resource = useEvaluationResource(resourceType, resourceId) + const setJudgeModel = useEvaluationStore(state => state.setJudgeModel) + const selectedModel = decodeModelSelection(resource.judgeModelId) + + useEffect(() => { + if (!autoSelectFirst || resource.judgeModelId || !modelList.length) + return + + const firstProvider = modelList[0] + const firstModel = firstProvider.models[0] + if (!firstProvider || !firstModel) + return + + setJudgeModel(resourceType, resourceId, encodeModelSelection(firstProvider.provider, firstModel.model)) + }, [autoSelectFirst, modelList, resource.judgeModelId, resourceId, resourceType, setJudgeModel]) + + return ( + setJudgeModel(resourceType, resourceId, encodeModelSelection(model.provider, model.model))} + showDeprecatedWarnIcon + triggerClassName="h-8 w-full rounded-lg" + /> + ) +} + +export default JudgeModelSelector diff --git a/web/app/components/evaluation/components/layout/non-pipeline-evaluation.tsx b/web/app/components/evaluation/components/layout/non-pipeline-evaluation.tsx new file mode 100644 index 0000000000..5d47a754ff --- /dev/null +++ b/web/app/components/evaluation/components/layout/non-pipeline-evaluation.tsx @@ -0,0 +1,62 @@ +'use client' + +import type { EvaluationResourceProps } from '../../types' +import { useTranslation } from 'react-i18next' +import { useDocLink } from '@/context/i18n' +import BatchTestPanel from '../batch-test-panel' +import ConditionsSection from '../conditions-section' +import JudgeModelSelector from '../judge-model-selector' +import MetricSection from '../metric-section' +import SectionHeader, { InlineSectionHeader } from '../section-header' + +const NonPipelineEvaluation = ({ + resourceType, + resourceId, +}: EvaluationResourceProps) => { + const { t } = useTranslation('evaluation') + const { t: tCommon } = useTranslation('common') + const docLink = useDocLink() + + return ( +
+
+
+ + {t('description')} + {' '} + + {tCommon('operation.learnMore')} + + + )} + descriptionClassName="max-w-[700px]" + /> +
+ +
+ +
+
+
+ +
+ +
+
+ +
+ +
+
+ ) +} + +export default NonPipelineEvaluation diff --git a/web/app/components/evaluation/components/layout/pipeline-evaluation.tsx b/web/app/components/evaluation/components/layout/pipeline-evaluation.tsx new file mode 100644 index 0000000000..9fb6c70a90 --- /dev/null +++ b/web/app/components/evaluation/components/layout/pipeline-evaluation.tsx @@ -0,0 +1,96 @@ +'use client' + +import type { EvaluationResourceProps } from '../../types' +import { useEffect } from 'react' +import { useTranslation } from 'react-i18next' +import { useDocLink } from '@/context/i18n' +import { useEvaluationStore } from '../../store' +import HistoryTab from '../batch-test-panel/history-tab' +import JudgeModelSelector from '../judge-model-selector' +import PipelineBatchActions from '../pipeline/pipeline-batch-actions' +import PipelineMetricsSection from '../pipeline/pipeline-metrics-section' +import PipelineResultsPanel from '../pipeline/pipeline-results-panel' +import SectionHeader, { InlineSectionHeader } from '../section-header' + +const PipelineEvaluation = ({ + resourceType, + resourceId, +}: EvaluationResourceProps) => { + const { t } = useTranslation('evaluation') + const { t: tCommon } = useTranslation('common') + const docLink = useDocLink() + const ensureResource = useEvaluationStore(state => state.ensureResource) + + useEffect(() => { + ensureResource(resourceType, resourceId) + }, [ensureResource, resourceId, resourceType]) + + return ( +
+
+
+ + {t('description')} + {' '} + + {tCommon('operation.learnMore')} + + + )} + /> +
+ +
+
+
+ +
+ +
+
+ + + + +
+
+ +
+ +
+ +
+
+ +
+ +
+
+ ) +} + +export default PipelineEvaluation diff --git a/web/app/components/evaluation/components/metric-section/__tests__/index.spec.tsx b/web/app/components/evaluation/components/metric-section/__tests__/index.spec.tsx new file mode 100644 index 0000000000..a234635c50 --- /dev/null +++ b/web/app/components/evaluation/components/metric-section/__tests__/index.spec.tsx @@ -0,0 +1,229 @@ +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, fireEvent, render, screen } from '@testing-library/react' +import MetricSection from '..' +import { useEvaluationStore } from '../../../store' + +const mockUseAvailableEvaluationWorkflows = vi.hoisted(() => vi.fn()) +const mockUseAvailableEvaluationMetrics = vi.hoisted(() => vi.fn()) +const mockUseEvaluationNodeInfoMutation = vi.hoisted(() => vi.fn()) + +vi.mock('@/service/use-evaluation', () => ({ + useAvailableEvaluationWorkflows: (...args: unknown[]) => mockUseAvailableEvaluationWorkflows(...args), + useAvailableEvaluationMetrics: (...args: unknown[]) => mockUseAvailableEvaluationMetrics(...args), + useEvaluationNodeInfoMutation: (...args: unknown[]) => mockUseEvaluationNodeInfoMutation(...args), +})) + +const resourceType = 'apps' as const +const resourceId = 'metric-section-resource' + +const renderMetricSection = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + + return render( + + + , + ) +} + +describe('MetricSection', () => { + beforeEach(() => { + vi.clearAllMocks() + useEvaluationStore.setState({ resources: {} }) + + mockUseAvailableEvaluationMetrics.mockReturnValue({ + data: { + metrics: ['answer-correctness'], + }, + isLoading: false, + }) + + mockUseAvailableEvaluationWorkflows.mockReturnValue({ + data: { + pages: [{ items: [], page: 1, limit: 20, has_more: false }], + }, + fetchNextPage: vi.fn(), + hasNextPage: false, + isFetching: false, + isFetchingNextPage: false, + isLoading: false, + }) + + mockUseEvaluationNodeInfoMutation.mockReturnValue({ + isPending: false, + mutate: (_input: unknown, options?: { onSuccess?: (data: Record>) => void }) => { + options?.onSuccess?.({ + 'answer-correctness': [ + { node_id: 'node-answer', title: 'Answer Node', type: 'llm' }, + ], + }) + }, + }) + }) + + // Verify the empty state block extracted from MetricSection. + describe('Empty State', () => { + it('should render the metric empty state when no metrics are selected', () => { + renderMetricSection() + + expect(screen.getByText('evaluation.metrics.description')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'evaluation.metrics.add' })).toBeInTheDocument() + }) + }) + + // Verify the extracted builtin metric card presentation and removal flow. + describe('Builtin Metric Card', () => { + it('should render node badges for a builtin metric and remove it when delete is clicked', () => { + // Arrange + act(() => { + useEvaluationStore.getState().addBuiltinMetric(resourceType, resourceId, 'answer-correctness', [ + { node_id: 'node-answer', title: 'Answer Node', type: 'llm' }, + ]) + }) + + // Act + renderMetricSection() + + // Assert + expect(screen.getByText('Answer Correctness')).toBeInTheDocument() + expect(screen.getByText('Answer Node')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.remove' })) + + expect(screen.queryByText('Answer Correctness')).not.toBeInTheDocument() + expect(screen.getByText('evaluation.metrics.description')).toBeInTheDocument() + }) + + it('should render the all-nodes label when a builtin metric has no node selection', () => { + // Arrange + act(() => { + useEvaluationStore.getState().addBuiltinMetric(resourceType, resourceId, 'answer-correctness', []) + }) + + // Act + renderMetricSection() + + // Assert + expect(screen.getByText('evaluation.metrics.nodesAll')).toBeInTheDocument() + }) + + it('should collapse and expand the node section when the metric header is clicked', () => { + // Arrange + act(() => { + useEvaluationStore.getState().addBuiltinMetric(resourceType, resourceId, 'answer-correctness', [ + { node_id: 'node-answer', title: 'Answer Node', type: 'llm' }, + ]) + }) + + // Act + renderMetricSection() + + const toggleButton = screen.getByRole('button', { name: 'evaluation.metrics.collapseNodes' }) + fireEvent.click(toggleButton) + + // Assert + expect(screen.queryByText('Answer Node')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'evaluation.metrics.expandNodes' })).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.expandNodes' })) + + expect(screen.getByText('Answer Node')).toBeInTheDocument() + }) + + it('should show only unselected nodes in the add-node dropdown and append the selected node', () => { + // Arrange + mockUseEvaluationNodeInfoMutation.mockReturnValue({ + isPending: false, + mutate: (_input: unknown, options?: { onSuccess?: (data: Record>) => void }) => { + options?.onSuccess?.({ + 'answer-correctness': [ + { node_id: 'node-1', title: 'LLM 1', type: 'llm' }, + { node_id: 'node-2', title: 'LLM 2', type: 'llm' }, + ], + }) + }, + }) + + act(() => { + useEvaluationStore.getState().addBuiltinMetric(resourceType, resourceId, 'answer-correctness', [ + { node_id: 'node-1', title: 'LLM 1', type: 'llm' }, + ]) + }) + + // Act + renderMetricSection() + + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.addNode' })) + + // Assert + expect(screen.queryByRole('menuitem', { name: 'LLM 1' })).not.toBeInTheDocument() + fireEvent.click(screen.getByRole('menuitem', { name: 'LLM 2' })) + + expect(screen.getByText('LLM 2')).toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'evaluation.metrics.addNode' })).not.toBeInTheDocument() + }) + + it('should hide the add-node button when the builtin metric already targets all nodes', () => { + // Arrange + mockUseEvaluationNodeInfoMutation.mockReturnValue({ + isPending: false, + mutate: (_input: unknown, options?: { onSuccess?: (data: Record>) => void }) => { + options?.onSuccess?.({ + 'answer-correctness': [ + { node_id: 'node-1', title: 'LLM 1', type: 'llm' }, + { node_id: 'node-2', title: 'LLM 2', type: 'llm' }, + ], + }) + }, + }) + + act(() => { + useEvaluationStore.getState().addBuiltinMetric(resourceType, resourceId, 'answer-correctness', []) + }) + + // Act + renderMetricSection() + + // Assert + expect(screen.getByText('evaluation.metrics.nodesAll')).toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'evaluation.metrics.addNode' })).not.toBeInTheDocument() + }) + }) + + // Verify the extracted custom metric editor card renders inside the metric card. + describe('Custom Metric Card', () => { + it('should render the custom metric editor card when a custom metric is added', () => { + act(() => { + useEvaluationStore.getState().addCustomMetric(resourceType, resourceId) + }) + + renderMetricSection() + + expect(screen.getByText('Custom Evaluator')).toBeInTheDocument() + expect(screen.getByText('evaluation.metrics.custom.warningBadge')).toBeInTheDocument() + expect(screen.getByText('evaluation.metrics.custom.workflowPlaceholder')).toBeInTheDocument() + expect(screen.getByText('evaluation.metrics.custom.mappingTitle')).toBeInTheDocument() + }) + + it('should disable adding another custom metric when one already exists', () => { + // Arrange + act(() => { + useEvaluationStore.getState().addCustomMetric(resourceType, resourceId) + }) + + // Act + renderMetricSection() + fireEvent.click(screen.getByRole('button', { name: 'evaluation.metrics.add' })) + + // Assert + expect(screen.getByRole('button', { name: /evaluation.metrics.custom.footerTitle/i })).toBeDisabled() + expect(screen.getByText('evaluation.metrics.custom.limitDescription')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/evaluation/components/metric-section/builtin-metric-card.tsx b/web/app/components/evaluation/components/metric-section/builtin-metric-card.tsx new file mode 100644 index 0000000000..f8103aa2cc --- /dev/null +++ b/web/app/components/evaluation/components/metric-section/builtin-metric-card.tsx @@ -0,0 +1,161 @@ +'use client' + +import type { EvaluationMetric, EvaluationResourceProps } from '../../types' +import type { NodeInfo } from '@/types/evaluation' +import { cn } from '@langgenius/dify-ui/cn' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { Button } from '@/app/components/base/ui/button' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/app/components/base/ui/dropdown-menu' +import { useEvaluationStore } from '../../store' +import { dedupeNodeInfoList, getMetricVisual, getNodeVisual, getToneClasses } from '../metric-selector/utils' + +type BuiltinMetricCardProps = EvaluationResourceProps & { + metric: EvaluationMetric + availableNodeInfoList?: NodeInfo[] +} + +const BuiltinMetricCard = ({ + resourceType, + resourceId, + metric, + availableNodeInfoList = [], +}: BuiltinMetricCardProps) => { + const { t } = useTranslation('evaluation') + const updateBuiltinMetric = useEvaluationStore(state => state.addBuiltinMetric) + const removeMetric = useEvaluationStore(state => state.removeMetric) + const [isExpanded, setIsExpanded] = useState(true) + const metricVisual = getMetricVisual(metric.optionId) + const metricToneClasses = getToneClasses(metricVisual.tone) + const selectedNodeInfoList = metric.nodeInfoList ?? [] + const selectedNodeIdSet = new Set(selectedNodeInfoList.map(nodeInfo => nodeInfo.node_id)) + const selectableNodeInfoList = selectedNodeInfoList.length > 0 + ? availableNodeInfoList.filter(nodeInfo => !selectedNodeIdSet.has(nodeInfo.node_id)) + : [] + const shouldShowAddNode = selectableNodeInfoList.length > 0 + + return ( +
+
+ + + +
+ + {isExpanded && ( +
+ {selectedNodeInfoList.length + ? selectedNodeInfoList.map((nodeInfo) => { + const nodeVisual = getNodeVisual(nodeInfo) + const nodeToneClasses = getToneClasses(nodeVisual.tone) + + return ( +
+
+
+ {nodeInfo.title} + +
+ ) + }) + : ( + {t('metrics.nodesAll')} + )} + + {shouldShowAddNode && ( + + + )} + > + + + {selectableNodeInfoList.map((nodeInfo) => { + const nodeVisual = getNodeVisual(nodeInfo) + const nodeToneClasses = getToneClasses(nodeVisual.tone) + + return ( + updateBuiltinMetric( + resourceType, + resourceId, + metric.optionId, + dedupeNodeInfoList([...selectedNodeInfoList, nodeInfo]), + )} + > +
+
+
+ {nodeInfo.title} +
+
+ ) + })} +
+
+ )} +
+ )} +
+ ) +} + +export default BuiltinMetricCard diff --git a/web/app/components/evaluation/components/metric-section/custom-metric-card.tsx b/web/app/components/evaluation/components/metric-section/custom-metric-card.tsx new file mode 100644 index 0000000000..d0c04bac04 --- /dev/null +++ b/web/app/components/evaluation/components/metric-section/custom-metric-card.tsx @@ -0,0 +1,63 @@ +'use client' + +import type { EvaluationMetric, EvaluationResourceProps } from '../../types' +import { cn } from '@langgenius/dify-ui/cn' +import { useTranslation } from 'react-i18next' +import Badge from '@/app/components/base/badge' +import { Button } from '@/app/components/base/ui/button' +import { isCustomMetricConfigured, useEvaluationStore } from '../../store' +import CustomMetricEditorCard from '../custom-metric-editor' +import { getToneClasses } from '../metric-selector/utils' + +type CustomMetricCardProps = EvaluationResourceProps & { + metric: EvaluationMetric +} + +const CustomMetricCard = ({ + resourceType, + resourceId, + metric, +}: CustomMetricCardProps) => { + const { t } = useTranslation('evaluation') + const removeMetric = useEvaluationStore(state => state.removeMetric) + const isCustomMetricInvalid = !isCustomMetricConfigured(metric) + const metricToneClasses = getToneClasses('indigo') + + return ( +
+
+
+
+
+
{metric.label}
+
+ +
+ {isCustomMetricInvalid && ( + + {t('metrics.custom.warningBadge')} + + )} + +
+
+ + +
+ ) +} + +export default CustomMetricCard diff --git a/web/app/components/evaluation/components/metric-section/index.tsx b/web/app/components/evaluation/components/metric-section/index.tsx new file mode 100644 index 0000000000..7887b2733b --- /dev/null +++ b/web/app/components/evaluation/components/metric-section/index.tsx @@ -0,0 +1,95 @@ +'use client' + +import type { EvaluationResourceProps } from '../../types' +import type { NodeInfo } from '@/types/evaluation' +import { useEffect, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useAvailableEvaluationMetrics, useEvaluationNodeInfoMutation } from '@/service/use-evaluation' +import { useEvaluationResource } from '../../store' +import MetricSelector from '../metric-selector' +import { toEvaluationTargetType } from '../metric-selector/utils' +import { InlineSectionHeader } from '../section-header' +import MetricCard from './metric-card' +import MetricSectionEmptyState from './metric-section-empty-state' + +const MetricSection = ({ + resourceType, + resourceId, +}: EvaluationResourceProps) => { + const { t } = useTranslation('evaluation') + const resource = useEvaluationResource(resourceType, resourceId) + const [nodeInfoMap, setNodeInfoMap] = useState>({}) + const hasMetrics = resource.metrics.length > 0 + const hasBuiltinMetrics = resource.metrics.some(metric => metric.kind === 'builtin') + const shouldLoadNodeInfo = resourceType !== 'datasets' && !!resourceId && hasBuiltinMetrics + const { data: availableMetricsData } = useAvailableEvaluationMetrics(shouldLoadNodeInfo) + const { mutate: loadNodeInfo } = useEvaluationNodeInfoMutation() + const availableMetricIds = useMemo(() => availableMetricsData?.metrics ?? [], [availableMetricsData?.metrics]) + const availableMetricIdsKey = availableMetricIds.join(',') + const resolvedNodeInfoMap = shouldLoadNodeInfo ? nodeInfoMap : {} + + useEffect(() => { + if (!shouldLoadNodeInfo || availableMetricIds.length === 0) + return + + let isActive = true + + loadNodeInfo( + { + params: { + targetType: toEvaluationTargetType(resourceType), + targetId: resourceId, + }, + body: { + metrics: availableMetricIds, + }, + }, + { + onSuccess: (data) => { + if (!isActive) + return + + setNodeInfoMap(data) + }, + onError: () => { + if (!isActive) + return + + setNodeInfoMap({}) + }, + }, + ) + + return () => { + isActive = false + } + }, [availableMetricIds, availableMetricIdsKey, loadNodeInfo, resourceId, resourceType, shouldLoadNodeInfo]) + + return ( +
+ +
+ {!hasMetrics && } + {resource.metrics.map(metric => ( + + ))} + +
+
+ ) +} + +export default MetricSection diff --git a/web/app/components/evaluation/components/metric-section/metric-card.tsx b/web/app/components/evaluation/components/metric-section/metric-card.tsx new file mode 100644 index 0000000000..5d2088d52e --- /dev/null +++ b/web/app/components/evaluation/components/metric-section/metric-card.tsx @@ -0,0 +1,39 @@ +'use client' + +import type { EvaluationMetric, EvaluationResourceProps } from '../../types' +import type { NodeInfo } from '@/types/evaluation' +import BuiltinMetricCard from './builtin-metric-card' +import CustomMetricCard from './custom-metric-card' + +type MetricCardProps = EvaluationResourceProps & { + metric: EvaluationMetric + availableNodeInfoList?: NodeInfo[] +} + +const MetricCard = ({ + resourceType, + resourceId, + metric, + availableNodeInfoList, +}: MetricCardProps) => { + if (metric.kind === 'custom-workflow') { + return ( + + ) + } + + return ( + + ) +} + +export default MetricCard diff --git a/web/app/components/evaluation/components/metric-section/metric-section-empty-state.tsx b/web/app/components/evaluation/components/metric-section/metric-section-empty-state.tsx new file mode 100644 index 0000000000..918a93e430 --- /dev/null +++ b/web/app/components/evaluation/components/metric-section/metric-section-empty-state.tsx @@ -0,0 +1,18 @@ +type MetricSectionEmptyStateProps = { + description: string +} + +const MetricSectionEmptyState = ({ description }: MetricSectionEmptyStateProps) => { + return ( +
+
+
+
+ {description} +
+
+ ) +} + +export default MetricSectionEmptyState diff --git a/web/app/components/evaluation/components/metric-selector/index.tsx b/web/app/components/evaluation/components/metric-selector/index.tsx new file mode 100644 index 0000000000..0f91271a9f --- /dev/null +++ b/web/app/components/evaluation/components/metric-selector/index.tsx @@ -0,0 +1,151 @@ +'use client' + +import type { ChangeEvent } from 'react' +import type { MetricSelectorProps } from './types' +import { cn } from '@langgenius/dify-ui/cn' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/app/components/base/ui/popover' +import { useEvaluationResource, useEvaluationStore } from '../../store' +import SelectorEmptyState from './selector-empty-state' +import SelectorFooter from './selector-footer' +import SelectorMetricSection from './selector-metric-section' +import { useMetricSelectorData } from './use-metric-selector-data' + +const MetricSelector = ({ + resourceType, + resourceId, + triggerClassName, + triggerStyle = 'button', +}: MetricSelectorProps) => { + const { t } = useTranslation('evaluation') + const resource = useEvaluationResource(resourceType, resourceId) + const addCustomMetric = useEvaluationStore(state => state.addCustomMetric) + const [open, setOpen] = useState(false) + const [query, setQuery] = useState('') + const [nodeInfoMap, setNodeInfoMap] = useState>>({}) + const [collapsedMetricMap, setCollapsedMetricMap] = useState>({}) + const [expandedMetricNodesMap, setExpandedMetricNodesMap] = useState>({}) + const hasCustomMetric = resource.metrics.some(metric => metric.kind === 'custom-workflow') + + const { + builtinMetricMap, + filteredSections, + isRemoteLoading, + toggleNodeSelection, + } = useMetricSelectorData({ + open, + query, + resourceType, + resourceId, + nodeInfoMap, + setNodeInfoMap, + }) + + const handleOpenChange = (nextOpen: boolean) => { + setOpen(nextOpen) + + if (nextOpen) { + setQuery('') + setCollapsedMetricMap({}) + setExpandedMetricNodesMap({}) + } + } + + const handleQueryChange = (event: ChangeEvent) => { + setQuery(event.target.value) + } + + return ( + + +