diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index e0e6a6f5c3..9df78a7830 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :return: """ with Session(db.engine, expire_on_commit=False) as session: - agent_thought: MessageAgentThought | None = ( - session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() + agent_thought: MessageAgentThought | None = session.scalar( + select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1) ) if agent_thought: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 143d1e696b..a5297fa33a 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -345,8 +345,8 @@ class DatasourceManager: @classmethod def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: with session_factory.create_session() as session: - upload_file = ( - session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first() + upload_file = session.scalar( + select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1) ) if not upload_file: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8071770c0f..aa258c9f89 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -467,7 +467,7 @@ class LLMGenerator: ): session = db.session() - app: App | None = session.query(App).where(App.id == flow_id).first() + app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1)) if not app: raise ValueError("App not found.") workflow = workflow_service.get_draft_workflow(app_model=app) diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index 8c081ae225..a1f96b9edf 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -56,8 +56,10 @@ class BaseTraceInstance(ABC): if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") - current_tenant = ( - session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() + current_tenant = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True)) + .limit(1) ) if not current_tenant: raise ValueError(f"Current tenant not found for account {service_account.id}") diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 2bd6db22bf..84f54d8a5a 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance): if not service_account: raise ValueError(f"Creator account not found for app {app_id}") - current_tenant = ( - session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() + current_tenant = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True)) + .limit(1) ) if not current_tenant: raise ValueError(f"Current tenant not found for account {service_account.id}") diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index f7e7b7e20e..f22602a400 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __exit__(self, exc_type, exc, tb): return False - def query(self, *args, **kwargs): - return self - - def where(self, *args, **kwargs): - return self - - def first(self): + def scalar(self, *args, **kwargs): return agent_thought monkeypatch.setattr( @@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __exit__(self, exc_type, exc, tb): return False - def query(self, *args, **kwargs): - return self - - def where(self, *args, **kwargs): - return self - - def first(self): + def scalar(self, *args, **kwargs): return None monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index b0c72ee42f..d338cadb77 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker): source_url="http://x", ) - class _Q: - def __init__(self, row): - self._row = row - - def where(self, *_args, **_kwargs): - return self - - def first(self): - return self._row - class _S: def __init__(self, row): self._row = row @@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker): def __exit__(self, *exc): return False - def query(self, *_): - return _Q(self._row) + def scalar(self, *_args, **_kwargs): + return self._row mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row)) @@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker): def test_get_upload_file_by_id_raises_when_missing(mocker): - class _Q: - def where(self, *_args, **_kwargs): - return self - - def first(self): - return None - class _S: def __enter__(self): return self @@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker): def __exit__(self, *exc): return False - def query(self, *_): - return _Q() + def scalar(self, *_args, **_kwargs): + return None mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 62e714deb6..7cdfb31189 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -346,13 +346,13 @@ class TestLLMGenerator: def test_instruction_modify_workflow_app_not_found(self): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = None + mock_session.return_value.scalar.return_value = None with pytest.raises(ValueError, match="App not found."): LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock()) def test_instruction_modify_workflow_no_workflow(self): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow_service = MagicMock() workflow_service.get_draft_workflow.return_value = None with pytest.raises(ValueError, match="Workflow not found for the given app model."): @@ -360,7 +360,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index 382e5dadc3..f67abba807 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -407,8 +407,7 @@ class TestTencentDataTrace: mock_db.engine = "engine" with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value - session.scalar.side_effect = [app, account] - session.query.return_value.filter_by.return_value.first.return_value = tenant_join + session.scalar.side_effect = [app, account, tenant_join] with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" diff --git a/api/tests/unit_tests/core/ops/test_base_trace_instance.py b/api/tests/unit_tests/core/ops/test_base_trace_instance.py index a8bee7dfa7..ac65d13454 100644 --- a/api/tests/unit_tests/core/ops/test_base_trace_instance.py +++ b/api/tests/unit_tests/core/ops/test_base_trace_instance.py @@ -76,10 +76,7 @@ def test_get_service_account_with_tenant_tenant_not_found(mock_db_session): mock_account = MagicMock(spec=Account) mock_account.id = "creator_id" - mock_db_session.scalar.side_effect = [mock_app, mock_account] - - # session.query(TenantAccountJoin).filter_by(...).first() returns None - mock_db_session.query.return_value.filter_by.return_value.first.return_value = None + mock_db_session.scalar.side_effect = [mock_app, mock_account, None] config = MagicMock(spec=BaseTracingConfig) instance = ConcreteTraceInstance(config) @@ -97,11 +94,10 @@ def test_get_service_account_with_tenant_success(mock_db_session): mock_account.id = "creator_id" mock_account.set_tenant_id = MagicMock() - mock_db_session.scalar.side_effect = [mock_app, mock_account] - mock_tenant_join = MagicMock(spec=TenantAccountJoin) mock_tenant_join.tenant_id = "tenant_id" - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join + + mock_db_session.scalar.side_effect = [mock_app, mock_account, mock_tenant_join] config = MagicMock(spec=BaseTracingConfig) instance = ConcreteTraceInstance(config)