From 8f0e66b72e6738970f7fe4c251f6e42e5ecbeae1 Mon Sep 17 00:00:00 2001 From: Kai Wang Date: Fri, 3 Apr 2026 17:11:41 +0800 Subject: [PATCH] fix: repair websocket custom tool calls --- ...nai_responses_websocket_toolcall_repair.go | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go index 530aca967..1a5772ec7 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go @@ -266,15 +266,15 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa continue } itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) - switch itemType { - case "function_call_output": + switch { + case isResponsesToolCallOutputType(itemType): callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID == "" { continue } outputPresent[callID] = struct{}{} outputCache.record(sessionKey, callID, item) - case "function_call": + case isResponsesToolCallType(itemType): callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID == "" { continue @@ -293,7 +293,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa continue } itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) - if itemType == "function_call_output" { + if isResponsesToolCallOutputType(itemType) { callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID == "" { // Upstream rejects tool outputs without a call_id; drop it. @@ -325,7 +325,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa // Drop orphaned function_call_output items; upstream rejects transcripts with missing calls. continue } - if itemType != "function_call" { + if !isResponsesToolCallType(itemType) { filtered = append(filtered, item) continue } @@ -376,7 +376,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO return } for _, item := range output.Array() { - if strings.TrimSpace(item.Get("type").String()) != "function_call" { + if !isResponsesToolCallType(item.Get("type").String()) { continue } callID := strings.TrimSpace(item.Get("call_id").String()) @@ -390,7 +390,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO if !item.Exists() || !item.IsObject() { return } - if strings.TrimSpace(item.Get("type").String()) != "function_call" { + if !isResponsesToolCallType(item.Get("type").String()) { return } callID := strings.TrimSpace(item.Get("call_id").String()) @@ -400,3 +400,21 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO cache.record(sessionKey, callID, json.RawMessage(item.Raw)) } } + +func isResponsesToolCallType(itemType string) bool { + switch strings.TrimSpace(itemType) { + case "function_call", "custom_tool_call": + return true + default: + return false + } +} + +func isResponsesToolCallOutputType(itemType string) bool { + switch strings.TrimSpace(itemType) { + case "function_call_output", "custom_tool_call_output": + return true + default: + return false + } +}