feat: support disabling image generation globally

- Added `disable-image-generation` configuration flag to disable the `image_generation` tool globally.
- Updated payload handling to remove `image_generation` tools from request payload arrays when the flag is enabled.
- Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to return 404 when the feature is disabled.
- Enhanced configuration diff logging to track changes for the `disable-image-generation` flag.
- Added accompanying unit tests for the new feature in payload helpers and image handler logic.
This commit is contained in:
Luis Pater
2026-04-30 03:42:27 +08:00
parent 359ec30d0c
commit e3e60f914b
11 changed files with 284 additions and 126 deletions

View File

@@ -90,6 +90,10 @@ max-retry-interval: 30
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
disable-cooling: false
# When true, disable the built-in image_generation tool globally.
# The server will stop injecting image_generation and will also remove it from request payload tools arrays.
disable-image-generation: false
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
# When > 0, overrides the default worker count (16).
# auth-auto-refresh-workers: 16

View File

@@ -1013,6 +1013,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
}
if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration {
log.Infof("disable-image-generation updated: %t -> %t", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration)
}
applySignatureCacheConfig(oldCfg, cfg)
if s.handlers != nil && s.handlers.AuthManager != nil {

View File

@@ -610,6 +610,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false
cfg.DisableImageGeneration = false
cfg.Pprof.Enable = false
cfg.Pprof.Addr = DefaultPprofAddr
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient

View File

@@ -9,6 +9,12 @@ type SDKConfig struct {
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// DisableImageGeneration disables the built-in image_generation tool when true.
// When enabled, the server will avoid injecting image_generation into request payloads,
// will remove any existing image_generation tool entries from tools arrays, and will
// return 404 for /v1/images/generations and /v1/images/edits.
DisableImageGeneration bool `yaml:"disable-image-generation" json:"disable-image-generation"`
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
// Default is false for safety; when false, /v1internal:* requests are rejected.
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`

View File

@@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream")
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel)
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)

View File

@@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if cfg == nil || len(payload) == 0 {
return payload
}
rules := cfg.Payload
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
return payload
}
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return payload
}
candidates := payloadModelCandidates(model, requestedModel)
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
rules := cfg.Payload
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
if hasPayloadRules {
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model != "" || requestedModel != "" {
candidates := payloadModelCandidates(model, requestedModel)
source := original
if len(source) == 0 {
source = payload
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
if cfg.DisableImageGeneration {
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
}
return out
}
@@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
return r + "." + p
}
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
if len(payload) == 0 {
return payload
}
toolType = strings.TrimSpace(toolType)
if toolType == "" {
return payload
}
toolsPath := buildPayloadPath(root, "tools")
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
}
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
tools := gjson.GetBytes(payload, toolsPath)
if !tools.Exists() || !tools.IsArray() {
return payload
}
removed := false
filtered := []byte(`[]`)
for _, tool := range tools.Array() {
if tool.Get("type").String() == toolType {
removed = true
continue
}
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
if errSet != nil {
continue
}
filtered = updated
}
if !removed {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
if errSet != nil {
return payload
}
return updated
}
func payloadRawValue(value any) ([]byte, bool) {
if value == nil {
return nil, false

View File

@@ -0,0 +1,50 @@
package helps
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/tidwall/gjson"
)
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
cfg := &config.Config{
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
}
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
tools := gjson.GetBytes(out, "tools")
if !tools.Exists() || !tools.IsArray() {
t.Fatalf("expected tools array, got %v", tools.Type)
}
arr := tools.Array()
if len(arr) != 1 {
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
}
if got := arr[0].Get("type").String(); got != "function" {
t.Fatalf("expected remaining tool type=function, got %q", got)
}
}
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
cfg := &config.Config{
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
}
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
tools := gjson.GetBytes(out, "request.tools")
if !tools.Exists() || !tools.IsArray() {
t.Fatalf("expected request.tools array, got %v", tools.Type)
}
arr := tools.Array()
if len(arr) != 1 {
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
}
if got := arr[0].Get("type").String(); got != "web_search" {
t.Fatalf("expected remaining tool type=web_search, got %q", got)
}
}

View File

@@ -42,6 +42,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.DisableCooling != newCfg.DisableCooling {
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
}
if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration {
changes = append(changes, fmt.Sprintf("disable-image-generation: %t -> %t", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration))
}
if oldCfg.RequestLog != newCfg.RequestLog {
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
}

View File

@@ -279,6 +279,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
APIKeys: []string{" key-1 ", "key-2"},
ForceModelPrefix: true,
NonStreamKeepAliveInterval: 5,
DisableImageGeneration: true,
},
}
@@ -287,6 +288,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
expectContains(t, details, "logging-to-file: false -> true")
expectContains(t, details, "usage-statistics-enabled: false -> true")
expectContains(t, details, "disable-cooling: false -> true")
expectContains(t, details, "disable-image-generation: false -> true")
expectContains(t, details, "request-log: false -> true")
expectContains(t, details, "request-retry: 1 -> 2")
expectContains(t, details, "max-retry-credentials: 1 -> 3")
@@ -403,9 +405,10 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
ProxyURL: "http://new-proxy",
APIKeys: []string{"keyB"},
RequestLog: true,
ProxyURL: "http://new-proxy",
APIKeys: []string{"keyB"},
DisableImageGeneration: true,
},
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
OpenAICompatibility: []config.OpenAICompatibility{
@@ -431,6 +434,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
expectContains(t, changes, "logging-to-file: false -> true")
expectContains(t, changes, "usage-statistics-enabled: false -> true")
expectContains(t, changes, "disable-cooling: false -> true")
expectContains(t, changes, "disable-image-generation: false -> true")
expectContains(t, changes, "request-retry: 1 -> 2")
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
expectContains(t, changes, "max-retry-interval: 1 -> 3")

View File

@@ -198,6 +198,11 @@ func parseBoolField(raw string, fallback bool) bool {
}
func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration {
c.AbortWithStatus(http.StatusNotFound)
return
}
rawJSON, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
@@ -281,6 +286,11 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
}
func (h *OpenAIAPIHandler) ImagesEdits(c *gin.Context) {
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration {
c.AbortWithStatus(http.StatusNotFound)
return
}
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
if strings.HasPrefix(contentType, "application/json") {
h.imagesEditsFromJSON(c)

View File

@@ -10,6 +10,8 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
@@ -93,3 +95,27 @@ func TestImagesEditsMultipartRejectsUnsupportedModel(t *testing.T) {
assertUnsupportedImagesModelResponse(t, resp, "gpt-5.4-mini")
}
func TestImagesGenerations_DisableImageGeneration_Returns404(t *testing.T) {
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil)
handler := NewOpenAIAPIHandler(base)
body := strings.NewReader(`{"prompt":"draw a square"}`)
resp := performImagesEndpointRequest(t, imagesGenerationsPath, "application/json", body, handler.ImagesGenerations)
if resp.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String())
}
}
func TestImagesEdits_DisableImageGeneration_Returns404(t *testing.T) {
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil)
handler := NewOpenAIAPIHandler(base)
body := strings.NewReader(`{"prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`)
resp := performImagesEndpointRequest(t, imagesEditsPath, "application/json", body, handler.ImagesEdits)
if resp.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String())
}
}