Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-17 16:26:19 +08:00
109 changed files with 2431 additions and 1752 deletions

View File

@@ -2,6 +2,7 @@ import base64
import secrets
import click
from sqlalchemy.orm import Session
from constants.languages import languages
from extensions.ext_database import db
@@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm):
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
with Session(db.engine) as session:
account = session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
with Session(db.engine) as session:
account = session.merge(account)
account.email = normalized_new_email
session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))

View File

@@ -299,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
tool_instance = tool_instances.get(prompt_tool.name)
if tool_instance:
self.update_prompt_message_tool(tool_instance, prompt_tool)
iteration_step += 1

View File

@@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:param session: optional database session
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
# fix origin data
credential_record = session.execute(stmt).scalar_one_or_none()
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def _generate_provider_credential_name(self, session) -> str:
"""
@@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credential_name = self._generate_provider_credential_name(pre_session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
credentials = self.validate_provider_credentials(credentials=credentials)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
# If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
credential_name=credential_name, session=pre_session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
credentials = self.validate_provider_credentials(
credentials=credentials, credential_id=credential_id, session=session
)
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
"""
Validate custom model credentials.
@@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
credential_record = s.execute(stmt).scalar_one_or_none()
credential_record = session.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
model=model, model_type=model_type, session=pre_session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
)
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
# save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
session=session,
session=pre_session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
session=session,
)
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:

View File

@@ -1,8 +1,8 @@
from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
from pydantic import BaseModel
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
from core.ops.utils import validate_project_name, validate_url
class TracingProviderEnum(StrEnum):
@@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel):
return validate_project_name(v, default_name)
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
"""
Model class for Langsmith tracing config.
"""
api_key: str
project: str
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""
api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
class MLflowConfig(BaseTracingConfig):
"""
Model class for MLflow tracing config.
"""
tracking_uri: str = "http://localhost:5000"
experiment_id: str = "0" # Default experiment id in MLflow is 0
username: str | None = None
password: str | None = None
@field_validator("tracking_uri")
@classmethod
def tracking_uri_validator(cls, v, info: ValidationInfo):
if isinstance(v, str) and v.startswith("databricks"):
raise ValueError(
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
)
return validate_url_with_path(v, "http://localhost:5000")
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
class DatabricksConfig(BaseTracingConfig):
"""
Model class for Databricks (Databricks-managed MLflow) tracing config.
"""
experiment_id: str
host: str
client_id: str | None = None
client_secret: str | None = None
personal_access_token: str | None = None
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
try:
match provider:
case TracingProviderEnum.LANGFUSE:
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
case TracingProviderEnum.LANGSMITH:
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
case TracingProviderEnum.LANGSMITH:
from dify_trace_langsmith.config import LangSmithConfig
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
case TracingProviderEnum.OPIK:
from core.ops.entities.config_entity import OpikConfig
from core.ops.opik_trace.opik_trace import OpikDataTrace
case TracingProviderEnum.OPIK:
from dify_trace_opik.config import OpikConfig
from dify_trace_opik.opik_trace import OpikDataTrace
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
case TracingProviderEnum.WEAVE:
from core.ops.entities.config_entity import WeaveConfig
from core.ops.weave_trace.weave_trace import WeaveDataTrace
case TracingProviderEnum.WEAVE:
from dify_trace_weave.config import WeaveConfig
from dify_trace_weave.weave_trace import WeaveDataTrace
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import ArizeConfig
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
from dify_trace_arize_phoenix.config import ArizeConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import PhoenixConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
from dify_trace_arize_phoenix.config import PhoenixConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.entities.config_entity import AliyunConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from core.ops.entities.config_entity import MLflowConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from dify_trace_mlflow.config import MLflowConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from core.ops.entities.config_entity import DatabricksConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from dify_trace_mlflow.config import DatabricksConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
case TracingProviderEnum.TENCENT:
from dify_trace_tencent.config import TencentConfig
from dify_trace_tencent.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
except ImportError:
raise ImportError(f"Provider {provider} is not installed.")
provider_config_map = OpsTraceProviderConfigMap()

View File

@@ -10,8 +10,8 @@ from typing import Any
from sqlalchemy import select
from core.app.file_access import FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.workflow.file_reference import build_file_reference
from extensions.ext_database import db
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
from models import ToolFile, UploadFile
@@ -135,29 +135,30 @@ def _build_from_local_file(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
with session_factory.create_session() as session:
row = session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
size=row.size,
storage_key=row.key,
)
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
size=row.size,
storage_key=row.key,
)
def _build_from_remote_url(
@@ -179,32 +180,33 @@ def _build_from_remote_url(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
with session_factory.create_session() as session:
upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
)
url = mapping.get("url") or mapping.get("remote_url")
if not url:
@@ -247,30 +249,31 @@ def _build_from_tool_file(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
with session_factory.create_session() as session:
tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)
return File(
id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)
def _build_from_datasource_file(
@@ -289,31 +292,32 @@ def _build_from_datasource_file(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
with session_factory.create_session() as session:
datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
url=datasource_file.source_url,
)
return File(
id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
url=datasource_file.source_url,
)
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:

View File

@@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase):
)
class DocumentSegmentSummary(Base):
class DocumentSegmentSummary(TypeBase):
__tablename__ = "document_segment_summaries"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
@@ -1725,25 +1725,40 @@ class DocumentSegmentSummary(Base):
sa.Index("document_segment_summaries_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
status: Mapped[SummaryStatus] = mapped_column(
EnumText(SummaryStatus, length=32),
nullable=False,
server_default=sa.text("'generating'"),
default=SummaryStatus.GENERATING,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
def __repr__(self):

View File

@@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Difys API
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
## Excluding Providers
In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-<provider>` and `--group trace-<provider>` to enable specific providers.

View File

@@ -0,0 +1,78 @@
# Trace providers
This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
## Architecture
| Layer | Location | Role |
|--------|----------|------|
| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
| Your package | `api/providers/trace/trace-<name>/` | Pydantic config + subclass of `BaseTraceInstance` |
At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
## What you implement
### 1. Config model (`BaseTracingConfig`)
Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
Fields fall into two groups used by the manager:
- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
### 2. Trace instance (`BaseTraceInstance`)
Subclass `BaseTraceInstance` and implement:
```python
def trace(self, trace_info: BaseTraceInfo) -> None:
...
```
Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
You may ignore categories your backend does not support; existing providers often no-op unhandled types.
Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
### 3. Register in the API core
Upstream changes are required so Dify knows your provider exists:
1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
- `config_class`: your Pydantic config type
- `secret_keys` / `other_keys`: lists of field names as above
- `trace_instance`: your `BaseTraceInstance` subclass
Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
## Package layout
Each provider is a normal uv workspace member, for example:
- `api/providers/trace/trace-<name>/pyproject.toml` — project name `dify-trace-<name>`, dependencies on vendor SDKs
- `api/providers/trace/trace-<name>/src/dify_trace_<name>/``config.py`, `<name>_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
## Wiring into the `api` workspace
In `api/pyproject.toml`:
1. **`[tool.uv.sources]`** — `dify-trace-<name> = { workspace = true }`
2. **`[dependency-groups]`** — add `trace-<name> = ["dify-trace-<name>"]` and include `dify-trace-<name>` in `trace-all` if it should ship with the default bundle
After changing metadata, run **`uv sync`** from `api/`.

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-trace-aliyun"
version = "0.0.1"
dependencies = [
# versions inherited from parent
"opentelemetry-api",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
"opentelemetry-semantic-conventions",
]
description = "Dify ops tracing provider (Aliyun)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -4,7 +4,20 @@ from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.data_exporter.traceclient import (
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
@@ -12,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
convert_to_trace_id,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
from core.ops.aliyun_trace.entities.semconv import (
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
DIFY_APP_ID,
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@@ -32,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.aliyun_trace.utils import (
from dify_trace_aliyun.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@@ -44,19 +57,6 @@ from core.ops.aliyun_trace.utils import (
get_workflow_node_status,
serialize_json_data,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey

View File

@@ -0,0 +1,32 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")

View File

@@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000

View File

@@ -4,7 +4,8 @@ from typing import Any, TypedDict
from opentelemetry.trace import Link, Status, StatusCode
from core.ops.aliyun_trace.entities.semconv import (
from core.rag.models.document import Document
from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@@ -13,7 +14,6 @@ from core.ops.aliyun_trace.entities.semconv import (
OUTPUT_VALUE,
GenAISpanKind,
)
from core.rag.models.document import Document
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
from dify_trace_aliyun.data_exporter.traceclient import create_link
links = []
if trace_id:

View File

@@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch
import httpx
import pytest
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import SpanKind, Status, StatusCode
from core.ops.aliyun_trace.data_exporter.traceclient import (
from dify_trace_aliyun.data_exporter.traceclient import (
INVALID_SPAN_ID,
SpanBuilder,
TraceClient,
@@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
create_link,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import SpanKind, Status, StatusCode
@pytest.fixture
@@ -41,8 +40,8 @@ def trace_client_factory():
class TestTraceClient:
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
mock_gethostname.return_value = "test-host"
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -56,7 +55,7 @@ class TestTraceClient:
client.shutdown()
assert client.done is True
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -64,8 +63,8 @@ class TestTraceClient:
client.export(spans)
mock_exporter.export.assert_called_once_with(spans)
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 405
@@ -74,8 +73,8 @@ class TestTraceClient:
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is True
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 500
@@ -84,8 +83,8 @@ class TestTraceClient:
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is False
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
mock_head.side_effect = httpx.RequestError("Connection error")
@@ -93,12 +92,12 @@ class TestTraceClient:
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
client.api_check()
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_add_span(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(
service_name="test-service",
@@ -134,8 +133,8 @@ class TestTraceClient:
assert len(client.queue) == 2
mock_notify.assert_called_once()
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
@@ -159,7 +158,7 @@ class TestTraceClient:
assert len(client.queue) == 1
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
mock_exporter.export.side_effect = Exception("Export failed")
@@ -168,11 +167,11 @@ class TestTraceClient:
mock_span = MagicMock(spec=ReadableSpan)
client.queue.append(mock_span)
with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
client._export_batch()
mock_logger.warning.assert_called()
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
# We need to test the wait timeout in _worker
# But _worker runs in a thread. Let's mock condition.wait.
@@ -189,7 +188,7 @@ class TestTraceClient:
# mock_wait might have been called
assert mock_wait.called or client.done
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -268,7 +267,7 @@ def test_generate_span_id():
assert span_id != INVALID_SPAN_ID
# Test retry loop
with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand:
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
span_id = generate_span_id()
assert span_id == 999
@@ -290,7 +289,7 @@ def test_convert_to_trace_id():
def test_convert_string_to_id():
assert convert_string_to_id("test") > 0
# Test with None string
with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen:
mock_gen.return_value = 12345
assert convert_string_to_id(None) == 12345

View File

@@ -1,11 +1,10 @@
import pytest
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import ValidationError
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
class TestTraceMetadata:
def test_trace_metadata_init(self):

View File

@@ -1,4 +1,4 @@
from core.ops.aliyun_trace.entities.semconv import (
from dify_trace_aliyun.entities.semconv import (
ACS_ARMS_SERVICE_FEATURE,
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,

View File

@@ -4,12 +4,11 @@ from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.aliyun_trace.entities.semconv import (
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
@@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.entities.config_entity import AliyunConfig
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,

View File

@@ -1,9 +1,7 @@
import json
from unittest.mock import MagicMock
from opentelemetry.trace import Link, StatusCode
from core.ops.aliyun_trace.entities.semconv import (
from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import (
INPUT_VALUE,
OUTPUT_VALUE,
)
from core.ops.aliyun_trace.utils import (
from dify_trace_aliyun.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import (
get_workflow_node_status,
serialize_json_data,
)
from opentelemetry.trace import Link, StatusCode
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
mock_session = MagicMock()
mock_session.get.return_value = end_user_data
from core.ops.aliyun_trace.utils import db
from dify_trace_aliyun.utils import db
monkeypatch.setattr(db, "session", mock_session)
@@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
mock_session = MagicMock()
mock_session.get.return_value = None
from core.ops.aliyun_trace.utils import db
from dify_trace_aliyun.utils import db
monkeypatch.setattr(db, "session", mock_session)
@@ -112,9 +112,9 @@ def test_get_workflow_node_status():
def test_create_links_from_trace_id(monkeypatch):
# Mock create_link
mock_link = MagicMock(spec=Link)
import core.ops.aliyun_trace.data_exporter.traceclient
import dify_trace_aliyun.data_exporter.traceclient
monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
# Trace ID None
assert create_links_from_trace_id(None) == []

View File

@@ -0,0 +1,85 @@
import pytest
from dify_trace_aliyun.config import AliyunConfig
from pydantic import ValidationError
class TestAliyunConfig:
"""Test cases for AliyunConfig"""
def test_valid_config(self):
"""Test valid Aliyun configuration"""
config = AliyunConfig(
app_name="test_app",
license_key="test_license_key",
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
)
assert config.app_name == "test_app"
assert config.license_key == "test_license_key"
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
assert config.app_name == "dify_app"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert config.app_name == "dify_app"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = AliyunConfig(license_key="test_license", endpoint="")
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation preserves path for Aliyun endpoints"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_license_key_required(self):
"""Test that license_key is required and cannot be empty"""
with pytest.raises(ValidationError):
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_valid_endpoint_format_examples(self):
"""Test valid endpoint format examples from comments"""
valid_endpoints = [
# cms2.0 public endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
# cms2.0 intranet endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
# xtrace public endpoint
"http://tracing-cn-heyuan.arms.aliyuncs.com",
# xtrace intranet endpoint
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
]
for endpoint in valid_endpoints:
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
assert config.endpoint == endpoint

View File

@@ -0,0 +1,10 @@
[project]
name = "dify-trace-arize-phoenix"
version = "0.0.1"
dependencies = [
"arize-phoenix-otel~=0.15.0",
]
description = "Dify ops tracing provider (Arize / Phoenix)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -25,7 +25,6 @@ from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -39,6 +38,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile

View File

@@ -0,0 +1,45 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://app.phoenix.arize.com")

View File

@@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from opentelemetry.sdk.trace import Tracer
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import StatusCode
from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
from dify_trace_arize_phoenix.arize_phoenix_trace import (
ArizePhoenixDataTrace,
datetime_to_nanos,
error_to_string,
@@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
setup_tracer,
wrap_span_metadata,
)
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from opentelemetry.sdk.trace import Tracer
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -80,7 +80,7 @@ def test_datetime_to_nanos():
expected = int(dt.timestamp() * 1_000_000_000)
assert datetime_to_nanos(dt) == expected
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt:
mock_now = MagicMock()
mock_now.timestamp.return_value = 1704110400.0
mock_dt.now.return_value = mock_now
@@ -142,8 +142,8 @@ def test_wrap_span_metadata():
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_arize(mock_provider, mock_exporter):
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
setup_tracer(config)
@@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter):
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
config = PhoenixConfig(endpoint="http://p.com", project="p")
setup_tracer(config)
@@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter):
def test_setup_tracer_exception():
config = ArizeConfig(endpoint="http://a.com", project="p")
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
with pytest.raises(Exception, match="boom"):
setup_tracer(config)
@@ -172,7 +172,7 @@ def test_setup_tracer_exception():
@pytest.fixture
def trace_instance():
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup:
mock_tracer = MagicMock(spec=Tracer)
mock_processor = MagicMock()
mock_setup.return_value = (mock_tracer, mock_processor)
@@ -228,9 +228,9 @@ def test_trace_exception(trace_instance):
trace_instance.trace(_make_workflow_info())
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
@@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac
assert trace_instance.tracer.start_span.call_count >= 2
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_workflow_trace_no_app_id(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
@@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance):
trace_instance.workflow_trace(info)
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_message_trace_success(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
@@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance):
assert trace_instance.tracer.start_span.call_count >= 1
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_message_trace_with_error(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()

View File

@@ -1,6 +1,6 @@
from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from openinference.semconv.trace import OpenInferenceSpanKindValues
from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes

View File

@@ -0,0 +1,88 @@
import pytest
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from pydantic import ValidationError
class TestArizeConfig:
"""Test cases for ArizeConfig"""
def test_valid_config(self):
"""Test valid Arize configuration"""
config = ArizeConfig(
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
)
assert config.api_key == "test_key"
assert config.space_id == "test_space"
assert config.project == "test_project"
assert config.endpoint == "https://custom.arize.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = ArizeConfig()
assert config.api_key is None
assert config.space_id is None
assert config.project is None
assert config.endpoint == "https://otlp.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = ArizeConfig(project="")
assert config.project == "default"
def test_project_validation_none(self):
"""Test project validation with None value"""
config = ArizeConfig(project=None)
assert config.project == "default"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = ArizeConfig(endpoint="")
assert config.endpoint == "https://otlp.arize.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
assert config.endpoint == "https://custom.arize.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="ftp://invalid.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="invalid.com")
class TestPhoenixConfig:
"""Test cases for PhoenixConfig"""
def test_valid_config(self):
"""Test valid Phoenix configuration"""
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.phoenix.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = PhoenixConfig()
assert config.api_key is None
assert config.project is None
assert config.endpoint == "https://app.phoenix.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = PhoenixConfig(project="")
assert config.project == "default"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation with path"""
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
def test_endpoint_validation_without_path(self):
"""Test endpoint validation without path"""
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
assert config.endpoint == "https://app.phoenix.arize.com"

View File

@@ -0,0 +1,10 @@
[project]
name = "dify-trace-langfuse"
version = "0.0.1"
dependencies = [
"langfuse>=4.2.0,<5.0.0",
]
description = "Dify ops tracing provider (Langfuse)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1,19 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://api.langfuse.com")

View File

@@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.entities.langfuse_trace_entity import (
GenerationUsage,
LangfuseGeneration,
LangfuseSpan,
@@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
LevelEnum,
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser, WorkflowNodeExecutionTriggeredFrom

View File

@@ -5,8 +5,16 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.entities.langfuse_trace_entity import (
LangfuseGeneration,
LangfuseSpan,
LangfuseTrace,
LevelEnum,
UnitEnum,
)
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
LangfuseGeneration,
LangfuseSpan,
LangfuseTrace,
LevelEnum,
UnitEnum,
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from graphon.enums import BuiltinNodeTypes
from models import EndUser
from models.enums import MessageStatus
@@ -43,7 +43,7 @@ def langfuse_config():
def trace_instance(langfuse_config, monkeypatch):
# Mock Langfuse client to avoid network calls
mock_client = MagicMock()
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
instance = LangFuseDataTrace(langfuse_config)
return instance
@@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch):
def test_init(langfuse_config, monkeypatch):
mock_langfuse = MagicMock()
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangFuseDataTrace(langfuse_config)
@@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
# Mock DB and Repositories
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
@@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
error="",
)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
@@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
workflow_app_log_id="log-1",
error="",
)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_trace = MagicMock()
trace_instance.add_generation = MagicMock()
@@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
repo.get_by_workflow_execution.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()

View File

@@ -0,0 +1,42 @@
import pytest
from dify_trace_langfuse.config import LangfuseConfig
from pydantic import ValidationError
class TestLangfuseConfig:
"""Test cases for LangfuseConfig"""
def test_valid_config(self):
"""Test valid Langfuse configuration"""
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == "https://custom.langfuse.com"
def test_valid_config_with_path(self):
host = "https://custom.langfuse.com/api/v1"
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == host
def test_default_values(self):
"""Test default values are set correctly"""
config = LangfuseConfig(public_key="public", secret_key="secret")
assert config.host == "https://api.langfuse.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
def test_host_validation_empty(self):
"""Test host validation with empty value"""
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
assert config.host == "https://api.langfuse.com"

View File

@@ -4,14 +4,15 @@ from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.ops.entities.config_entity import LangfuseConfig
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from graphon.enums import BuiltinNodeTypes
def _create_trace_instance() -> LangFuseDataTrace:
with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True):
with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True):
return LangFuseDataTrace(
LangfuseConfig(
public_key="public-key",
@@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime:
patch.object(trace, "add_span"),
patch.object(trace, "add_generation") as add_generation,
patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()),
patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()),
patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()),
patch(
"core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
"dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=repository,
),
):

View File

@@ -0,0 +1,10 @@
[project]
name = "dify-trace-langsmith"
version = "0.0.1"
dependencies = [
"langsmith~=0.7.30",
]
description = "Dify ops tracing provider (LangSmith)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1,20 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url
class LangSmithConfig(BaseTracingConfig):
"""
Model class for Langsmith tracing config.
"""
api_key: str
project: str
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))

View File

@@ -9,7 +9,6 @@ from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -21,13 +20,14 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_langsmith.config import LangSmithConfig
from dify_trace_langsmith.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@@ -3,8 +3,14 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
from dify_trace_langsmith.config import LangSmithConfig
from dify_trace_langsmith.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser
@@ -38,7 +38,7 @@ def langsmith_config():
def trace_instance(langsmith_config, monkeypatch):
# Mock LangSmith client
mock_client = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
instance = LangSmithDataTrace(langsmith_config)
return instance
@@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch):
def test_init(langsmith_config, monkeypatch):
mock_client_class = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangSmithDataTrace(langsmith_config)
@@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch):
# Mock dependencies
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
@@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
)
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
@@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_info.error = ""
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch):
# Mock EndUser lookup
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_run = MagicMock()
@@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()

View File

@@ -0,0 +1,35 @@
import pytest
from dify_trace_langsmith.config import LangSmithConfig
from pydantic import ValidationError
class TestLangSmithConfig:
"""Test cases for LangSmithConfig"""
def test_valid_config(self):
"""Test valid LangSmith configuration"""
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.smith.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = LangSmithConfig(api_key="key", project="project")
assert config.endpoint == "https://api.smith.langchain.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")

View File

@@ -0,0 +1,10 @@
[project]
name = "dify-trace-mlflow"
version = "0.0.1"
dependencies = [
"mlflow-skinny>=3.11.1",
]
description = "Dify ops tracing provider (MLflow / Databricks)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1,46 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_integer_id, validate_url_with_path
class MLflowConfig(BaseTracingConfig):
"""
Model class for MLflow tracing config.
"""
tracking_uri: str = "http://localhost:5000"
experiment_id: str = "0" # Default experiment id in MLflow is 0
username: str | None = None
password: str | None = None
@field_validator("tracking_uri")
@classmethod
def tracking_uri_validator(cls, v, info: ValidationInfo):
if isinstance(v, str) and v.startswith("databricks"):
raise ValueError(
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
)
return validate_url_with_path(v, "http://localhost:5000")
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
class DatabricksConfig(BaseTracingConfig):
"""
Model class for Databricks (Databricks-managed MLflow) tracing config.
"""
experiment_id: str
host: str
client_id: str | None = None
client_secret: str | None = None
personal_access_token: str | None = None
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)

View File

@@ -11,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -24,6 +23,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser

View File

@@ -1,4 +1,4 @@
"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module."""
"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module."""
from __future__ import annotations
@@ -9,8 +9,9 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -20,7 +21,6 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
from graphon.enums import BuiltinNodeTypes
# ── Helpers ──────────────────────────────────────────────────────────────────
@@ -179,7 +179,7 @@ def _make_node(**overrides):
@pytest.fixture
def mock_mlflow():
with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock:
with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock:
yield mock
@@ -187,10 +187,10 @@ def mock_mlflow():
def mock_tracing():
"""Patch all MLflow tracing functions used by the module."""
with (
patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start,
patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update,
patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set,
patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach,
patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start,
patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update,
patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set,
patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach,
):
yield {
"start": mock_start,
@@ -202,7 +202,7 @@ def mock_tracing():
@pytest.fixture
def mock_db():
with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock:
with patch("dify_trace_mlflow.mlflow_trace.db") as mock:
yield mock

View File

@@ -0,0 +1,10 @@
[project]
name = "dify-trace-opik"
version = "0.0.1"
dependencies = [
"opik~=1.11.2",
]
description = "Dify ops tracing provider (Opik)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1,25 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")

View File

@@ -10,7 +10,6 @@ from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -23,6 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_opik.config import OpikConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@@ -5,8 +5,9 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from dify_trace_opik.config import OpikConfig
from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -17,7 +18,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser
from models.enums import MessageStatus
@@ -37,7 +37,7 @@ def opik_config():
@pytest.fixture
def trace_instance(opik_config, monkeypatch):
mock_client = MagicMock()
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
instance = OpikDataTrace(opik_config)
return instance
@@ -67,7 +67,7 @@ def test_prepare_opik_uuid():
def test_init(opik_config, monkeypatch):
mock_opik = MagicMock()
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = OpikDataTrace(opik_config)
@@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
)
mock_session = MagicMock()
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
node_llm = MagicMock()
node_llm.id = LLM_NODE_ID
@@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
error="",
)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
@@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
error="",
)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user)
monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
trace_instance.add_span = MagicMock()
@@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch
repo.get_by_workflow_execution.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()

View File

@@ -0,0 +1,48 @@
import pytest
from dify_trace_opik.config import OpikConfig
from pydantic import ValidationError
class TestOpikConfig:
"""Test cases for OpikConfig"""
def test_valid_config(self):
"""Test valid Opik configuration"""
config = OpikConfig(
api_key="test_key",
project="test_project",
workspace="test_workspace",
url="https://custom.comet.com/opik/api/",
)
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.workspace == "test_workspace"
assert config.url == "https://custom.comet.com/opik/api/"
def test_default_values(self):
"""Test default values are set correctly"""
config = OpikConfig()
assert config.api_key is None
assert config.project is None
assert config.workspace is None
assert config.url == "https://www.comet.com/opik/api/"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = OpikConfig(project="")
assert config.project == "Default Project"
def test_url_validation_empty(self):
"""Test URL validation with empty value"""
config = OpikConfig(url="")
assert config.url == "https://www.comet.com/opik/api/"
def test_url_validation_missing_suffix(self):
"""Test URL validation requires /api/ suffix"""
with pytest.raises(ValidationError, match="URL should end with /api/"):
OpikConfig(url="https://custom.comet.com/opik/")
def test_url_validation_invalid_scheme(self):
"""Test URL validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
OpikConfig(url="ftp://custom.comet.com/opik/api/")

View File

@@ -14,8 +14,9 @@ import uuid
from datetime import datetime
from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo
from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
# A stable UUID4 used as the workflow_run_id throughout all tests.
_WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6"
@@ -56,8 +57,8 @@ def _make_workflow_trace_info(
def _make_opik_trace_instance() -> OpikDataTrace:
"""Construct an OpikDataTrace with the Opik SDK client mocked out."""
with patch("core.ops.opik_trace.opik_trace.Opik"):
from core.ops.entities.config_entity import OpikConfig
with patch("dify_trace_opik.opik_trace.Opik"):
from dify_trace_opik.config import OpikConfig
config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/")
instance = OpikDataTrace(config)
@@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId:
fake_repo.get_by_workflow_execution.return_value = node_executions or []
with (
patch("core.ops.opik_trace.opik_trace.db") as mock_db,
patch("core.ops.opik_trace.opik_trace.sessionmaker"),
patch("dify_trace_opik.opik_trace.db") as mock_db,
patch("dify_trace_opik.opik_trace.sessionmaker"),
patch(
"core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
"dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=fake_repo,
),
):
@@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId:
fake_repo.get_by_workflow_execution.return_value = node_executions or []
with (
patch("core.ops.opik_trace.opik_trace.db") as mock_db,
patch("core.ops.opik_trace.opik_trace.sessionmaker"),
patch("dify_trace_opik.opik_trace.db") as mock_db,
patch("dify_trace_opik.opik_trace.sessionmaker"),
patch(
"core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
"dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=fake_repo,
),
):

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-trace-tencent"
version = "0.0.1"
dependencies = [
# versions inherited from parent
"opentelemetry-api",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
"opentelemetry-semantic-conventions",
]
description = "Dify ops tracing provider (Tencent APM)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1,30 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")

View File

@@ -14,7 +14,8 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.semconv import (
from core.rag.models.document import Document
from dify_trace_tencent.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
@@ -38,9 +39,8 @@ from core.ops.tencent_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.rag.models.document import Document
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from dify_trace_tencent.utils import TencentTraceUtils
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus

View File

@@ -8,7 +8,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -19,11 +18,12 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.client import TencentTraceClient
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from dify_trace_tencent.client import TencentTraceClient
from dify_trace_tencent.config import TencentConfig
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from dify_trace_tencent.span_builder import TencentSpanBuilder
from dify_trace_tencent.utils import TencentTraceUtils
from extensions.ext_database import db
from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,

View File

@@ -8,13 +8,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from core.ops.tencent_trace import client as client_module
from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []

View File

@@ -1,15 +1,7 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from opentelemetry.trace import StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.semconv import (
from dify_trace_tencent.entities.semconv import (
GEN_AI_IS_ENTRY,
GEN_AI_IS_STREAMING_REQUEST,
GEN_AI_MODEL_NAME,
@@ -23,7 +15,15 @@ from core.ops.tencent_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from dify_trace_tencent.span_builder import TencentSpanBuilder
from opentelemetry.trace import StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -31,7 +31,7 @@ from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutio
class TestTencentSpanBuilder:
def test_get_time_nanoseconds(self):
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
mock_convert.return_value = 123456789
dt = datetime.now()
result = TencentSpanBuilder._get_time_nanoseconds(dt)
@@ -48,7 +48,7 @@ class TestTencentSpanBuilder:
trace_info.workflow_run_outputs = {"answer": "world"}
trace_info.metadata = {"conversation_id": "conv_id"}
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
@@ -70,7 +70,7 @@ class TestTencentSpanBuilder:
trace_info.workflow_run_outputs = {}
trace_info.metadata = {} # No conversation_id
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 1
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
@@ -98,7 +98,7 @@ class TestTencentSpanBuilder:
}
node_execution.outputs = {"text": "world"}
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
@@ -123,7 +123,7 @@ class TestTencentSpanBuilder:
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
}
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
@@ -142,7 +142,7 @@ class TestTencentSpanBuilder:
trace_info.metadata = {"conversation_id": "conv_id"}
trace_info.is_streaming_request = True
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
@@ -162,7 +162,7 @@ class TestTencentSpanBuilder:
trace_info.metadata = {}
trace_info.is_streaming_request = False
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
@@ -182,7 +182,7 @@ class TestTencentSpanBuilder:
trace_info.tool_inputs = {"i": 2}
trace_info.tool_outputs = "result"
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 101
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1)
@@ -204,7 +204,7 @@ class TestTencentSpanBuilder:
)
trace_info.documents = [doc]
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
@@ -222,7 +222,7 @@ class TestTencentSpanBuilder:
trace_info.end_time = datetime.now()
trace_info.documents = []
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
@@ -264,7 +264,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
@@ -286,7 +286,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
@@ -307,7 +307,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
@@ -329,7 +329,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
@@ -350,7 +350,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 505
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution)

View File

@@ -2,8 +2,9 @@ import logging
from unittest.mock import MagicMock, patch
import pytest
from dify_trace_tencent.config import TencentConfig
from dify_trace_tencent.tencent_trace import TencentDataTrace
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -13,7 +14,6 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes
from models import Account, App, TenantAccountJoin
@@ -28,19 +28,19 @@ def tencent_config():
@pytest.fixture
def mock_trace_client():
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock:
with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock:
yield mock
@pytest.fixture
def mock_span_builder():
with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock:
with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock:
yield mock
@pytest.fixture
def mock_trace_utils():
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock:
with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock:
yield mock
@@ -198,9 +198,9 @@ class TestTencentDataTrace:
trace_info.workflow_run_id = "run-id"
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.workflow_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
@@ -230,9 +230,9 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.message_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
@@ -262,9 +262,9 @@ class TestTencentDataTrace:
trace_info.message_id = "msg-id"
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.tool_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
@@ -294,22 +294,22 @@ class TestTencentDataTrace:
trace_info.message_id = "msg-id"
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.dataset_retrieval_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
def test_suggested_question_trace(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
def test_suggested_question_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
@@ -342,7 +342,7 @@ class TestTencentDataTrace:
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
# The exception should be caught by the outer handler since convert_to_span_id is called first
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
@@ -351,7 +351,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
@@ -381,7 +381,7 @@ class TestTencentDataTrace:
node.id = "n1"
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
assert result is None
mock_log.assert_called_once()
@@ -403,15 +403,13 @@ class TestTencentDataTrace:
mock_executions = [MagicMock()]
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.engine = "engine"
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.side_effect = [app, account, tenant_join]
with patch(
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
) as mock_repo:
with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo:
mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions
results = tencent_data_trace._get_workflow_node_executions(trace_info)
@@ -423,7 +421,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
@@ -432,14 +430,14 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock() # Ensure init_app is mocked
mock_db.engine = "engine"
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.return_value = None
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
@@ -449,8 +447,8 @@ class TestTencentDataTrace:
trace_info.tenant_id = "tenant-1"
trace_info.metadata = {"user_id": "user-1"}
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock()
mock_db.engine = MagicMock()
@@ -476,8 +474,8 @@ class TestTencentDataTrace:
trace_info.tenant_id = "t"
trace_info.metadata = {"user_id": "u"}
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
@@ -519,7 +517,7 @@ class TestTencentDataTrace:
node.process_data = None
node.outputs = None
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_llm_metrics(node)
# Should not crash
@@ -557,7 +555,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = None
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_llm_metrics(trace_info)
# Should not crash
@@ -609,7 +607,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_workflow_trace_duration(trace_info)
def test_record_message_trace_duration(self, tencent_data_trace):
@@ -631,7 +629,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.start_time = None
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
def test_del(self, tencent_data_trace):
@@ -641,6 +639,6 @@ class TestTencentDataTrace:
def test_del_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.__del__()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")

View File

@@ -8,10 +8,9 @@ from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from dify_trace_tencent.utils import TencentTraceUtils
from opentelemetry.trace import Link, TraceFlags
from core.ops.tencent_trace.utils import TencentTraceUtils
def test_convert_to_trace_id_with_valid_uuid() -> None:
uuid_str = "12345678-1234-5678-1234-567812345678"
@@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None:
def test_convert_to_trace_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int
uuid4_mock.assert_called_once()
@@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None:
def test_convert_to_span_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
span_id = TencentTraceUtils.convert_to_span_id(None, "workflow")
assert isinstance(span_id, int)
uuid4_mock.assert_called_once()
@@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None:
def test_generate_span_id_skips_invalid_span_id() -> None:
with patch(
"core.ops.tencent_trace.utils.random.getrandbits",
"dify_trace_tencent.utils.random.getrandbits",
side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42],
) as bits_mock:
assert TencentTraceUtils.generate_span_id() == 42
@@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None:
fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC)
expected = int(fixed.timestamp() * 1e9)
with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock:
with patch("dify_trace_tencent.utils.datetime") as datetime_mock:
datetime_mock.now.return_value = fixed
assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected
datetime_mock.now.assert_called_once()
@@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i
@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None])
def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None:
fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd")
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type]
assert link.context.trace_id == fallback_uuid.int
uuid4_mock.assert_called_once()

View File

@@ -0,0 +1,10 @@
[project]
name = "dify-trace-weave"
version = "0.0.1"
dependencies = [
"weave>=0.52.36",
]
description = "Dify ops tracing provider (Weave)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1,29 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url
class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""
api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v

View File

@@ -17,7 +17,6 @@ from weave.trace_server.trace_server_interface import (
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -29,8 +28,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_weave.config import WeaveConfig
from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@@ -0,0 +1,61 @@
import pytest
from dify_trace_weave.config import WeaveConfig
from pydantic import ValidationError
class TestWeaveConfig:
"""Test cases for WeaveConfig"""
def test_valid_config(self):
"""Test valid Weave configuration"""
config = WeaveConfig(
api_key="test_key",
entity="test_entity",
project="test_project",
endpoint="https://custom.wandb.ai",
host="https://custom.host.com",
)
assert config.api_key == "test_key"
assert config.entity == "test_entity"
assert config.project == "test_project"
assert config.endpoint == "https://custom.wandb.ai"
assert config.host == "https://custom.host.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = WeaveConfig(api_key="key", project="project")
assert config.entity is None
assert config.endpoint == "https://trace.wandb.ai"
assert config.host is None
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
with pytest.raises(ValidationError):
WeaveConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
def test_host_validation_optional(self):
"""Test host validation is optional but validates when provided"""
config = WeaveConfig(api_key="key", project="project", host=None)
assert config.host is None
config = WeaveConfig(api_key="key", project="project", host="")
assert config.host == ""
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
assert config.host == "https://valid.host.com"
def test_host_validation_invalid_scheme(self):
"""Test host validation rejects invalid schemes when provided"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")

View File

@@ -1,4 +1,4 @@
"""Comprehensive tests for core.ops.weave_trace.weave_trace module."""
"""Comprehensive tests for dify_trace_weave.weave_trace module."""
from __future__ import annotations
@@ -7,9 +7,11 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from dify_trace_weave.config import WeaveConfig
from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
from dify_trace_weave.weave_trace import WeaveDataTrace
from weave.trace_server.trace_server_interface import TraceStatus
from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -20,8 +22,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
# ── Helpers ──────────────────────────────────────────────────────────────────
@@ -191,14 +191,14 @@ def _make_node(**overrides):
@pytest.fixture
def mock_wandb():
with patch("core.ops.weave_trace.weave_trace.wandb") as mock:
with patch("dify_trace_weave.weave_trace.wandb") as mock:
mock.login.return_value = True
yield mock
@pytest.fixture
def mock_weave():
with patch("core.ops.weave_trace.weave_trace.weave") as mock:
with patch("dify_trace_weave.weave_trace.weave") as mock:
client = MagicMock()
client.entity = "my-entity"
client.project = "my-project"
@@ -307,7 +307,7 @@ class TestGetProjectUrl:
monkeypatch.setattr(trace_instance, "entity", None)
monkeypatch.setattr(trace_instance, "project_name", None)
# Force an error by making string formatting fail
with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger:
with patch("dify_trace_weave.weave_trace.logger") as mock_logger:
# Simulate exception via property
original_entity = trace_instance.entity
trace_instance.entity = None
@@ -594,9 +594,9 @@ class TestWorkflowTrace:
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock())
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
return repo
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
@@ -703,8 +703,8 @@ class TestWorkflowTrace:
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
"""Raises ValueError when app_id is missing from metadata."""
monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock())
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
trace_info = _make_workflow_trace_info(
message_id=None,
@@ -802,7 +802,7 @@ class TestMessageTrace:
def test_basic_message_trace(self, trace_instance, monkeypatch):
"""message_trace creates message run and llm child run."""
monkeypatch.setattr(
"core.ops.weave_trace.weave_trace.db.session.get",
"dify_trace_weave.weave_trace.db.session.get",
lambda model, pk: None,
)
@@ -824,7 +824,7 @@ class TestMessageTrace:
mock_db = MagicMock()
mock_db.session.get.return_value = None
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -846,7 +846,7 @@ class TestMessageTrace:
mock_db = MagicMock()
mock_db.session.get.return_value = end_user
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -866,7 +866,7 @@ class TestMessageTrace:
"""message_trace handles when from_end_user_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -884,7 +884,7 @@ class TestMessageTrace:
"""trace_id falls back to message_id when trace_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -899,7 +899,7 @@ class TestMessageTrace:
"""message_trace handles file_list=None gracefully."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()

View File

@@ -32,9 +32,6 @@ dependencies = [
"flask-restx>=1.3.2,<2.0.0",
"google-cloud-aiplatform>=1.147.0,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
"langfuse>=4.2.0,<5.0.0",
"langsmith>=0.7.31,<1.0.0",
"mlflow-skinny>=3.11.1,<4.0.0",
"opentelemetry-distro>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-flask>=0.62b0,<1.0.0",
@@ -44,15 +41,12 @@ dependencies = [
"opentelemetry-propagator-b3>=1.41.0,<2.0.0",
"readabilipy>=0.3.0,<1.0.0",
"resend>=2.27.0,<3.0.0",
"weave>=0.52.36,<1.0.0",
# Emerging: newer and fast-moving, use compatible pins
"arize-phoenix-otel~=0.15.0",
"fastopenapi[flask]~=0.7.0",
"graphon~=0.1.2",
"httpx-sse~=0.4.0",
"json-repair~=0.59.2",
"opik~=1.11.2",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -61,8 +55,8 @@ dependencies = [
packages = []
[tool.uv.workspace]
members = ["providers/vdb/*"]
exclude = ["providers/vdb/__pycache__"]
members = ["providers/vdb/*", "providers/trace/*"]
exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"]
[tool.uv.sources]
dify-vdb-alibabacloud-mysql = { workspace = true }
@@ -95,9 +89,17 @@ dify-vdb-upstash = { workspace = true }
dify-vdb-vastbase = { workspace = true }
dify-vdb-vikingdb = { workspace = true }
dify-vdb-weaviate = { workspace = true }
dify-trace-aliyun = { workspace = true }
dify-trace-arize-phoenix = { workspace = true }
dify-trace-langfuse = { workspace = true }
dify-trace-langsmith = { workspace = true }
dify-trace-mlflow = { workspace = true }
dify-trace-opik = { workspace = true }
dify-trace-tencent = { workspace = true }
dify-trace-weave = { workspace = true }
[tool.uv]
default-groups = ["storage", "tools", "vdb-all"]
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = [
"pyarrow>=18.0.0",
@@ -266,6 +268,25 @@ vdb-weaviate = ["dify-vdb-weaviate"]
# Optional client used by some tests / integrations (not a vector backend plugin)
vdb-xinference = ["xinference-client>=2.4.0"]
trace-all = [
"dify-trace-aliyun",
"dify-trace-arize-phoenix",
"dify-trace-langfuse",
"dify-trace-langsmith",
"dify-trace-mlflow",
"dify-trace-opik",
"dify-trace-tencent",
"dify-trace-weave",
]
trace-aliyun = ["dify-trace-aliyun"]
trace-arize-phoenix = ["dify-trace-arize-phoenix"]
trace-langfuse = ["dify-trace-langfuse"]
trace-langsmith = ["dify-trace-langsmith"]
trace-mlflow = ["dify-trace-mlflow"]
trace-opik = ["dify-trace-opik"]
trace-tencent = ["dify-trace-tencent"]
trace-weave = ["dify-trace-weave"]
[tool.pyrefly]
project-includes = ["."]
project-excludes = [".venv", "migrations/"]

View File

@@ -34,12 +34,12 @@ core/external_data_tool/api/api.py
core/llm_generator/llm_generator.py
core/llm_generator/output_parser/structured_output.py
core/mcp/mcp_client.py
core/ops/aliyun_trace/data_exporter/traceclient.py
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
core/ops/mlflow_trace/mlflow_trace.py
providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
core/ops/ops_trace_manager.py
core/ops/tencent_trace/client.py
core/ops/tencent_trace/utils.py
providers/trace/trace-tencent/src/dify_trace_tencent/client.py
providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py

View File

@@ -5,7 +5,8 @@
".venv",
"migrations/",
"core/rag",
"providers/",
"providers/vdb/",
"providers/trace/*/tests",
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [

View File

@@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
@@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param template_id: Template ID
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(template_id)

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class CustomizedTemplateItemDict(TypedDict):
id: str
name: str
description: str
icon: dict[str, Any]
position: int
chunk_structure: str
class CustomizedTemplatesResultDict(TypedDict):
pipeline_templates: list[CustomizedTemplateItemDict]
class CustomizedTemplateDetailDict(TypedDict):
id: str
name: str
icon_info: dict[str, Any]
description: str
chunk_structure: str
export_data: str
graph: dict[str, Any]
created_by: str
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from database
@@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
_, current_tenant_id = current_account_with_tenant()
result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
return result
return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.CUSTOMIZED
@@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
).all()
recommended_pipelines_results = []
recommended_pipelines_results: list[CustomizedTemplateItemDict] = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
recommended_pipeline_result: CustomizedTemplateItemDict = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class PipelineTemplateItemDict(TypedDict):
id: str
name: str
description: str
icon: dict[str, Any]
copyright: str
privacy_policy: str
position: int
chunk_structure: str
class PipelineTemplatesResultDict(TypedDict):
pipeline_templates: list[PipelineTemplateItemDict]
class PipelineTemplateDetailDict(TypedDict):
id: str
name: str
icon_info: dict[str, Any]
description: str
chunk_structure: str
export_data: str
graph: dict[str, Any]
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
result = self.fetch_pipeline_templates_from_db(language)
return result
return self.fetch_pipeline_templates_from_db(language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.DATABASE
@@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
).all()
)
recommended_pipelines_results = []
recommended_pipelines_results: list[PipelineTemplateItemDict] = []
for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = {
recommended_pipeline_result: PipelineTemplateItemDict = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"description": pipeline_built_in_template.description,

View File

@@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
result: dict[str, Any] | None
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
return self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
return result
return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
try:
result = self.fetch_pipeline_templates_from_dify_official(language)
return self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
return result
return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
def get_type(self) -> str:
return PipelineTemplateType.REMOTE

View File

@@ -349,7 +349,6 @@ class SummaryIndexService:
summary_record_id,
)
summary_record_in_session = DocumentSegmentSummary(
id=summary_record_id, # Use the same ID if available
dataset_id=dataset.id,
document_id=segment.document_id,
chunk_id=segment.id,
@@ -360,6 +359,9 @@ class SummaryIndexService:
status=SummaryStatus.COMPLETED,
enabled=True,
)
if summary_record_in_session is None:
raise RuntimeError("summary_record_in_session should not be None at this point")
summary_record_in_session.id = summary_record_id
session.add(summary_record_in_session)
logger.info(
"Created new summary record (id=%s) for segment %s after vectorization",

View File

@@ -0,0 +1,650 @@
"""Testcontainers integration tests for SQL-backed DocumentService paths."""
import datetime
import json
from unittest.mock import create_autospec, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.errors.account import NoPermissionError
FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
class DocumentServiceIntegrationFactory:
@staticmethod
def create_dataset(
db_session_with_containers,
*,
tenant_id: str | None = None,
created_by: str | None = None,
name: str | None = None,
) -> Dataset:
dataset = Dataset(
tenant_id=tenant_id or str(uuid4()),
name=name or f"dataset-{uuid4()}",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by or str(uuid4()),
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_document(
db_session_with_containers,
*,
dataset: Dataset,
name: str = "doc.txt",
position: int = 1,
tenant_id: str | None = None,
indexing_status: str = IndexingStatus.COMPLETED,
enabled: bool = True,
archived: bool = False,
is_paused: bool = False,
need_summary: bool = False,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
batch: str | None = None,
data_source_type: str = DataSourceType.UPLOAD_FILE,
data_source_info: dict | None = None,
created_by: str | None = None,
) -> Document:
document = Document(
tenant_id=tenant_id or dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info or {}),
batch=batch or f"batch-{uuid4()}",
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or dataset.created_by,
doc_form=doc_form,
)
document.indexing_status = indexing_status
document.enabled = enabled
document.archived = archived
document.is_paused = is_paused
document.need_summary = need_summary
if indexing_status == IndexingStatus.COMPLETED:
document.completed_at = FIXED_UPLOAD_CREATED_AT
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@staticmethod
def create_upload_file(
db_session_with_containers,
*,
tenant_id: str,
created_by: str,
file_id: str | None = None,
name: str = "source.txt",
) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=StorageType.LOCAL,
key=f"uploads/{uuid4()}",
name=name,
size=128,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
created_at=FIXED_UPLOAD_CREATED_AT,
used=False,
)
if file_id:
upload_file.id = file_id
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
@pytest.fixture
def current_user_mock():
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
current_user.id = str(uuid4())
current_user.current_tenant_id = str(uuid4())
current_user.current_role = None
yield current_user
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
assert DocumentService.get_document(dataset.id, None) is None
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
result = DocumentService.get_document(dataset.id, document.id)
assert result is not None
assert result.id == document.id
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
result = DocumentService.get_documents_by_ids(dataset.id, [])
assert result == []
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
doc_b = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
name="b.txt",
position=2,
)
result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id])
assert {document.id for document in result} == {doc_a.id, doc_b.id}
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
paragraph_doc = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
need_summary=True,
)
qa_doc = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
need_summary=True,
doc_form=IndexStructureType.QA_INDEX,
)
updated_count = DocumentService.update_documents_need_summary(
dataset.id,
[paragraph_doc.id, qa_doc.id],
need_summary=False,
)
db_session_with_containers.expire_all()
refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id)
refreshed_qa = db_session_with_containers.get(Document, qa_doc.id)
assert updated_count == 1
assert refreshed_paragraph is not None
assert refreshed_qa is not None
assert refreshed_paragraph.need_summary is False
assert refreshed_qa.need_summary is True
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": upload_file.id},
)
with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url:
result = DocumentService.get_document_download_url(document)
assert result == "signed-url"
get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_type=DataSourceType.WEBSITE_CRAWL,
data_source_info={"url": "https://example.com"},
)
with pytest.raises(NotFound, match="invalid source"):
DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="invalid source",
missing_file_message="missing file",
)
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={},
)
with pytest.raises(NotFound, match="missing file"):
DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="invalid source",
missing_file_message="missing file",
)
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": 99},
)
result = DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="invalid source",
missing_file_message="missing file",
)
assert result == "99"
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": "missing-file"},
)
with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}):
with pytest.raises(NotFound, match="Uploaded file not found"):
DocumentService._get_upload_file_for_upload_file_document(document)
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": upload_file.id},
)
result = DocumentService._get_upload_file_for_upload_file_document(document)
assert result.id == upload_file.id
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
with pytest.raises(NotFound, match="Document not found"):
DocumentService._get_upload_files_by_document_id_for_zip_download(
dataset_id=dataset.id,
document_ids=[str(uuid4())],
tenant_id=dataset.tenant_id,
)
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
tenant_id=str(uuid4()),
data_source_info={"upload_file_id": upload_file.id},
)
with pytest.raises(Forbidden, match="No permission"):
DocumentService._get_upload_files_by_document_id_for_zip_download(
dataset_id=dataset.id,
document_ids=[document.id],
tenant_id=dataset.tenant_id,
)
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": str(uuid4())},
)
with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"):
DocumentService._get_upload_files_by_document_id_for_zip_download(
dataset_id=dataset.id,
document_ids=[document.id],
tenant_id=dataset.tenant_id,
)
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
name="a.txt",
)
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
name="b.txt",
)
document_a = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": upload_file_a.id},
)
document_b = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
data_source_info={"upload_file_id": upload_file_b.id},
)
mapping = DocumentService._get_upload_files_by_document_id_for_zip_download(
dataset_id=dataset.id,
document_ids=[document_a.id, document_b.id],
tenant_id=dataset.tenant_id,
)
assert mapping[document_a.id].id == upload_file_a.id
assert mapping[document_b.id].id == upload_file_b.id
def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(
current_user_mock, flask_app_with_containers
):
with flask_app_with_containers.app_context():
with pytest.raises(NotFound, match="Dataset not found"):
DocumentService.prepare_document_batch_download_zip(
dataset_id=str(uuid4()),
document_ids=[str(uuid4())],
tenant_id=current_user_mock.current_tenant_id,
current_user=current_user_mock,
)
def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
db_session_with_containers,
current_user_mock,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=current_user_mock.current_tenant_id,
created_by=current_user_mock.id,
)
with patch(
"services.dataset_service.DatasetService.check_dataset_permission",
side_effect=NoPermissionError("denied"),
):
with pytest.raises(Forbidden, match="denied"):
DocumentService.prepare_document_batch_download_zip(
dataset_id=dataset.id,
document_ids=[],
tenant_id=current_user_mock.current_tenant_id,
current_user=current_user_mock,
)
def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
db_session_with_containers,
current_user_mock,
):
dataset = DocumentServiceIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=current_user_mock.current_tenant_id,
created_by=current_user_mock.id,
)
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
name="a.txt",
)
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
name="b.txt",
)
document_a = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": upload_file_a.id},
)
document_b = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
data_source_info={"upload_file_id": upload_file_b.id},
)
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
dataset_id=dataset.id,
document_ids=[document_b.id, document_a.id],
tenant_id=current_user_mock.current_tenant_id,
current_user=current_user_mock,
)
assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id]
assert download_name.endswith(".zip")
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
enabled_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
enabled=True,
)
DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
enabled=False,
)
result = DocumentService.get_document_by_dataset_id(dataset.id)
assert [document.id for document in result] == [enabled_document.id]
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
available_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
indexing_status=IndexingStatus.ERROR,
)
result = DocumentService.get_working_documents_by_dataset_id(dataset.id)
assert [document.id for document in result] == [available_document.id]
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
error_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
indexing_status=IndexingStatus.ERROR,
)
paused_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
indexing_status=IndexingStatus.PAUSED,
)
DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=3,
indexing_status=IndexingStatus.COMPLETED,
)
result = DocumentService.get_error_documents_by_dataset_id(dataset.id)
assert {document.id for document in result} == {error_document.id, paused_document.id}
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
batch = f"batch-{uuid4()}"
matching_document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
batch=batch,
)
DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
tenant_id=str(uuid4()),
batch=batch,
)
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
current_user.current_tenant_id = dataset.tenant_id
result = DocumentService.get_batch_documents(dataset.id, batch)
assert [document.id for document in result] == [matching_document.id]
def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
)
result = DocumentService.get_document_file_detail(upload_file.id)
assert result is not None
assert result.id == upload_file.id
def test_delete_document_emits_signal_and_commits(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
)
document = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": upload_file.id},
)
with patch("services.dataset_service.document_was_deleted.send") as signal_send:
DocumentService.delete_document(document)
assert db_session_with_containers.get(Document, document.id) is None
signal_send.assert_called_once_with(
document.id,
dataset_id=document.dataset_id,
doc_form=document.doc_form,
file_id=upload_file.id,
)
def test_delete_documents_ignores_empty_input(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
DocumentService.delete_documents(dataset, [])
delay.assert_not_called()
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
db_session_with_containers.commit()
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
name="a.txt",
)
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
db_session_with_containers,
tenant_id=dataset.tenant_id,
created_by=dataset.created_by,
name="b.txt",
)
document_a = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
data_source_info={"upload_file_id": upload_file_a.id},
)
document_b = DocumentServiceIntegrationFactory.create_document(
db_session_with_containers,
dataset=dataset,
position=2,
data_source_info={"upload_file_id": upload_file_b.id},
)
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
DocumentService.delete_documents(dataset, [document_a.id, document_b.id])
assert db_session_with_containers.get(Document, document_a.id) is None
assert db_session_with_containers.get(Document, document_b.id) is None
delay.assert_called_once()
args = delay.call_args.args
assert args[0] == [document_a.id, document_b.id]
assert args[1] == dataset.id
assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
assert DocumentService.get_documents_position(dataset.id) == 4
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
assert DocumentService.get_documents_position(dataset.id) == 1

Some files were not shown because too many files have changed in this diff Show More