diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index c55014a368..cb6cdb309a 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,5 +1,13 @@ -from flask_restx import Namespace, fields +from __future__ import annotations +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from graphon.variables.types import SegmentType +from pydantic import field_validator + +from fields.base import ResponseModel from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -29,6 +37,74 @@ conversation_variable_infinite_scroll_pagination_fields = { } +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type().value) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type().value) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields)