From af21dc7df875aec8437332ebb1f99bcef893a258 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Thu, 16 Apr 2026 22:02:04 -0700 Subject: [PATCH] refactor(api): migrate dataset document response schemas to BaseModel (#35298) Co-authored-by: ai-hpc Co-authored-by: Asuka Minato Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/datasets/datasets_document.py | 166 +++++++++++++----- .../datasets/test_datasets_document.py | 9 +- 2 files changed, 130 insertions(+), 45 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b61e39716f..3372a967d9 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,18 +3,19 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack +from datetime import datetime from typing import Any, Literal, cast import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload -from controllers.common.schema import get_or_create_model, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -29,11 +30,9 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.dataset_fields import dataset_fields +from fields.base import ResponseModel from fields.document_fields import ( - dataset_and_document_fields, document_fields, - document_metadata_fields, document_status_fields, document_with_segments_fields, ) @@ -72,27 +71,100 @@ from ..wraps import ( logger = logging.getLogger(__name__) -# Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = get_or_create_model("Dataset", dataset_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -document_fields_copy = document_fields.copy() -document_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_model = get_or_create_model("Document", document_fields_copy) +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) -document_with_segments_fields_copy = document_with_segments_fields.copy() -document_with_segments_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) -dataset_and_document_fields_copy = dataset_and_document_fields.copy() -dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) -dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +class DatasetResponse(ResponseModel): + id: str + name: str + description: str | None = None + permission: str | None = None + data_source_type: str | None = None + indexing_technique: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("data_source_type", "indexing_technique", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentMetadataResponse(ResponseModel): + id: str + name: str + type: str + value: str | None = None + + +class DocumentResponse(ResponseModel): + id: str + position: int | None = None + data_source_type: str | None = None + data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") + data_source_detail_dict: Any = None + dataset_process_rule_id: str | None = None + name: str + created_from: str | None = None + created_by: str | None = None + created_at: int | None = None + tokens: int | None = None + indexing_status: str | None = None + error: str | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + archived: bool | None = None + display_status: str | None = None + word_count: int | None = None + hit_count: int | None = None + doc_form: str | None = None + doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") + summary_index_status: str | None = None + need_summary: bool | None = None + + @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("doc_metadata", mode="before") + @classmethod + def _normalize_doc_metadata(cls, value: Any) -> list[Any]: + if value is None: + return [] + return value + + @field_validator("created_at", "disabled_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentWithSegmentsResponse(DocumentResponse): + process_rule_dict: Any = None + completed_segments: int | None = None + total_segments: int | None = None + + +class DatasetAndDocumentResponse(ResponseModel): + dataset: DatasetResponse + documents: list[DocumentResponse] + batch: str class DocumentRetryPayload(BaseModel): @@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] +class DocumentMetadataUpdatePayload(BaseModel): + doc_type: str | None = None + doc_metadata: Any = None + + class DocumentDatasetListParam(BaseModel): page: int = Field(1, title="Page", description="Page number.") limit: int = Field(20, title="Limit", description="Page size.") @@ -124,7 +201,13 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, + DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, + DatasetResponse, + DocumentMetadataResponse, + DocumentResponse, + DocumentWithSegmentsResponse, + DatasetAndDocumentResponse, ) @@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) + @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @setup_required @login_required @@ -426,12 +511,13 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) + @console_ns.response( + 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] + ) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -479,9 +565,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - - return response + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/datasets//documents//indexing-estimate") @@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect( - console_ns.model( - "UpdateDocumentMetadataRequest", - { - "doc_type": fields.String(description="Document type"), - "doc_metadata": fields.Raw(description="Document metadata"), - }, - ) - ) + @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() + req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) - doc_type = req_data.get("doc_type") - doc_metadata = req_data.get("doc_metadata") + doc_type = req_data.doc_type + doc_metadata = req_data.doc_metadata # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_model) + @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document + return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") @console_ns.route("/datasets//documents//website-sync") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index ce2278de4f..d9b02ac453 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -215,17 +216,23 @@ class TestDatasetDocumentListApi: method = unwrap(api.post) payload = {"indexing_technique": "economy"} + created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy") + created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=created_dataset, + ), patch( "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", - return_value=([MagicMock()], "batch-1"), + return_value=([created_document], "batch-1"), ), ): response = method(api, "ds-1")