chore(api): migrate mail task and OAuth data source to use Session(db… (#35235)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
jerryzai
2026-04-17 04:52:27 -04:00
committed by GitHub
parent dfcc0f8863
commit b9c300d570
2 changed files with 128 additions and 124 deletions

View File

@@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding from models.source import DataSourceOauthBinding
@@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
pages=pages, pages=pages,
) )
# save data source binding # save data source binding
data_source_binding = db.session.scalar( with session_factory.create_session() as session:
select(DataSourceOauthBinding).where( data_source_binding = session.scalar(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, select(DataSourceOauthBinding).where(
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
) )
) if data_source_binding:
if data_source_binding: data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False
data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now()
data_source_binding.updated_at = naive_utc_now() session.commit()
db.session.commit() else:
else: new_data_source_binding = DataSourceOauthBinding(
new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id,
tenant_id=current_user.current_tenant_id, access_token=access_token,
access_token=access_token, source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion",
provider="notion", )
) session.add(new_data_source_binding)
db.session.add(new_data_source_binding) session.commit()
db.session.commit()
def save_internal_access_token(self, access_token: str) -> None: def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token) workspace_name = self.notion_workspace_name(access_token)
@@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
pages=pages, pages=pages,
) )
# save data source binding # save data source binding
data_source_binding = db.session.scalar( with session_factory.create_session() as session:
select(DataSourceOauthBinding).where( data_source_binding = session.scalar(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, select(DataSourceOauthBinding).where(
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
) )
) if data_source_binding:
if data_source_binding: data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False
data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now()
data_source_binding.updated_at = naive_utc_now() session.commit()
db.session.commit() else:
else: new_data_source_binding = DataSourceOauthBinding(
new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id,
tenant_id=current_user.current_tenant_id, access_token=access_token,
access_token=access_token, source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion",
provider="notion", )
) session.add(new_data_source_binding)
db.session.add(new_data_source_binding) session.commit()
db.session.commit()
def sync_data_source(self, binding_id: str) -> None: def sync_data_source(self, binding_id: str) -> None:
# save data source binding # save data source binding
data_source_binding = db.session.scalar( with session_factory.create_session() as session:
select(DataSourceOauthBinding).where( data_source_binding = session.scalar(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, select(DataSourceOauthBinding).where(
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
) )
)
if data_source_binding: if data_source_binding:
# get all authorized pages # get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token) pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info( new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"], workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"], workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"], workspace_id=source_info["workspace_id"],
pages=pages, pages=pages,
) )
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now() data_source_binding.updated_at = naive_utc_now()
db.session.commit() session.commit()
else: else:
raise ValueError("Data source binding not found") raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = [] pages: list[NotionPageSummary] = []

View File

@@ -7,8 +7,8 @@ from sqlalchemy import select
import app import app
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory
from enums.cloud_plan import CloudPlan from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_mail import mail from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin from models import Account, Tenant, TenantAccountJoin
@@ -33,67 +33,68 @@ def mail_clean_document_notify_task():
# send document clean notify mail # send document clean notify mail
try: try:
dataset_auto_disable_logs = db.session.scalars( with session_factory.create_session() as session:
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) dataset_auto_disable_logs = session.scalars(
).all() select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
# group by tenant_id ).all()
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) # group by tenant_id
for dataset_auto_disable_log in dataset_auto_disable_logs: dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: for dataset_auto_disable_log in dataset_auto_disable_logs:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
url = f"{dify_config.CONSOLE_WEB_URL}/datasets" dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
features = FeatureService.get_features(tenant_id) for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
plan = features.billing.subscription.plan features = FeatureService.get_features(tenant_id)
if plan != CloudPlan.SANDBOX: plan = features.billing.subscription.plan
knowledge_details = [] if plan != CloudPlan.SANDBOX:
# check tenant knowledge_details = []
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) # check tenant
if not tenant: tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
continue if not tenant:
# check current owner continue
current_owner_join = db.session.scalar( # check current owner
select(TenantAccountJoin) current_owner_join = session.scalar(
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") select(TenantAccountJoin)
.limit(1) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
) .limit(1)
if not current_owner_join: )
continue if not current_owner_join:
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) continue
if not account: account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
continue if not account:
continue
dataset_auto_dataset_map = {} # type: ignore dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)
# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: dataset_auto_disable_log.notified = True
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] session.commit()
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)
# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True
db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception: except Exception: