diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 8b39d63385..ceda30e950 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -89,7 +89,10 @@ class AsyncWorkflowService: raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}") # 2. Get workflow - workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) + workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session) + + # commit read only session before starting the billig rpc call + session.commit() # 3. Get dispatcher based on tenant subscription dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) @@ -302,13 +305,21 @@ class AsyncWorkflowService: return [log.to_dict() for log in logs] @staticmethod - def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow: + def _get_workflow( + workflow_service: WorkflowService, + app_model: App, + workflow_id: str | None = None, + session: Session | None = None, + ) -> Workflow: """ Get workflow for the app Args: app_model: App model instance workflow_id: Optional specific workflow ID + session: Reuse this SQLAlchemy session for the lookup when provided, + so the caller's explicit session bears the connection cost + instead of Flask's request-scoped ``db.session``. Returns: Workflow instance @@ -318,12 +329,12 @@ class AsyncWorkflowService: """ if workflow_id: # Get specific published workflow - workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) + workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session) if not workflow: raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}") else: # Get default published workflow - workflow = workflow_service.get_published_workflow(app_model) + workflow = workflow_service.get_published_workflow(app_model, session=session) if not workflow: raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}") diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index d562220fa7..5d99900a04 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -799,50 +799,47 @@ class WebhookService: Exception: If workflow execution fails """ try: - with Session(db.engine) as session: - # Prepare inputs for the webhook node - # The webhook node expects webhook_data in the inputs - workflow_inputs = cls.build_workflow_inputs(webhook_data) + workflow_inputs = cls.build_workflow_inputs(webhook_data) - # Create trigger data - trigger_data = WebhookTriggerData( - app_id=webhook_trigger.app_id, - workflow_id=workflow.id, - root_node_id=webhook_trigger.node_id, # Start from the webhook node - inputs=workflow_inputs, - tenant_id=webhook_trigger.tenant_id, + trigger_data = WebhookTriggerData( + app_id=webhook_trigger.app_id, + workflow_id=workflow.id, + root_node_id=webhook_trigger.node_id, + inputs=workflow_inputs, + tenant_id=webhook_trigger.tenant_id, + ) + + end_user = EndUserService.get_or_create_end_user_by_type( + type=InvokeFrom.TRIGGER, + tenant_id=webhook_trigger.tenant_id, + app_id=webhook_trigger.app_id, + user_id=None, + ) + + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) + logger.info( + "Tenant %s rate limited, skipping webhook trigger %s", + webhook_trigger.tenant_id, + webhook_trigger.webhook_id, ) + raise - end_user = EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.TRIGGER, - tenant_id=webhook_trigger.tenant_id, - app_id=webhook_trigger.app_id, - user_id=None, - ) - - # reserve quota before triggering workflow execution - try: - quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) - logger.info( - "Tenant %s rate limited, skipping webhook trigger %s", - webhook_trigger.tenant_id, - webhook_trigger.webhook_id, - ) - raise - - # Trigger workflow execution asynchronously - try: + try: + # NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()` + # trigger_workflow_async need to handle multipe session commits internally + with Session(db.engine, expire_on_commit=False) as session: AsyncWorkflowService.trigger_workflow_async( session, end_user, trigger_data, ) - quota_charge.commit() - except Exception: - quota_charge.refund() - raise + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index d71223314e..d01331b588 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -156,11 +156,18 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: + def get_published_workflow_by_id( + self, app_model: App, workflow_id: str, session: Session | None = None + ) -> Workflow | None: """ fetch published workflow by workflow_id + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, @@ -178,16 +185,20 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Workflow | None: + def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None: """ Get published workflow + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index b0cbc54db3..71ecf08689 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -259,59 +259,60 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity + + # Ensure expire_on_commit is set to False to remain workflows available with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) - end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( - type=InvokeFrom.TRIGGER, - tenant_id=subscription.tenant_id, - app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], - user_id=user_id, - ) - for plugin_trigger in subscribers: - # Get workflow from mapping - workflow: Workflow | None = workflows.get(plugin_trigger.app_id) - if not workflow: - logger.error( - "Workflow not found for app %s", - plugin_trigger.app_id, - ) - continue + end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( + type=InvokeFrom.TRIGGER, + tenant_id=subscription.tenant_id, + app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], + user_id=user_id, + ) - # Find the trigger node in the workflow - event_node = None - for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): - if node_id == plugin_trigger.node_id: - event_node = node_config - break - - if not event_node: - logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) - continue - - # invoke trigger - trigger_metadata = PluginTriggerMetadata( - plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", - endpoint_id=subscription.endpoint_id, - provider_id=subscription.provider_id, - event_name=event_name, - icon_filename=trigger_entity.identity.icon or "", - icon_dark_filename=trigger_entity.identity.icon_dark or "", + for plugin_trigger in subscribers: + workflow: Workflow | None = workflows.get(plugin_trigger.app_id) + if not workflow: + logger.error( + "Workflow not found for app %s", + plugin_trigger.app_id, ) + continue - # reserve quota before invoking trigger - quota_charge = unlimited() - try: - quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) - logger.info( - "Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id - ) - return 0 + event_node = None + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): + if node_id == plugin_trigger.node_id: + event_node = node_config + break - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) - invoke_response: TriggerInvokeEventResponse | None = None + if not event_node: + logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) + continue + + trigger_metadata = PluginTriggerMetadata( + plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", + endpoint_id=subscription.endpoint_id, + provider_id=subscription.provider_id, + event_name=event_name, + icon_filename=trigger_entity.identity.icon or "", + icon_dark_filename=trigger_entity.identity.icon_dark or "", + ) + + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) + logger.info( + "Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id + ) + return dispatched_count + + node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) + invoke_response: TriggerInvokeEventResponse | None = None + + with session_factory.create_session() as session: try: invoke_response = TriggerManager.invoke_trigger_event( tenant_id=subscription.tenant_id, @@ -403,7 +404,7 @@ def dispatch_triggered_workflow( plugin_trigger.app_id, ) - return dispatched_count + return dispatched_count def dispatch_triggered_workflows( diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index dfb2fb3391..7638652000 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -33,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ + # Ensure expire_on_commit is set to False to remain schedule/tenant_owner available with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: @@ -42,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None: if not tenant_owner: raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}") - quota_charge = unlimited() - try: - quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) - logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) - return + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) + logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) + return - try: - # Production dispatch: Trigger the workflow normally + try: + with session_factory.create_session() as session: response = AsyncWorkflowService.trigger_workflow_async( session=session, user=tenant_owner, @@ -62,10 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) - quota_charge.commit() - logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) - except Exception as e: - quota_charge.refund() - raise ScheduleExecutionError( - f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" - ) from e + quota_charge.commit() + logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) + except Exception as e: + quota_charge.refund() + raise ScheduleExecutionError( + f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" + ) from e diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py index ec10c51e04..85ce3a6ba6 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -10,6 +10,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from enums.quota_type import QuotaType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import AppTriggerStatus, AppTriggerType from models.model import App @@ -290,17 +291,26 @@ class TestWebhookServiceTriggerExecutionWithContainers: end_user = SimpleNamespace(id=str(uuid4())) webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + quota_charge = MagicMock() + with ( patch( "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", return_value=end_user, ), - patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume, + patch( + "services.trigger.webhook_service.QuotaService.reserve", + return_value=quota_charge, + ) as mock_reserve, patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, ): WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) - mock_consume.assert_called_once_with(webhook_trigger.tenant_id) + mock_reserve.assert_called_once() + reserve_args = mock_reserve.call_args.args + assert reserve_args[0] == QuotaType.TRIGGER + assert reserve_args[1] == webhook_trigger.tenant_id + quota_charge.commit.assert_called_once() mock_trigger.assert_called_once() trigger_args = mock_trigger.call_args.args assert trigger_args[1] is end_user @@ -327,7 +337,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: return_value=SimpleNamespace(id=str(uuid4())), ), patch( - "services.trigger.webhook_service.QuotaType.TRIGGER.consume", + "services.trigger.webhook_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), ), patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index 361e95a557..c88da0fa85 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -474,7 +474,9 @@ class TestAsyncWorkflowServiceGetWorkflow: # Assert assert result == workflow - workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123") + workflow_service.get_published_workflow_by_id.assert_called_once_with( + app_model, "workflow-123", session=None + ) workflow_service.get_published_workflow.assert_not_called() def test_should_raise_when_specific_workflow_id_not_found(self): @@ -502,7 +504,7 @@ class TestAsyncWorkflowServiceGetWorkflow: # Assert assert result == workflow - workflow_service.get_published_workflow.assert_called_once_with(app_model) + workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None) workflow_service.get_published_workflow_by_id.assert_not_called() def test_should_raise_when_default_published_workflow_not_found(self):