From b1b977e284c2ac4536cfd0a1fff9ca16a3df7ef7 Mon Sep 17 00:00:00 2001 From: hj24 Date: Mon, 27 Apr 2026 09:49:40 +0800 Subject: [PATCH] refactor: quota v3 integration (#35436) Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/enums/quota_type.py | 188 ---------- api/services/app_generate_service.py | 6 +- api/services/async_workflow_service.py | 44 ++- api/services/billing_service.py | 98 ++++- api/services/feature_service.py | 2 +- api/services/quota_service.py | 233 ++++++++++++ api/services/trigger/webhook_service.py | 73 ++-- api/services/workflow_service.py | 21 +- api/tasks/trigger_processing_tasks.py | 97 ++--- api/tasks/workflow_schedule_tasks.py | 35 +- .../services/test_app_generate_service.py | 23 +- .../test_webhook_service_relationships.py | 16 +- .../trigger/test_trigger_e2e.py | 4 +- api/tests/unit_tests/enums/__init__.py | 0 api/tests/unit_tests/enums/test_quota_type.py | 349 ++++++++++++++++++ .../services/test_app_generate_service.py | 12 +- .../services/test_async_workflow_service.py | 26 +- .../services/test_billing_service.py | 160 +++++++- .../tasks/test_trigger_processing_tasks.py | 204 ++++++++++ 19 files changed, 1255 insertions(+), 336 deletions(-) create mode 100644 api/services/quota_service.py create mode 100644 api/tests/unit_tests/enums/__init__.py create mode 100644 api/tests/unit_tests/enums/test_quota_type.py create mode 100644 api/tests/unit_tests/tasks/test_trigger_processing_tasks.py diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 9f511b88ef..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,56 +1,17 @@ -import logging -from dataclasses import dataclass from enum import StrEnum, auto -logger = logging.getLogger(__name__) - - -@dataclass -class QuotaCharge: - """ - Result of a quota consumption operation. - - Attributes: - success: Whether the quota charge succeeded - charge_id: UUID for refund, or None if failed/disabled - """ - - success: bool - charge_id: str | None - _quota_type: "QuotaType" - - def refund(self) -> None: - """ - Refund this quota charge. - - Safe to call even if charge failed or was disabled. - This method guarantees no exceptions will be raised. - """ - if self.charge_id: - self._quota_type.refund(self.charge_id) - logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id) - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. - - Add additional types here whenever new billable features become available. """ - # Trigger execution quota TRIGGER = auto() - - # Workflow execution quota WORKFLOW = auto() - UNLIMITED = auto() @property def billing_key(self) -> str: - """ - Get the billing key for the feature. - """ match self: case QuotaType.TRIGGER: return "trigger_event" @@ -58,152 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: - """ - Consume quota for the feature. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with success status and charge_id for refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to consume must be greater than 0") - - try: - response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount) - - if response.get("result") != "success": - logger.warning( - "Failed to consume quota for %s, feature %s details: %s", - tenant_id, - self.value, - response.get("detail"), - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - charge_id = response.get("history_id") - logger.debug( - "Successfully consumed %d %s quota for tenant %s, charge_id: %s", - amount, - self.value, - tenant_id, - charge_id, - ) - return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self) - - except QuotaExceededError: - raise - except Exception: - # fail-safe: allow request on billing errors - logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - """ - Check if tenant has sufficient quota without consuming. - - Args: - tenant_id: The tenant identifier - amount: Amount to check (default: 1) - - Returns: - True if quota is sufficient, False otherwise - """ - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - # fail-safe: allow request on billing errors - return True - - def refund(self, charge_id: str) -> None: - """ - Refund quota using charge_id from consume(). - - This method guarantees no exceptions will be raised. - All errors are logged but silently handled. - - Args: - charge_id: The UUID returned from consume() - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not charge_id: - logger.warning("Cannot refund: charge_id is empty") - return - - logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id) - - response = BillingService.refund_tenant_feature_plan_usage(charge_id) - if response.get("result") == "success": - logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id) - else: - logger.warning("Refund failed for charge_id: %s", charge_id) - - except Exception: - # Catch ALL exceptions - refund must never fail - logger.exception("Failed to refund quota for charge_id: %s", charge_id) - # Don't raise - refund is best-effort and must be silent - - def get_remaining(self, tenant_id: str) -> int: - """ - Get remaining quota for the tenant. - - Args: - tenant_id: The tenant identifier - - Returns: - Remaining quota amount - """ - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key) - # Assuming the API returns a dict with 'remaining' or 'limit' and 'used' - if isinstance(usage_info, dict): - return usage_info.get("remaining", 0) - # If it returns a simple number, treat it as remaining - return int(usage_info) if usage_info else 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 - - -def unlimited() -> QuotaCharge: - """ - Return a quota charge for unlimited quota. - - This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type. - """ - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 8ff53d143b..d6c01e9dcc 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -106,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") @@ -116,6 +117,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) + quota_charge.commit() effective_mode = ( AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode ) diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index a731d5c048..ceda30e950 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -88,7 +89,10 @@ class AsyncWorkflowService: raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}") # 2. Get workflow - workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) + workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session) + + # commit read only session before starting the billig rpc call + session.commit() # 3. Get dispatcher based on tenant subscription dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) @@ -131,9 +135,10 @@ class AsyncWorkflowService: trigger_log = trigger_log_repo.create(trigger_log) session.commit() - # 7. Check and consume quota + # 7. Reserve quota (commit after successful dispatch) + quota_charge = unlimited() try: - QuotaType.WORKFLOW.consume(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED @@ -153,13 +158,18 @@ class AsyncWorkflowService: # 9. Dispatch to appropriate queue task_data_dict = task_data.model_dump(mode="json") - task: AsyncResult[Any] | None = None - if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) - elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) - else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) + try: + task: AsyncResult[Any] | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED @@ -295,13 +305,21 @@ class AsyncWorkflowService: return [log.to_dict() for log in logs] @staticmethod - def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow: + def _get_workflow( + workflow_service: WorkflowService, + app_model: App, + workflow_id: str | None = None, + session: Session | None = None, + ) -> Workflow: """ Get workflow for the app Args: app_model: App model instance workflow_id: Optional specific workflow ID + session: Reuse this SQLAlchemy session for the lookup when provided, + so the caller's explicit session bears the connection cost + instead of Flask's request-scoped ``db.session``. Returns: Workflow instance @@ -311,12 +329,12 @@ class AsyncWorkflowService: """ if workflow_id: # Get specific published workflow - workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) + workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session) if not workflow: raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}") else: # Get default published workflow - workflow = workflow_service.get_published_workflow(app_model) + workflow = workflow_service.get_published_workflow(app_model, session=session) if not workflow: raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index a1362ccad6..c0e23cdc6f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) + + +class _TenantFeatureQuota(TypedDict): + usage: int + limit: int + reset_date: NotRequired[int] + + +class TenantFeatureQuotaInfo(TypedDict): + """Response of /quota/info. + + NOTE (hj24): + - Same convention as BillingInfo: billing may return int fields as str, + always keep non-strict mode to auto-coerce. + """ + + trigger_event: _TenantFeatureQuota + api_rate_limit: _TenantFeatureQuota + + +_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo) + + class _BillingQuota(TypedDict): size: int limit: int @@ -149,11 +193,63 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info + @classmethod + def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo: + params = {"tenant_id": tenant_id} + return _tenant_feature_quota_info_adapter.validate_python( + cls._send_request("GET", "/quota/info", params=params) + ) + + @classmethod + def quota_reserve( + cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "request_id": request_id, + "amount": amount, + } + if meta: + payload["meta"] = meta + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) + + @classmethod + def quota_commit( + cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + "actual_amount": actual_amount, + } + if meta: + payload["meta"] = meta + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) + + @classmethod + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) + ) + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 38518378f7..9477c28bf3 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -290,7 +290,7 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + features_usage_info = BillingService.get_quota_info(tenant_id) features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..4c784315c7 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index ca4e43e516..5d99900a04 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -38,6 +38,7 @@ from models.workflow import Workflow from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData @@ -798,45 +799,47 @@ class WebhookService: Exception: If workflow execution fails """ try: - with Session(db.engine) as session: - # Prepare inputs for the webhook node - # The webhook node expects webhook_data in the inputs - workflow_inputs = cls.build_workflow_inputs(webhook_data) + workflow_inputs = cls.build_workflow_inputs(webhook_data) - # Create trigger data - trigger_data = WebhookTriggerData( - app_id=webhook_trigger.app_id, - workflow_id=workflow.id, - root_node_id=webhook_trigger.node_id, # Start from the webhook node - inputs=workflow_inputs, - tenant_id=webhook_trigger.tenant_id, + trigger_data = WebhookTriggerData( + app_id=webhook_trigger.app_id, + workflow_id=workflow.id, + root_node_id=webhook_trigger.node_id, + inputs=workflow_inputs, + tenant_id=webhook_trigger.tenant_id, + ) + + end_user = EndUserService.get_or_create_end_user_by_type( + type=InvokeFrom.TRIGGER, + tenant_id=webhook_trigger.tenant_id, + app_id=webhook_trigger.app_id, + user_id=None, + ) + + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) + logger.info( + "Tenant %s rate limited, skipping webhook trigger %s", + webhook_trigger.tenant_id, + webhook_trigger.webhook_id, ) + raise - end_user = EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.TRIGGER, - tenant_id=webhook_trigger.tenant_id, - app_id=webhook_trigger.app_id, - user_id=None, - ) - - # consume quota before triggering workflow execution - try: - QuotaType.TRIGGER.consume(webhook_trigger.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) - logger.info( - "Tenant %s rate limited, skipping webhook trigger %s", - webhook_trigger.tenant_id, - webhook_trigger.webhook_id, + try: + # NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()` + # trigger_workflow_async need to handle multipe session commits internally + with Session(db.engine, expire_on_commit=False) as session: + AsyncWorkflowService.trigger_workflow_async( + session, + end_user, + trigger_data, ) - raise - - # Trigger workflow execution asynchronously - AsyncWorkflowService.trigger_workflow_async( - session, - end_user, - trigger_data, - ) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index d4b9095ce5..f97b85dc2b 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -156,11 +156,18 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: + def get_published_workflow_by_id( + self, app_model: App, workflow_id: str, session: Session | None = None + ) -> Workflow | None: """ fetch published workflow by workflow_id + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, @@ -178,16 +185,20 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Workflow | None: + def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None: """ Get published workflow + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 25ea53dfac..8505375b6a 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -27,7 +27,7 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, @@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -258,59 +259,58 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity + + # Ensure expire_on_commit is set to False to remain workflows available with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) - end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( - type=InvokeFrom.TRIGGER, - tenant_id=subscription.tenant_id, - app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], - user_id=user_id, - ) - for plugin_trigger in subscribers: - # Get workflow from mapping - workflow: Workflow | None = workflows.get(plugin_trigger.app_id) - if not workflow: - logger.error( - "Workflow not found for app %s", - plugin_trigger.app_id, - ) - continue + end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( + type=InvokeFrom.TRIGGER, + tenant_id=subscription.tenant_id, + app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], + user_id=user_id, + ) - # Find the trigger node in the workflow - event_node = None - for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): - if node_id == plugin_trigger.node_id: - event_node = node_config - break - - if not event_node: - logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) - continue - - # invoke trigger - trigger_metadata = PluginTriggerMetadata( - plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", - endpoint_id=subscription.endpoint_id, - provider_id=subscription.provider_id, - event_name=event_name, - icon_filename=trigger_entity.identity.icon or "", - icon_dark_filename=trigger_entity.identity.icon_dark or "", + for plugin_trigger in subscribers: + workflow: Workflow | None = workflows.get(plugin_trigger.app_id) + if not workflow: + logger.error( + "Workflow not found for app %s", + plugin_trigger.app_id, ) + continue - # consume quota before invoking trigger - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) - logger.info( - "Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id - ) - return 0 + event_node = None + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): + if node_id == plugin_trigger.node_id: + event_node = node_config + break - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) - invoke_response: TriggerInvokeEventResponse | None = None + if not event_node: + logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) + continue + + trigger_metadata = PluginTriggerMetadata( + plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", + endpoint_id=subscription.endpoint_id, + provider_id=subscription.provider_id, + event_name=event_name, + icon_filename=trigger_entity.identity.icon or "", + icon_dark_filename=trigger_entity.identity.icon_dark or "", + ) + + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) + logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id) + return dispatched_count + + node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) + invoke_response: TriggerInvokeEventResponse | None = None + + with session_factory.create_session() as session: try: invoke_response = TriggerManager.invoke_trigger_event( tenant_id=subscription.tenant_id, @@ -387,6 +387,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", @@ -401,7 +402,7 @@ def dispatch_triggered_workflow( plugin_trigger.app_id, ) - return dispatched_count + return dispatched_count def dispatch_triggered_workflows( diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..7638652000 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -32,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ + # Ensure expire_on_commit is set to False to remain schedule/tenant_owner available with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: @@ -41,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None: if not tenant_owner: raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}") - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) - logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) - return + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) + logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) + return - try: - # Production dispatch: Trigger the workflow normally + try: + with session_factory.create_session() as session: response = AsyncWorkflowService.trigger_workflow_async( session=session, user=tenant_owner, @@ -61,9 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) - logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) - except Exception as e: - quota_charge.refund() - raise ScheduleExecutionError( - f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" - ) from e + quota_charge.commit() + logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) + except Exception as e: + quota_charge.refund() + raise ScheduleExecutionError( + f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" + ) from e diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..3229693fd4 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -36,12 +36,19 @@ class TestAppGenerateService: ) as mock_message_based_generator, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config, patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -101,6 +108,8 @@ class TestAppGenerateService: mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100 mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_quota_dify_config.BILLING_ENABLED = False + mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 @@ -118,6 +127,7 @@ class TestAppGenerateService: "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, + "quota_dify_config": mock_quota_dify_config, "global_dify_config": mock_global_dify_config, } @@ -465,6 +475,7 @@ class TestAppGenerateService: # Set BILLING_ENABLED to True for this test mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True # Setup test arguments @@ -478,8 +489,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py index ec10c51e04..85ce3a6ba6 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -10,6 +10,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from enums.quota_type import QuotaType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import AppTriggerStatus, AppTriggerType from models.model import App @@ -290,17 +291,26 @@ class TestWebhookServiceTriggerExecutionWithContainers: end_user = SimpleNamespace(id=str(uuid4())) webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + quota_charge = MagicMock() + with ( patch( "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", return_value=end_user, ), - patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume, + patch( + "services.trigger.webhook_service.QuotaService.reserve", + return_value=quota_charge, + ) as mock_reserve, patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, ): WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) - mock_consume.assert_called_once_with(webhook_trigger.tenant_id) + mock_reserve.assert_called_once() + reserve_args = mock_reserve.call_args.args + assert reserve_args[0] == QuotaType.TRIGGER + assert reserve_args[1] == webhook_trigger.tenant_id + quota_charge.commit.assert_called_once() mock_trigger.assert_called_once() trigger_args = mock_trigger.call_args.args assert trigger_args[1] is end_user @@ -327,7 +337,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: return_value=SimpleNamespace(id=str(uuid4())), ), patch( - "services.trigger.webhook_service.QuotaType.TRIGGER.consume", + "services.trigger.webhook_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), ), patch( diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 55aec49878..9c20118e27 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) diff --git a/api/tests/unit_tests/enums/__init__.py b/api/tests/unit_tests/enums/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enums/test_quota_type.py b/api/tests/unit_tests/enums/test_quota_type.py new file mode 100644 index 0000000000..f256ff3b4e --- /dev/null +++ b/api/tests/unit_tests/enums/test_quota_type.py @@ -0,0 +1,349 @@ +"""Unit tests for QuotaType, QuotaService, and QuotaCharge.""" + +from unittest.mock import patch + +import pytest + +from enums.quota_type import QuotaType +from services.quota_service import QuotaCharge, QuotaService, unlimited + + +class TestQuotaType: + def test_billing_key_trigger(self): + assert QuotaType.TRIGGER.billing_key == "trigger_event" + + def test_billing_key_workflow(self): + assert QuotaType.WORKFLOW.billing_key == "api_rate_limit" + + def test_billing_key_unlimited_raises(self): + with pytest.raises(ValueError, match="Invalid quota type"): + _ = QuotaType.UNLIMITED.billing_key + + +class TestQuotaService: + def test_reserve_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService"), + ): + mock_cfg.BILLING_ENABLED = False + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_reserve_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0) + + def test_reserve_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99} + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1) + + assert charge.success is True + assert charge.charge_id == "rid-1" + assert charge._tenant_id == "t1" + assert charge._feature_key == "trigger_event" + assert charge._amount == 1 + mock_bs.quota_reserve.assert_called_once() + + def test_reserve_no_reservation_id_raises(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {} + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_quota_exceeded_propagates(self): + from services.errors.app import QuotaExceededError + + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1) + + with pytest.raises(QuotaExceededError): + QuotaService.reserve(QuotaType.TRIGGER, "t1") + + def test_reserve_api_exception_returns_unlimited(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.side_effect = RuntimeError("network") + + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") + assert charge.success is True + assert charge.charge_id is None + + def test_consume_calls_reserve_and_commit(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"} + mock_bs.quota_commit.return_value = {} + + charge = QuotaService.consume(QuotaType.TRIGGER, "t1") + assert charge.success is True + mock_bs.quota_commit.assert_called_once() + + def test_check_billing_disabled(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = False + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_check_zero_amount_raises(self): + with patch("services.quota_service.dify_config") as mock_cfg: + mock_cfg.BILLING_ENABLED = True + with pytest.raises(ValueError, match="greater than 0"): + QuotaService.check(QuotaType.TRIGGER, "t1", amount=0) + + def test_check_sufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=100), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True + + def test_check_insufficient_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=5), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False + + def test_check_unlimited_quota(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=-1), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True + + def test_check_exception_returns_true(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", side_effect=RuntimeError), + ): + mock_cfg.BILLING_ENABLED = True + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True + + def test_release_billing_disabled(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = False + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_empty_reservation(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event") + mock_bs.quota_release.assert_not_called() + + def test_release_success(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.return_value = {} + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + mock_bs.quota_release.assert_called_once_with( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1" + ) + + def test_release_exception_swallowed(self): + with ( + patch("services.quota_service.dify_config") as mock_cfg, + patch("services.billing_service.BillingService") as mock_bs, + ): + mock_cfg.BILLING_ENABLED = True + mock_bs.quota_release.side_effect = RuntimeError("fail") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_get_remaining_normal(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70 + + def test_get_remaining_unlimited(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_over_limit_returns_zero(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_exception_returns_neg1(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.side_effect = RuntimeError + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 + + def test_get_remaining_empty_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_non_dict_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = "invalid" + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + def test_get_remaining_feature_not_in_response(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} + remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1") + assert remaining == 0 + + def test_get_remaining_non_dict_feature_info(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"} + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 + + +class TestQuotaCharge: + def test_commit_success(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + mock_bs.quota_commit.assert_called_once_with( + tenant_id="t1", + feature_key="trigger_event", + reservation_id="rid-1", + actual_amount=1, + ) + assert charge._committed is True + + def test_commit_with_actual_amount(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=10, + ) + charge.commit(actual_amount=5) + call_kwargs = mock_bs.quota_commit.call_args[1] + assert call_kwargs["actual_amount"] == 5 + + def test_commit_idempotent(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.return_value = {} + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + charge.commit() + assert mock_bs.quota_commit.call_count == 1 + + def test_commit_no_charge_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_no_tenant_id_noop(self): + with patch("services.billing_service.BillingService") as mock_bs: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + _feature_key="trigger_event", + ) + charge.commit() + mock_bs.quota_commit.assert_not_called() + + def test_commit_exception_swallowed(self): + with patch("services.billing_service.BillingService") as mock_bs: + mock_bs.quota_commit.side_effect = RuntimeError("fail") + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + _amount=1, + ) + charge.commit() + + def test_refund_success(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id="t1", + _feature_key="trigger_event", + ) + charge.refund() + mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") + + def test_refund_no_charge_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) + charge.refund() + mock_rel.assert_not_called() + + def test_refund_no_tenant_id_noop(self): + with patch.object(QuotaService, "release") as mock_rel: + charge = QuotaCharge( + success=True, + charge_id="rid-1", + _quota_type=QuotaType.TRIGGER, + _tenant_id=None, + ) + charge.refund() + mock_rel.assert_not_called() + + +class TestUnlimited: + def test_unlimited_returns_success_with_no_charge_id(self): + charge = unlimited() + assert charge.success is True + assert charge.charge_id is None + assert charge._quota_type == QuotaType.UNLIMITED diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index 119a7adc45..d3f9c5dd9f 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -23,6 +23,7 @@ import pytest import services.app_generate_service as ags_module from core.app.entities.app_invoke_entities import InvokeFrom +from enums.quota_type import QuotaType from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError @@ -448,8 +449,8 @@ class TestGenerateBilling: def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() - consume_mock = mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + reserve_mock = mocker.patch( + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( @@ -468,7 +469,8 @@ class TestGenerateBilling: invoke_from=InvokeFrom.SERVICE_API, streaming=False, ) - consume_mock.assert_called_once_with("tenant-id") + reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id") + quota_charge.commit.assert_called_once() def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): from services.errors.app import QuotaExceededError @@ -476,7 +478,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), ) @@ -493,7 +495,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.consume", + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index ca6ff9dc63..1b9cc8a2ff 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -57,7 +57,7 @@ class TestAsyncWorkflowService: - repo: SQLAlchemyWorkflowTriggerLogRepository - dispatcher_manager_class: QueueDispatcherManager class - dispatcher: dispatcher instance - - quota_workflow: QuotaType.WORKFLOW + - quota_service: QuotaService mock - get_workflow: AsyncWorkflowService._get_workflow method - professional_task: execute_workflow_professional - team_task: execute_workflow_team @@ -72,7 +72,7 @@ class TestAsyncWorkflowService: mock_repo.create.side_effect = _create_side_effect mock_dispatcher = MagicMock() - quota_workflow = MagicMock() + mock_quota_service = MagicMock() with ( patch.object( @@ -88,8 +88,8 @@ class TestAsyncWorkflowService: ) as mock_get_workflow, patch.object( async_workflow_service_module, - "QuotaType", - new=SimpleNamespace(WORKFLOW=quota_workflow), + "QuotaService", + new=mock_quota_service, ), patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, @@ -102,7 +102,7 @@ class TestAsyncWorkflowService: "repo": mock_repo, "dispatcher_manager_class": mock_dispatcher_manager_class, "dispatcher": mock_dispatcher, - "quota_workflow": quota_workflow, + "quota_service": mock_quota_service, "get_workflow": mock_get_workflow, "professional_task": mock_professional_task, "team_task": mock_team_task, @@ -141,6 +141,9 @@ class TestAsyncWorkflowService: mocks["team_task"].delay.return_value = task_result mocks["sandbox_task"].delay.return_value = task_result + quota_charge_mock = MagicMock() + mocks["quota_service"].reserve.return_value = quota_charge_mock + class DummyAccount: def __init__(self, user_id: str): self.id = user_id @@ -158,8 +161,9 @@ class TestAsyncWorkflowService: assert result.status == "queued" assert result.queue == queue_name - mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") - assert session.commit.call_count == 2 + mocks["quota_service"].reserve.assert_called_once() + quota_charge_mock.commit.assert_called_once() + assert session.commit.call_count == 3 created_log = mocks["repo"].create.call_args[0][0] assert created_log.status == WorkflowTriggerStatus.QUEUED @@ -245,7 +249,7 @@ class TestAsyncWorkflowService: mocks = async_workflow_trigger_mocks mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM mocks["get_workflow"].return_value = workflow - mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + mocks["quota_service"].reserve.side_effect = QuotaExceededError( feature="workflow", tenant_id="tenant-123", required=1, @@ -262,7 +266,7 @@ class TestAsyncWorkflowService: trigger_data=trigger_data, ) - assert session.commit.call_count == 2 + assert session.commit.call_count == 3 updated_log = mocks["repo"].update.call_args[0][0] assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED assert "Quota limit reached" in updated_log.error @@ -465,7 +469,7 @@ class TestAsyncWorkflowServiceGetWorkflow: # Assert assert result == workflow - workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123") + workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None) workflow_service.get_published_workflow.assert_not_called() def test_should_raise_when_specific_workflow_id_not_found(self): @@ -493,7 +497,7 @@ class TestAsyncWorkflowServiceGetWorkflow: # Assert assert result == workflow - workflow_service.get_published_workflow.assert_called_once_with(app_model) + workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None) workflow_service.get_published_workflow_by_id.assert_not_called() def test_should_raise_when_default_published_workflow_not_found(self): diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 9ab0171eac..36592196c6 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -425,7 +425,7 @@ class TestBillingServiceUsageCalculation: yield mock def test_get_tenant_feature_plan_usage_info(self, mock_send_request): - """Test retrieval of tenant feature plan usage information.""" + """Test retrieval of tenant feature plan usage information (legacy endpoint).""" # Arrange tenant_id = "tenant-123" expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}} @@ -438,6 +438,20 @@ class TestBillingServiceUsageCalculation: assert result == expected_response mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + def test_get_quota_info(self, mock_send_request): + """Test retrieval of quota info from new endpoint.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_quota_info(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id}) + def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request): """Test updating tenant feature usage with positive delta (adding credits).""" # Arrange @@ -515,6 +529,150 @@ class TestBillingServiceUsageCalculation: ) +class TestBillingServiceQuotaOperations: + """Unit tests for quota reserve/commit/release operations.""" + + @pytest.fixture + def mock_send_request(self): + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_quota_reserve_success(self, mock_send_request): + expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/reserve", + json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1}, + ) + + def test_quota_reserve_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"} + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1) + + assert result["available"] == 99 + assert isinstance(result["available"], int) + assert result["reserved"] == 1 + assert isinstance(result["reserved"], int) + + def test_quota_reserve_with_meta(self, mock_send_request): + mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1} + meta = {"source": "webhook"} + + BillingService.quota_reserve( + tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"source": "webhook"} + + def test_quota_commit_success(self, mock_send_request): + expected = {"available": 98, "reserved": 0, "refunded": 0} + mock_send_request.return_value = expected + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1 + ) + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/commit", + json={ + "tenant_id": "t1", + "feature_key": "trigger_event", + "reservation_id": "rid-1", + "actual_amount": 1, + }, + ) + + def test_quota_commit_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"} + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1 + ) + + assert result["available"] == 97 + assert isinstance(result["available"], int) + assert result["refunded"] == 1 + assert isinstance(result["refunded"], int) + + def test_quota_commit_with_meta(self, mock_send_request): + mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0} + meta = {"reason": "partial"} + + BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta + ) + + call_json = mock_send_request.call_args[1]["json"] + assert call_json["meta"] == {"reason": "partial"} + + def test_quota_release_success(self, mock_send_request): + expected = {"available": 100, "reserved": 0, "released": 1} + mock_send_request.return_value = expected + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1") + + assert result == expected + mock_send_request.assert_called_once_with( + "POST", + "/quota/release", + json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"}, + ) + + def test_quota_release_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"} + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s") + + assert result["available"] == 100 + assert isinstance(result["available"], int) + assert result["released"] == 1 + assert isinstance(result["released"], int) + + def test_get_quota_info_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int for get_quota_info.""" + mock_send_request.return_value = { + "trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"}, + "api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"}, + } + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert isinstance(result["trigger_event"]["usage"], int) + assert result["trigger_event"]["limit"] == 3000 + assert isinstance(result["trigger_event"]["limit"], int) + assert result["trigger_event"]["reset_date"] == 1700000000 + assert isinstance(result["trigger_event"]["reset_date"], int) + assert result["api_rate_limit"]["limit"] == -1 + assert isinstance(result["api_rate_limit"]["limit"], int) + + def test_get_quota_info_accepts_int_values(self, mock_send_request): + """Test that get_quota_info works with native int values.""" + expected = { + "trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000}, + "api_rate_limit": {"usage": 0, "limit": -1}, + } + mock_send_request.return_value = expected + + result = BillingService.get_quota_info("t1") + + assert result["trigger_event"]["usage"] == 42 + assert result["trigger_event"]["limit"] == 3000 + assert result["api_rate_limit"]["limit"] == -1 + + class TestBillingServiceRateLimitEnforcement: """Unit tests for rate limit enforcement mechanisms. diff --git a/api/tests/unit_tests/tasks/test_trigger_processing_tasks.py b/api/tests/unit_tests/tasks/test_trigger_processing_tasks.py new file mode 100644 index 0000000000..59da5cc7a2 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_trigger_processing_tasks.py @@ -0,0 +1,204 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import tasks.trigger_processing_tasks as trigger_processing_tasks_module +from services.errors.app import QuotaExceededError +from tasks.trigger_processing_tasks import dispatch_triggered_workflow + + +class TestDispatchTriggeredWorkflow: + """Unit tests covering branch behaviours of ``dispatch_triggered_workflow``. + + The covered branches are: + - workflow missing for ``plugin_trigger.app_id`` → log + ``continue`` + - ``QuotaService.reserve`` raising ``QuotaExceededError`` → + ``mark_tenant_triggers_rate_limited`` + early ``return`` + - ``trigger_workflow_async`` succeeds → + ``quota_charge.commit()`` + ``dispatched_count`` increments + """ + + @pytest.fixture + def subscription(self): + sub = MagicMock() + sub.id = "subscription-123" + sub.tenant_id = "tenant-123" + sub.provider_id = "langgenius/test_plugin/test_plugin" + sub.endpoint_id = "endpoint-123" + sub.credentials = {} + sub.credential_type = "api_key" + return sub + + @pytest.fixture + def plugin_trigger(self): + trigger = MagicMock() + trigger.id = "plugin-trigger-123" + trigger.app_id = "app-123" + trigger.node_id = "node-123" + return trigger + + @pytest.fixture + def provider_controller(self): + controller = MagicMock() + controller.plugin_unique_identifier = "langgenius/test_plugin:0.0.1" + controller.entity.identity.name = "Test Plugin" + controller.entity.identity.icon = "icon.svg" + controller.entity.identity.icon_dark = "icon_dark.svg" + return controller + + @pytest.fixture + def dispatch_mocks(self, subscription, plugin_trigger, provider_controller): + """Patch all external dependencies reached by ``dispatch_triggered_workflow``. + + Defaults are configured so the code flow can reach the final async + trigger block (line ~385); each test overrides specific handles + (``get_workflows``, ``reserve``, ``create_end_user_batch``, ...) to + drive the path it targets. + """ + session_cm = MagicMock() + session_cm.__enter__.return_value = MagicMock() + session_cm.__exit__.return_value = False + + invoke_response = MagicMock() + invoke_response.cancelled = False + invoke_response.variables = {} + + quota_charge = MagicMock() + + with ( + patch.object( + trigger_processing_tasks_module.TriggerHttpRequestCachingService, + "get_request", + return_value=MagicMock(), + ), + patch.object( + trigger_processing_tasks_module.TriggerHttpRequestCachingService, + "get_payload", + return_value=MagicMock(), + ), + patch.object( + trigger_processing_tasks_module.TriggerSubscriptionOperatorService, + "get_subscriber_triggers", + return_value=[plugin_trigger], + ), + patch.object( + trigger_processing_tasks_module.TriggerManager, + "get_trigger_provider", + return_value=provider_controller, + ), + patch.object( + trigger_processing_tasks_module.TriggerManager, + "invoke_trigger_event", + return_value=invoke_response, + ) as invoke_trigger_event, + patch.object( + trigger_processing_tasks_module.TriggerEventNodeData, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + trigger_processing_tasks_module, + "_get_latest_workflows_by_app_ids", + ) as get_workflows, + patch.object( + trigger_processing_tasks_module.EndUserService, + "create_end_user_batch", + return_value={}, + ) as create_end_user_batch, + patch.object( + trigger_processing_tasks_module.session_factory, + "create_session", + return_value=session_cm, + ), + patch.object( + trigger_processing_tasks_module.QuotaService, + "reserve", + return_value=quota_charge, + ) as reserve, + patch.object( + trigger_processing_tasks_module.AppTriggerService, + "mark_tenant_triggers_rate_limited", + ) as mark_rate_limited, + patch.object( + trigger_processing_tasks_module.AsyncWorkflowService, + "trigger_workflow_async", + ) as trigger_workflow_async, + ): + yield { + "get_workflows": get_workflows, + "reserve": reserve, + "quota_charge": quota_charge, + "mark_rate_limited": mark_rate_limited, + "invoke_trigger_event": invoke_trigger_event, + "invoke_response": invoke_response, + "create_end_user_batch": create_end_user_batch, + "trigger_workflow_async": trigger_workflow_async, + } + + def test_dispatch_skips_when_workflow_missing(self, subscription, dispatch_mocks): + """Covers missing workflow → log + ``continue``.""" + dispatch_mocks["get_workflows"].return_value = {} + + dispatched = dispatch_triggered_workflow( + user_id="user-123", + subscription=subscription, + event_name="test_event", + request_id="request-123", + ) + + assert dispatched == 0 + dispatch_mocks["reserve"].assert_not_called() + dispatch_mocks["invoke_trigger_event"].assert_not_called() + dispatch_mocks["mark_rate_limited"].assert_not_called() + + def test_dispatch_marks_rate_limited_when_quota_exceeded(self, subscription, plugin_trigger, dispatch_mocks): + """Covers QuotaExceededError → mark rate-limited + early return.""" + workflow_mock = MagicMock() + workflow_mock.walk_nodes.return_value = iter( + [(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})] + ) + dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock} + dispatch_mocks["reserve"].side_effect = QuotaExceededError( + feature="trigger", tenant_id=subscription.tenant_id, required=1 + ) + + dispatched = dispatch_triggered_workflow( + user_id="user-123", + subscription=subscription, + event_name="test_event", + request_id="request-123", + ) + + assert dispatched == 0 + dispatch_mocks["reserve"].assert_called_once() + dispatch_mocks["mark_rate_limited"].assert_called_once_with(subscription.tenant_id) + dispatch_mocks["invoke_trigger_event"].assert_not_called() + + def test_dispatch_commits_quota_and_counts_when_workflow_triggered( + self, subscription, plugin_trigger, dispatch_mocks + ): + """Happy path: end user exists and async trigger succeeds.""" + workflow_mock = MagicMock() + workflow_mock.id = "workflow-123" + workflow_mock.walk_nodes.return_value = iter( + [(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})] + ) + dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock} + + end_user_mock = MagicMock() + dispatch_mocks["create_end_user_batch"].return_value = {plugin_trigger.app_id: end_user_mock} + + dispatched = dispatch_triggered_workflow( + user_id="user-123", + subscription=subscription, + event_name="test_event", + request_id="request-123", + ) + + assert dispatched == 1 + dispatch_mocks["trigger_workflow_async"].assert_called_once() + _, kwargs = dispatch_mocks["trigger_workflow_async"].call_args + assert kwargs["user"] is end_user_mock + dispatch_mocks["quota_charge"].commit.assert_called_once() + dispatch_mocks["quota_charge"].refund.assert_not_called() + dispatch_mocks["mark_rate_limited"].assert_not_called()