diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index f2ffa3b170..a6e6b1bae7 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -32,22 +32,33 @@ class AdvancedPromptTemplateService: def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt - ) - elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == AppMode.COMPLETION: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt - ) - elif model_mode == "chat": - return cls.get_chat_prompt( - copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt - ) + match app_mode: + case AppMode.CHAT: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + case _: + pass + case AppMode.COMPLETION: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + case _: + pass + case _: + pass # default return empty dict return {} @@ -73,25 +84,38 @@ class AdvancedPromptTemplateService: def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt - ) - elif model_mode == "chat": - return cls.get_chat_prompt( - copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt - ) - elif app_mode == AppMode.COMPLETION: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), - has_context, - baichuan_context_prompt, - ) - elif model_mode == "chat": - return cls.get_chat_prompt( - copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt - ) + match app_mode: + case AppMode.CHAT: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) + case _: + pass + case AppMode.COMPLETION: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + case _: + pass + case _: + pass # default return empty dict return {} diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 3bc30cb323..2013c869af 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -7,11 +7,12 @@ from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: - if app_mode == AppMode.CHAT: - return ChatAppConfigManager.config_validate(tenant_id, config) - elif app_mode == AppMode.AGENT_CHAT: - return AgentChatAppConfigManager.config_validate(tenant_id, config) - elif app_mode == AppMode.COMPLETION: - return CompletionAppConfigManager.config_validate(tenant_id, config) - else: - raise ValueError(f"Invalid app mode: {app_mode}") + match app_mode: + case AppMode.CHAT: + return ChatAppConfigManager.config_validate(tenant_id, config) + case AppMode.AGENT_CHAT: + return AgentChatAppConfigManager.config_validate(tenant_id, config) + case AppMode.COMPLETION: + return CompletionAppConfigManager.config_validate(tenant_id, config) + case AppMode.WORKFLOW | AppMode.ADVANCED_CHAT | AppMode.CHANNEL | AppMode.RAG_PIPELINE: + raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c1ad3f33ad..1582bcd46c 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -170,34 +170,38 @@ class WorkflowConverter: graph = self._append_node(graph, llm_node) - if new_app_mode == AppMode.WORKFLOW: - # convert to end node by app mode - end_node = self._convert_to_end_node() - graph = self._append_node(graph, end_node) - else: - answer_node = self._convert_to_answer_node() - graph = self._append_node(graph, answer_node) - app_model_config_dict = app_config.app_model_config_dict - # features - if new_app_mode == AppMode.ADVANCED_CHAT: - features = { - "opening_statement": app_model_config_dict.get("opening_statement"), - "suggested_questions": app_model_config_dict.get("suggested_questions"), - "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), - "speech_to_text": app_model_config_dict.get("speech_to_text"), - "text_to_speech": app_model_config_dict.get("text_to_speech"), - "file_upload": app_model_config_dict.get("file_upload"), - "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), - "retriever_resource": app_model_config_dict.get("retriever_resource"), - } - else: - features = { - "text_to_speech": app_model_config_dict.get("text_to_speech"), - "file_upload": app_model_config_dict.get("file_upload"), - "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), - } + match new_app_mode: + case AppMode.WORKFLOW: + end_node = self._convert_to_end_node() + graph = self._append_node(graph, end_node) + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + case AppMode.ADVANCED_CHAT: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + case _: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } # create workflow record workflow = Workflow( @@ -220,19 +224,23 @@ class WorkflowConverter: def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: app_mode_enum = AppMode.value_of(app_model.mode) app_config: EasyUIBasedAppConfig - if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config - ) - elif app_mode_enum == AppMode.CHAT: - app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode_enum == AppMode.COMPLETION: - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config - ) - else: - raise ValueError("Invalid app mode") + effective_mode = ( + AppMode.AGENT_CHAT if app_model.is_agent and app_mode_enum != AppMode.AGENT_CHAT else app_mode_enum + ) + match effective_mode: + case AppMode.AGENT_CHAT: + app_model.mode = AppMode.AGENT_CHAT + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config + ) + case AppMode.CHAT: + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) + case AppMode.COMPLETION: + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config + ) + case _: + raise ValueError("Invalid app mode") return app_config diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 8f365c7c51..662e0410f9 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1417,16 +1417,17 @@ class WorkflowService: self._validate_human_input_node_data(node_data) def validate_features_structure(self, app_model: App, features: dict): - if app_model.mode == AppMode.ADVANCED_CHAT: - return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - elif app_model.mode == AppMode.WORKFLOW: - return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - else: - raise ValueError(f"Invalid app mode: {app_model.mode}") + match app_model.mode: + case AppMode.ADVANCED_CHAT: + return AdvancedChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + case AppMode.WORKFLOW: + return WorkflowAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + case _: + raise ValueError(f"Invalid app mode: {app_model.mode}") def _validate_human_input_node_data(self, node_data: dict) -> None: """