diff --git a/config.example.yaml b/config.example.yaml index 24e3d99c8..772a6416e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -90,6 +90,10 @@ max-retry-interval: 30 # When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states). disable-cooling: false +# When true, disable the built-in image_generation tool globally. +# The server will stop injecting image_generation and will also remove it from request payload tools arrays. +disable-image-generation: false + # Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh). # When > 0, overrides the default worker count (16). # auth-auto-refresh-workers: 16 diff --git a/internal/api/server.go b/internal/api/server.go index f817ac309..c414e10a1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1013,6 +1013,10 @@ func (s *Server) UpdateClients(cfg *config.Config) { auth.SetQuotaCooldownDisabled(cfg.DisableCooling) } + if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration { + log.Infof("disable-image-generation updated: %t -> %t", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration) + } + applySignatureCacheConfig(oldCfg, cfg) if s.handlers != nil && s.handlers.AuthManager != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 1ee7aed53..c30593f67 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -610,6 +610,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.ErrorLogsMaxFiles = 10 cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false + cfg.DisableImageGeneration = false cfg.Pprof.Enable = false cfg.Pprof.Addr = DefaultPprofAddr cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index aa27526d1..752f53aa9 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -9,6 +9,12 @@ type SDKConfig struct { // ProxyURL is the URL of an optional proxy server to use for outbound requests. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + // DisableImageGeneration disables the built-in image_generation tool when true. + // When enabled, the server will avoid injecting image_generation into request payloads, + // will remove any existing image_generation tool entries from tools arrays, and will + // return 404 for /v1/images/generations and /v1/images/edits. + DisableImageGeneration bool `yaml:"disable-image-generation" json:"disable-image-generation"` + // EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled. // Default is false for safety; when false, /v1internal:* requests are rejected. EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"` diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 2a01c7ac0..1948beac4 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body, _ = sjson.DeleteBytes(body, "safety_identifier") body, _ = sjson.DeleteBytes(body, "stream_options") body = normalizeCodexInstructions(body) - body = ensureImageGenerationTool(body, baseModel, auth) + if e.cfg == nil || !e.cfg.DisableImageGeneration { + body = ensureImageGenerationTool(body, baseModel, auth) + } url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.DeleteBytes(body, "stream") body = normalizeCodexInstructions(body) - body = ensureImageGenerationTool(body, baseModel, auth) + if e.cfg == nil || !e.cfg.DisableImageGeneration { + body = ensureImageGenerationTool(body, baseModel, auth) + } url := strings.TrimSuffix(baseURL, "/") + "/responses/compact" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body, _ = sjson.DeleteBytes(body, "stream_options") body, _ = sjson.SetBytes(body, "model", baseModel) body = normalizeCodexInstructions(body) - body = ensureImageGenerationTool(body, baseModel, auth) + if e.cfg == nil || !e.cfg.DisableImageGeneration { + body = ensureImageGenerationTool(body, baseModel, auth) + } url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) diff --git a/internal/runtime/executor/helps/payload_helpers.go b/internal/runtime/executor/helps/payload_helpers.go index 73514c2dd..b868d445a 100644 --- a/internal/runtime/executor/helps/payload_helpers.go +++ b/internal/runtime/executor/helps/payload_helpers.go @@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string if cfg == nil || len(payload) == 0 { return payload } - rules := cfg.Payload - if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 { - return payload - } - model = strings.TrimSpace(model) - requestedModel = strings.TrimSpace(requestedModel) - if model == "" && requestedModel == "" { - return payload - } - candidates := payloadModelCandidates(model, requestedModel) out := payload - source := original - if len(source) == 0 { - source = payload - } - appliedDefaults := make(map[string]struct{}) - // Apply default rules: first write wins per field across all matching rules. - for i := range rules.Default { - rule := &rules.Default[i] - if !payloadModelRulesMatch(rule.Models, protocol, candidates) { - continue - } - for path, value := range rule.Params { - fullPath := buildPayloadPath(root, path) - if fullPath == "" { - continue + + rules := cfg.Payload + hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0 + if hasPayloadRules { + model = strings.TrimSpace(model) + requestedModel = strings.TrimSpace(requestedModel) + if model != "" || requestedModel != "" { + candidates := payloadModelCandidates(model, requestedModel) + source := original + if len(source) == 0 { + source = payload } - if gjson.GetBytes(source, fullPath).Exists() { - continue + appliedDefaults := make(map[string]struct{}) + // Apply default rules: first write wins per field across all matching rules. + for i := range rules.Default { + rule := &rules.Default[i] + if !payloadModelRulesMatch(rule.Models, protocol, candidates) { + continue + } + for path, value := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + if gjson.GetBytes(source, fullPath).Exists() { + continue + } + if _, ok := appliedDefaults[fullPath]; ok { + continue + } + updated, errSet := sjson.SetBytes(out, fullPath, value) + if errSet != nil { + continue + } + out = updated + appliedDefaults[fullPath] = struct{}{} + } } - if _, ok := appliedDefaults[fullPath]; ok { - continue + // Apply default raw rules: first write wins per field across all matching rules. + for i := range rules.DefaultRaw { + rule := &rules.DefaultRaw[i] + if !payloadModelRulesMatch(rule.Models, protocol, candidates) { + continue + } + for path, value := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + if gjson.GetBytes(source, fullPath).Exists() { + continue + } + if _, ok := appliedDefaults[fullPath]; ok { + continue + } + rawValue, ok := payloadRawValue(value) + if !ok { + continue + } + updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) + if errSet != nil { + continue + } + out = updated + appliedDefaults[fullPath] = struct{}{} + } } - updated, errSet := sjson.SetBytes(out, fullPath, value) - if errSet != nil { - continue + // Apply override rules: last write wins per field across all matching rules. + for i := range rules.Override { + rule := &rules.Override[i] + if !payloadModelRulesMatch(rule.Models, protocol, candidates) { + continue + } + for path, value := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + updated, errSet := sjson.SetBytes(out, fullPath, value) + if errSet != nil { + continue + } + out = updated + } + } + // Apply override raw rules: last write wins per field across all matching rules. + for i := range rules.OverrideRaw { + rule := &rules.OverrideRaw[i] + if !payloadModelRulesMatch(rule.Models, protocol, candidates) { + continue + } + for path, value := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + rawValue, ok := payloadRawValue(value) + if !ok { + continue + } + updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) + if errSet != nil { + continue + } + out = updated + } + } + // Apply filter rules: remove matching paths from payload. + for i := range rules.Filter { + rule := &rules.Filter[i] + if !payloadModelRulesMatch(rule.Models, protocol, candidates) { + continue + } + for _, path := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + updated, errDel := sjson.DeleteBytes(out, fullPath) + if errDel != nil { + continue + } + out = updated + } } - out = updated - appliedDefaults[fullPath] = struct{}{} } } - // Apply default raw rules: first write wins per field across all matching rules. - for i := range rules.DefaultRaw { - rule := &rules.DefaultRaw[i] - if !payloadModelRulesMatch(rule.Models, protocol, candidates) { - continue - } - for path, value := range rule.Params { - fullPath := buildPayloadPath(root, path) - if fullPath == "" { - continue - } - if gjson.GetBytes(source, fullPath).Exists() { - continue - } - if _, ok := appliedDefaults[fullPath]; ok { - continue - } - rawValue, ok := payloadRawValue(value) - if !ok { - continue - } - updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) - if errSet != nil { - continue - } - out = updated - appliedDefaults[fullPath] = struct{}{} - } - } - // Apply override rules: last write wins per field across all matching rules. - for i := range rules.Override { - rule := &rules.Override[i] - if !payloadModelRulesMatch(rule.Models, protocol, candidates) { - continue - } - for path, value := range rule.Params { - fullPath := buildPayloadPath(root, path) - if fullPath == "" { - continue - } - updated, errSet := sjson.SetBytes(out, fullPath, value) - if errSet != nil { - continue - } - out = updated - } - } - // Apply override raw rules: last write wins per field across all matching rules. - for i := range rules.OverrideRaw { - rule := &rules.OverrideRaw[i] - if !payloadModelRulesMatch(rule.Models, protocol, candidates) { - continue - } - for path, value := range rule.Params { - fullPath := buildPayloadPath(root, path) - if fullPath == "" { - continue - } - rawValue, ok := payloadRawValue(value) - if !ok { - continue - } - updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) - if errSet != nil { - continue - } - out = updated - } - } - // Apply filter rules: remove matching paths from payload. - for i := range rules.Filter { - rule := &rules.Filter[i] - if !payloadModelRulesMatch(rule.Models, protocol, candidates) { - continue - } - for _, path := range rule.Params { - fullPath := buildPayloadPath(root, path) - if fullPath == "" { - continue - } - updated, errDel := sjson.DeleteBytes(out, fullPath) - if errDel != nil { - continue - } - out = updated - } + + if cfg.DisableImageGeneration { + out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation") } return out } @@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string { return r + "." + p } +func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte { + if len(payload) == 0 { + return payload + } + toolType = strings.TrimSpace(toolType) + if toolType == "" { + return payload + } + toolsPath := buildPayloadPath(root, "tools") + return removeToolTypeFromToolsArray(payload, toolsPath, toolType) +} + +func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte { + tools := gjson.GetBytes(payload, toolsPath) + if !tools.Exists() || !tools.IsArray() { + return payload + } + removed := false + filtered := []byte(`[]`) + for _, tool := range tools.Array() { + if tool.Get("type").String() == toolType { + removed = true + continue + } + updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw)) + if errSet != nil { + continue + } + filtered = updated + } + if !removed { + return payload + } + updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered) + if errSet != nil { + return payload + } + return updated +} + func payloadRawValue(value any) ([]byte, bool) { if value == nil { return nil, false diff --git a/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go b/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go new file mode 100644 index 000000000..143393dce --- /dev/null +++ b/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go @@ -0,0 +1,50 @@ +package helps + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/tidwall/gjson" +) + +func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) { + cfg := &config.Config{ + SDKConfig: config.SDKConfig{DisableImageGeneration: true}, + } + payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`) + + out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "") + + tools := gjson.GetBytes(out, "tools") + if !tools.Exists() || !tools.IsArray() { + t.Fatalf("expected tools array, got %v", tools.Type) + } + arr := tools.Array() + if len(arr) != 1 { + t.Fatalf("expected 1 tool after removal, got %d", len(arr)) + } + if got := arr[0].Get("type").String(); got != "function" { + t.Fatalf("expected remaining tool type=function, got %q", got) + } +} + +func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) { + cfg := &config.Config{ + SDKConfig: config.SDKConfig{DisableImageGeneration: true}, + } + payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`) + + out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "") + + tools := gjson.GetBytes(out, "request.tools") + if !tools.Exists() || !tools.IsArray() { + t.Fatalf("expected request.tools array, got %v", tools.Type) + } + arr := tools.Array() + if len(arr) != 1 { + t.Fatalf("expected 1 tool after removal, got %d", len(arr)) + } + if got := arr[0].Get("type").String(); got != "web_search" { + t.Fatalf("expected remaining tool type=web_search, got %q", got) + } +} diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index 11f9093e8..15ab5d31f 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -42,6 +42,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.DisableCooling != newCfg.DisableCooling { changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling)) } + if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration { + changes = append(changes, fmt.Sprintf("disable-image-generation: %t -> %t", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration)) + } if oldCfg.RequestLog != newCfg.RequestLog { changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog)) } diff --git a/internal/watcher/diff/config_diff_test.go b/internal/watcher/diff/config_diff_test.go index 2d45aa574..6cfda7b19 100644 --- a/internal/watcher/diff/config_diff_test.go +++ b/internal/watcher/diff/config_diff_test.go @@ -279,6 +279,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) { APIKeys: []string{" key-1 ", "key-2"}, ForceModelPrefix: true, NonStreamKeepAliveInterval: 5, + DisableImageGeneration: true, }, } @@ -287,6 +288,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) { expectContains(t, details, "logging-to-file: false -> true") expectContains(t, details, "usage-statistics-enabled: false -> true") expectContains(t, details, "disable-cooling: false -> true") + expectContains(t, details, "disable-image-generation: false -> true") expectContains(t, details, "request-log: false -> true") expectContains(t, details, "request-retry: 1 -> 2") expectContains(t, details, "max-retry-credentials: 1 -> 3") @@ -403,9 +405,10 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { SecretKey: "", }, SDKConfig: sdkconfig.SDKConfig{ - RequestLog: true, - ProxyURL: "http://new-proxy", - APIKeys: []string{"keyB"}, + RequestLog: true, + ProxyURL: "http://new-proxy", + APIKeys: []string{"keyB"}, + DisableImageGeneration: true, }, OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}}, OpenAICompatibility: []config.OpenAICompatibility{ @@ -431,6 +434,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { expectContains(t, changes, "logging-to-file: false -> true") expectContains(t, changes, "usage-statistics-enabled: false -> true") expectContains(t, changes, "disable-cooling: false -> true") + expectContains(t, changes, "disable-image-generation: false -> true") expectContains(t, changes, "request-retry: 1 -> 2") expectContains(t, changes, "max-retry-credentials: 1 -> 3") expectContains(t, changes, "max-retry-interval: 1 -> 3") diff --git a/sdk/api/handlers/openai/openai_images_handlers.go b/sdk/api/handlers/openai/openai_images_handlers.go index 081547c0f..162bf41eb 100644 --- a/sdk/api/handlers/openai/openai_images_handlers.go +++ b/sdk/api/handlers/openai/openai_images_handlers.go @@ -198,6 +198,11 @@ func parseBoolField(raw string, fallback bool) bool { } func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { + if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration { + c.AbortWithStatus(http.StatusNotFound) + return + } + rawJSON, err := c.GetRawData() if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ @@ -281,6 +286,11 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { } func (h *OpenAIAPIHandler) ImagesEdits(c *gin.Context) { + if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration { + c.AbortWithStatus(http.StatusNotFound) + return + } + contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type"))) if strings.HasPrefix(contentType, "application/json") { h.imagesEditsFromJSON(c) diff --git a/sdk/api/handlers/openai/openai_images_handlers_test.go b/sdk/api/handlers/openai/openai_images_handlers_test.go index 679bec6a2..7604c5d45 100644 --- a/sdk/api/handlers/openai/openai_images_handlers_test.go +++ b/sdk/api/handlers/openai/openai_images_handlers_test.go @@ -10,6 +10,8 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/tidwall/gjson" ) @@ -93,3 +95,27 @@ func TestImagesEditsMultipartRejectsUnsupportedModel(t *testing.T) { assertUnsupportedImagesModelResponse(t, resp, "gpt-5.4-mini") } + +func TestImagesGenerations_DisableImageGeneration_Returns404(t *testing.T) { + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil) + handler := NewOpenAIAPIHandler(base) + body := strings.NewReader(`{"prompt":"draw a square"}`) + + resp := performImagesEndpointRequest(t, imagesGenerationsPath, "application/json", body, handler.ImagesGenerations) + + if resp.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String()) + } +} + +func TestImagesEdits_DisableImageGeneration_Returns404(t *testing.T) { + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil) + handler := NewOpenAIAPIHandler(base) + body := strings.NewReader(`{"prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`) + + resp := performImagesEndpointRequest(t, imagesEditsPath, "application/json", body, handler.ImagesEdits) + + if resp.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String()) + } +}