mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 23:40:16 +08:00
refactor: enhance trigger db session management
Some checks failed
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/amd64, ubuntu-latest, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/amd64, ubuntu-latest, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Some checks failed
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/amd64, ubuntu-latest, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/amd64, ubuntu-latest, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
This commit is contained in:
@@ -89,7 +89,10 @@ class AsyncWorkflowService:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
||||
|
||||
# 2. Get workflow
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
|
||||
|
||||
# commit read only session before starting the billig rpc call
|
||||
session.commit()
|
||||
|
||||
# 3. Get dispatcher based on tenant subscription
|
||||
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||
@@ -302,13 +305,21 @@ class AsyncWorkflowService:
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
|
||||
def _get_workflow(
|
||||
workflow_service: WorkflowService,
|
||||
app_model: App,
|
||||
workflow_id: str | None = None,
|
||||
session: Session | None = None,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Get workflow for the app
|
||||
|
||||
Args:
|
||||
app_model: App model instance
|
||||
workflow_id: Optional specific workflow ID
|
||||
session: Reuse this SQLAlchemy session for the lookup when provided,
|
||||
so the caller's explicit session bears the connection cost
|
||||
instead of Flask's request-scoped ``db.session``.
|
||||
|
||||
Returns:
|
||||
Workflow instance
|
||||
@@ -318,12 +329,12 @@ class AsyncWorkflowService:
|
||||
"""
|
||||
if workflow_id:
|
||||
# Get specific published workflow
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
||||
else:
|
||||
# Get default published workflow
|
||||
workflow = workflow_service.get_published_workflow(app_model)
|
||||
workflow = workflow_service.get_published_workflow(app_model, session=session)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
||||
|
||||
|
||||
@@ -799,50 +799,47 @@ class WebhookService:
|
||||
Exception: If workflow execution fails
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Prepare inputs for the webhook node
|
||||
# The webhook node expects webhook_data in the inputs
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Create trigger data
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id,
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
)
|
||||
raise
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# reserve quota before triggering workflow execution
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
)
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
try:
|
||||
try:
|
||||
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
|
||||
# trigger_workflow_async need to handle multipe session commits internally
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@@ -156,11 +156,18 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
|
||||
def get_published_workflow_by_id(
|
||||
self, app_model: App, workflow_id: str, session: Session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
workflow = db.session.scalar(
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
@@ -178,16 +185,20 @@ class WorkflowService:
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Workflow | None:
|
||||
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
|
||||
if not app_model.workflow_id:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = db.session.scalar(
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
|
||||
@@ -259,59 +259,60 @@ def dispatch_triggered_workflow(
|
||||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
trigger_entity: TriggerProviderEntity = provider_controller.entity
|
||||
|
||||
# Ensure expire_on_commit is set to False to remain workflows available
|
||||
with session_factory.create_session() as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
for plugin_trigger in subscribers:
|
||||
# Get workflow from mapping
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Find the trigger node in the workflow
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
# invoke trigger
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
for plugin_trigger in subscribers:
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# reserve quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
|
||||
)
|
||||
return 0
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
|
||||
)
|
||||
return dispatched_count
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
invoke_response = TriggerManager.invoke_trigger_event(
|
||||
tenant_id=subscription.tenant_id,
|
||||
@@ -403,7 +404,7 @@ def dispatch_triggered_workflow(
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
|
||||
return dispatched_count
|
||||
return dispatched_count
|
||||
|
||||
|
||||
def dispatch_triggered_workflows(
|
||||
|
||||
@@ -33,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
TenantOwnerNotFoundError: If no owner/admin for tenant
|
||||
ScheduleExecutionError: If workflow trigger fails
|
||||
"""
|
||||
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
|
||||
with session_factory.create_session() as session:
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
@@ -42,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
if not tenant_owner:
|
||||
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
|
||||
try:
|
||||
# Production dispatch: Trigger the workflow normally
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
response = AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=tenant_owner,
|
||||
@@ -62,10 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
|
||||
from enums.quota_type import QuotaType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
@@ -290,17 +291,26 @@ class TestWebhookServiceTriggerExecutionWithContainers:
|
||||
end_user = SimpleNamespace(id=str(uuid4()))
|
||||
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
|
||||
|
||||
quota_charge = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
return_value=end_user,
|
||||
),
|
||||
patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume,
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
) as mock_reserve,
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
|
||||
):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
mock_consume.assert_called_once_with(webhook_trigger.tenant_id)
|
||||
mock_reserve.assert_called_once()
|
||||
reserve_args = mock_reserve.call_args.args
|
||||
assert reserve_args[0] == QuotaType.TRIGGER
|
||||
assert reserve_args[1] == webhook_trigger.tenant_id
|
||||
quota_charge.commit.assert_called_once()
|
||||
mock_trigger.assert_called_once()
|
||||
trigger_args = mock_trigger.call_args.args
|
||||
assert trigger_args[1] is end_user
|
||||
@@ -327,7 +337,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
|
||||
return_value=SimpleNamespace(id=str(uuid4())),
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaType.TRIGGER.consume",
|
||||
"services.trigger.webhook_service.QuotaService.reserve",
|
||||
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
|
||||
),
|
||||
patch(
|
||||
|
||||
@@ -474,7 +474,9 @@ class TestAsyncWorkflowServiceGetWorkflow:
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(
|
||||
app_model, "workflow-123", session=None
|
||||
)
|
||||
workflow_service.get_published_workflow.assert_not_called()
|
||||
|
||||
def test_should_raise_when_specific_workflow_id_not_found(self):
|
||||
@@ -502,7 +504,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model)
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None)
|
||||
workflow_service.get_published_workflow_by_id.assert_not_called()
|
||||
|
||||
def test_should_raise_when_default_published_workflow_not_found(self):
|
||||
|
||||
Reference in New Issue
Block a user