chore(api): migrate mail task and OAuth data source to use Session(db… (#35235)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
jerryzai
2026-04-17 04:52:27 -04:00
committed by GitHub
parent dfcc0f8863
commit b9c300d570
2 changed files with 128 additions and 124 deletions

View File

@@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding from models.source import DataSourceOauthBinding
@@ -95,7 +95,8 @@ class NotionOAuth(OAuthDataSource):
pages=pages, pages=pages,
) )
# save data source binding # save data source binding
data_source_binding = db.session.scalar( with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
@@ -106,7 +107,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now() data_source_binding.updated_at = naive_utc_now()
db.session.commit() session.commit()
else: else:
new_data_source_binding = DataSourceOauthBinding( new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@@ -114,8 +115,8 @@ class NotionOAuth(OAuthDataSource):
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion", provider="notion",
) )
db.session.add(new_data_source_binding) session.add(new_data_source_binding)
db.session.commit() session.commit()
def save_internal_access_token(self, access_token: str) -> None: def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token) workspace_name = self.notion_workspace_name(access_token)
@@ -130,7 +131,8 @@ class NotionOAuth(OAuthDataSource):
pages=pages, pages=pages,
) )
# save data source binding # save data source binding
data_source_binding = db.session.scalar( with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
@@ -141,7 +143,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now() data_source_binding.updated_at = naive_utc_now()
db.session.commit() session.commit()
else: else:
new_data_source_binding = DataSourceOauthBinding( new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@@ -149,12 +151,13 @@ class NotionOAuth(OAuthDataSource):
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion", provider="notion",
) )
db.session.add(new_data_source_binding) session.add(new_data_source_binding)
db.session.commit() session.commit()
def sync_data_source(self, binding_id: str) -> None: def sync_data_source(self, binding_id: str) -> None:
# save data source binding # save data source binding
data_source_binding = db.session.scalar( with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
@@ -176,7 +179,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now() data_source_binding.updated_at = naive_utc_now()
db.session.commit() session.commit()
else: else:
raise ValueError("Data source binding not found") raise ValueError("Data source binding not found")

View File

@@ -7,8 +7,8 @@ from sqlalchemy import select
import app import app
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory
from enums.cloud_plan import CloudPlan from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_mail import mail from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin from models import Account, Tenant, TenantAccountJoin
@@ -33,8 +33,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail # send document clean notify mail
try: try:
dataset_auto_disable_logs = db.session.scalars( with session_factory.create_session() as session:
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) dataset_auto_disable_logs = session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
).all() ).all()
# group by tenant_id # group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
@@ -49,18 +50,18 @@ def mail_clean_document_notify_task():
if plan != CloudPlan.SANDBOX: if plan != CloudPlan.SANDBOX:
knowledge_details = [] knowledge_details = []
# check tenant # check tenant
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant: if not tenant:
continue continue
# check current owner # check current owner
current_owner_join = db.session.scalar( current_owner_join = session.scalar(
select(TenantAccountJoin) select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1) .limit(1)
) )
if not current_owner_join: if not current_owner_join:
continue continue
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account: if not account:
continue continue
@@ -73,7 +74,7 @@ def mail_clean_document_notify_task():
) )
for dataset_id, document_ids in dataset_auto_dataset_map.items(): for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset: if dataset:
document_count = len(document_ids) document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
@@ -93,7 +94,7 @@ def mail_clean_document_notify_task():
# update notified to True # update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True dataset_auto_disable_log.notified = True
db.session.commit() session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception: except Exception: