diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7b2e5d8d5..131ac3ea6 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -8,7 +8,6 @@ import ( "context" "crypto/sha256" "encoding/hex" - "encoding/json" "fmt" "io" "net/http" @@ -1463,182 +1462,6 @@ func countCacheControls(payload []byte) int { return count } -func parsePayloadObject(payload []byte) (map[string]any, bool) { - if len(payload) == 0 { - return nil, false - } - var root map[string]any - if err := json.Unmarshal(payload, &root); err != nil { - return nil, false - } - return root, true -} - -func marshalPayloadObject(original []byte, root map[string]any) []byte { - if root == nil { - return original - } - out, err := json.Marshal(root) - if err != nil { - return original - } - return out -} - -func asObject(v any) (map[string]any, bool) { - obj, ok := v.(map[string]any) - return obj, ok -} - -func asArray(v any) ([]any, bool) { - arr, ok := v.([]any) - return arr, ok -} - -func countCacheControlsMap(root map[string]any) int { - count := 0 - - if system, ok := asArray(root["system"]); ok { - for _, item := range system { - if obj, ok := asObject(item); ok { - if _, exists := obj["cache_control"]; exists { - count++ - } - } - } - } - - if tools, ok := asArray(root["tools"]); ok { - for _, item := range tools { - if obj, ok := asObject(item); ok { - if _, exists := obj["cache_control"]; exists { - count++ - } - } - } - } - - if messages, ok := asArray(root["messages"]); ok { - for _, msg := range messages { - msgObj, ok := asObject(msg) - if !ok { - continue - } - content, ok := asArray(msgObj["content"]) - if !ok { - continue - } - for _, item := range content { - if obj, ok := asObject(item); ok { - if _, exists := obj["cache_control"]; exists { - count++ - } - } - } - } - } - - return count -} - -func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool { - ccRaw, exists := obj["cache_control"] - if !exists { - return false - } - cc, ok := asObject(ccRaw) - if !ok { - *seen5m = true - return false - } - ttlRaw, ttlExists := cc["ttl"] - ttl, ttlIsString := ttlRaw.(string) - if !ttlExists || !ttlIsString || ttl != "1h" { - *seen5m = true - return false - } - if *seen5m { - delete(cc, "ttl") - return true - } - return false -} - -func findLastCacheControlIndex(arr []any) int { - last := -1 - for idx, item := range arr { - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists { - last = idx - } - } - return last -} - -func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) { - for idx, item := range arr { - if *excess <= 0 { - return - } - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists && idx != preserveIdx { - delete(obj, "cache_control") - *excess-- - } - } -} - -func stripAllCacheControl(arr []any, excess *int) { - for _, item := range arr { - if *excess <= 0 { - return - } - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists { - delete(obj, "cache_control") - *excess-- - } - } -} - -func stripMessageCacheControl(messages []any, excess *int) { - for _, msg := range messages { - if *excess <= 0 { - return - } - msgObj, ok := asObject(msg) - if !ok { - continue - } - content, ok := asArray(msgObj["content"]) - if !ok { - continue - } - for _, item := range content { - if *excess <= 0 { - return - } - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists { - delete(obj, "cache_control") - *excess-- - } - } - } -} - // normalizeCacheControlTTL ensures cache_control TTL values don't violate the // prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not // appear after a 5m-TTL block anywhere in the evaluation order. @@ -1651,58 +1474,75 @@ func stripMessageCacheControl(messages []any, excess *int) { // Strategy: walk all cache_control blocks in evaluation order. Once a 5m block // is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m). func normalizeCacheControlTTL(payload []byte) []byte { - root, ok := parsePayloadObject(payload) - if !ok { + if len(payload) == 0 || !gjson.ValidBytes(payload) { return payload } + original := payload seen5m := false modified := false - if tools, ok := asArray(root["tools"]); ok { - for _, tool := range tools { - if obj, ok := asObject(tool); ok { - if normalizeTTLForBlock(obj, &seen5m) { - modified = true - } - } + processBlock := func(path string, obj gjson.Result) { + cc := obj.Get("cache_control") + if !cc.Exists() { + return } + if !cc.IsObject() { + seen5m = true + return + } + ttl := cc.Get("ttl") + if ttl.Type != gjson.String || ttl.String() != "1h" { + seen5m = true + return + } + if !seen5m { + return + } + ttlPath := path + ".cache_control.ttl" + updated, errDel := sjson.DeleteBytes(payload, ttlPath) + if errDel != nil { + return + } + payload = updated + modified = true } - if system, ok := asArray(root["system"]); ok { - for _, item := range system { - if obj, ok := asObject(item); ok { - if normalizeTTLForBlock(obj, &seen5m) { - modified = true - } - } - } + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + tools.ForEach(func(idx, item gjson.Result) bool { + processBlock(fmt.Sprintf("tools.%d", int(idx.Int())), item) + return true + }) } - if messages, ok := asArray(root["messages"]); ok { - for _, msg := range messages { - msgObj, ok := asObject(msg) - if !ok { - continue + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + system.ForEach(func(idx, item gjson.Result) bool { + processBlock(fmt.Sprintf("system.%d", int(idx.Int())), item) + return true + }) + } + + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + messages.ForEach(func(msgIdx, msg gjson.Result) bool { + content := msg.Get("content") + if !content.IsArray() { + return true } - content, ok := asArray(msgObj["content"]) - if !ok { - continue - } - for _, item := range content { - if obj, ok := asObject(item); ok { - if normalizeTTLForBlock(obj, &seen5m) { - modified = true - } - } - } - } + content.ForEach(func(itemIdx, item gjson.Result) bool { + processBlock(fmt.Sprintf("messages.%d.content.%d", int(msgIdx.Int()), int(itemIdx.Int())), item) + return true + }) + return true + }) } if !modified { - return payload + return original } - return marshalPayloadObject(payload, root) + return payload } // enforceCacheControlLimit removes excess cache_control blocks from a payload @@ -1722,64 +1562,166 @@ func normalizeCacheControlTTL(payload []byte) []byte { // Phase 4: remaining system blocks (last system). // Phase 5: remaining tool blocks (last tool). func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte { - root, ok := parsePayloadObject(payload) - if !ok { + if len(payload) == 0 || !gjson.ValidBytes(payload) { return payload } - total := countCacheControlsMap(root) + total := countCacheControls(payload) if total <= maxBlocks { return payload } excess := total - maxBlocks - var system []any - if arr, ok := asArray(root["system"]); ok { - system = arr - } - var tools []any - if arr, ok := asArray(root["tools"]); ok { - tools = arr - } - var messages []any - if arr, ok := asArray(root["messages"]); ok { - messages = arr - } - - if len(system) > 0 { - stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess) + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + lastIdx := -1 + system.ForEach(func(idx, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + lastIdx = int(idx.Int()) + } + return true + }) + if lastIdx >= 0 { + system.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + i := int(idx.Int()) + if i == lastIdx { + return true + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("system.%d.cache_control", i) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) + } } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(tools) > 0 { - stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess) + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + lastIdx := -1 + tools.ForEach(func(idx, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + lastIdx = int(idx.Int()) + } + return true + }) + if lastIdx >= 0 { + tools.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + i := int(idx.Int()) + if i == lastIdx { + return true + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("tools.%d.cache_control", i) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) + } } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(messages) > 0 { - stripMessageCacheControl(messages, &excess) + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + messages.ForEach(func(msgIdx, msg gjson.Result) bool { + if excess <= 0 { + return false + } + content := msg.Get("content") + if !content.IsArray() { + return true + } + content.ForEach(func(itemIdx, item gjson.Result) bool { + if excess <= 0 { + return false + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.cache_control", int(msgIdx.Int()), int(itemIdx.Int())) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) + return true + }) } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(system) > 0 { - stripAllCacheControl(system, &excess) + system = gjson.GetBytes(payload, "system") + if system.IsArray() { + system.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("system.%d.cache_control", int(idx.Int())) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(tools) > 0 { - stripAllCacheControl(tools, &excess) + tools = gjson.GetBytes(payload, "tools") + if tools.IsArray() { + tools.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("tools.%d.cache_control", int(idx.Int())) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) } - return marshalPayloadObject(payload, root) + return payload } // injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching. diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 74cec0a35..c6220fe9d 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -965,6 +965,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing. } } +func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) { + payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`) + + out := normalizeCacheControlTTL(payload) + + if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() { + t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block") + } + + outStr := string(out) + idxModel := strings.Index(outStr, `"model"`) + idxMessages := strings.Index(outStr, `"messages"`) + idxTools := strings.Index(outStr, `"tools"`) + idxSystem := strings.Index(outStr, `"system"`) + if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 { + t.Fatalf("failed to locate top-level keys in output: %s", outStr) + } + if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) { + t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out) + } +} + func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { payload := []byte(`{ "tools": [ @@ -994,6 +1016,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) } } +func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) { + payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`) + + out := enforceCacheControlLimit(payload, 4) + + if got := countCacheControls(out); got != 4 { + t.Fatalf("cache_control count = %d, want 4", got) + } + if gjson.GetBytes(out, "tools.0.cache_control").Exists() { + t.Fatalf("tools.0.cache_control should be removed first (non-last tool)") + } + + outStr := string(out) + idxModel := strings.Index(outStr, `"model"`) + idxMessages := strings.Index(outStr, `"messages"`) + idxTools := strings.Index(outStr, `"tools"`) + idxSystem := strings.Index(outStr, `"system"`) + if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 { + t.Fatalf("failed to locate top-level keys in output: %s", outStr) + } + if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) { + t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out) + } +} + func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) { payload := []byte(`{ "tools": [