mirror of
https://mirror.skon.top/github.com/router-for-me/CLIProxyAPI
synced 2026-04-30 16:20:23 +08:00
feat: support disabling image generation globally
- Added `disable-image-generation` configuration flag to disable the `image_generation` tool globally. - Updated payload handling to remove `image_generation` tools from request payload arrays when the flag is enabled. - Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to return 404 when the feature is disabled. - Enhanced configuration diff logging to track changes for the `disable-image-generation` flag. - Added accompanying unit tests for the new feature in payload helpers and image handler logic.
This commit is contained in:
@@ -90,6 +90,10 @@ max-retry-interval: 30
|
||||
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||
disable-cooling: false
|
||||
|
||||
# When true, disable the built-in image_generation tool globally.
|
||||
# The server will stop injecting image_generation and will also remove it from request payload tools arrays.
|
||||
disable-image-generation: false
|
||||
|
||||
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
|
||||
# When > 0, overrides the default worker count (16).
|
||||
# auth-auto-refresh-workers: 16
|
||||
|
||||
@@ -1013,6 +1013,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration {
|
||||
log.Infof("disable-image-generation updated: %t -> %t", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration)
|
||||
}
|
||||
|
||||
applySignatureCacheConfig(oldCfg, cfg)
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
|
||||
@@ -610,6 +610,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.DisableImageGeneration = false
|
||||
cfg.Pprof.Enable = false
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
|
||||
@@ -9,6 +9,12 @@ type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// DisableImageGeneration disables the built-in image_generation tool when true.
|
||||
// When enabled, the server will avoid injecting image_generation into request payloads,
|
||||
// will remove any existing image_generation tool entries from tools arrays, and will
|
||||
// return 404 for /v1/images/generations and /v1/images/edits.
|
||||
DisableImageGeneration bool `yaml:"disable-image-generation" json:"disable-image-generation"`
|
||||
|
||||
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
||||
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
||||
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
|
||||
|
||||
@@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
|
||||
@@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
rules := cfg.Payload
|
||||
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
|
||||
return payload
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return payload
|
||||
}
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
|
||||
rules := cfg.Payload
|
||||
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
|
||||
if hasPayloadRules {
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model != "" || requestedModel != "" {
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
|
||||
if cfg.DisableImageGeneration {
|
||||
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
|
||||
return r + "." + p
|
||||
}
|
||||
|
||||
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
||||
if len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
toolType = strings.TrimSpace(toolType)
|
||||
if toolType == "" {
|
||||
return payload
|
||||
}
|
||||
toolsPath := buildPayloadPath(root, "tools")
|
||||
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
|
||||
}
|
||||
|
||||
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
|
||||
tools := gjson.GetBytes(payload, toolsPath)
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return payload
|
||||
}
|
||||
removed := false
|
||||
filtered := []byte(`[]`)
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() == toolType {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
filtered = updated
|
||||
}
|
||||
if !removed {
|
||||
return payload
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
|
||||
if errSet != nil {
|
||||
return payload
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func payloadRawValue(value any) ([]byte, bool) {
|
||||
if value == nil {
|
||||
return nil, false
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
||||
}
|
||||
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
|
||||
|
||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
|
||||
|
||||
tools := gjson.GetBytes(out, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
t.Fatalf("expected tools array, got %v", tools.Type)
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
|
||||
}
|
||||
if got := arr[0].Get("type").String(); got != "function" {
|
||||
t.Fatalf("expected remaining tool type=function, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
||||
}
|
||||
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
|
||||
|
||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
|
||||
|
||||
tools := gjson.GetBytes(out, "request.tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
t.Fatalf("expected request.tools array, got %v", tools.Type)
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
|
||||
}
|
||||
if got := arr[0].Get("type").String(); got != "web_search" {
|
||||
t.Fatalf("expected remaining tool type=web_search, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration {
|
||||
changes = append(changes, fmt.Sprintf("disable-image-generation: %t -> %t", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
|
||||
@@ -279,6 +279,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
NonStreamKeepAliveInterval: 5,
|
||||
DisableImageGeneration: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -287,6 +288,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
expectContains(t, details, "logging-to-file: false -> true")
|
||||
expectContains(t, details, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "disable-image-generation: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-credentials: 1 -> 3")
|
||||
@@ -403,9 +405,10 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
DisableImageGeneration: true,
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
@@ -431,6 +434,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
expectContains(t, changes, "logging-to-file: false -> true")
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "disable-image-generation: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
|
||||
@@ -198,6 +198,11 @@ func parseBoolField(raw string, fallback bool) bool {
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
rawJSON, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
@@ -281,6 +286,11 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) ImagesEdits(c *gin.Context) {
|
||||
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
h.imagesEditsFromJSON(c)
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -93,3 +95,27 @@ func TestImagesEditsMultipartRejectsUnsupportedModel(t *testing.T) {
|
||||
|
||||
assertUnsupportedImagesModelResponse(t, resp, "gpt-5.4-mini")
|
||||
}
|
||||
|
||||
func TestImagesGenerations_DisableImageGeneration_Returns404(t *testing.T) {
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil)
|
||||
handler := NewOpenAIAPIHandler(base)
|
||||
body := strings.NewReader(`{"prompt":"draw a square"}`)
|
||||
|
||||
resp := performImagesEndpointRequest(t, imagesGenerationsPath, "application/json", body, handler.ImagesGenerations)
|
||||
|
||||
if resp.Code != http.StatusNotFound {
|
||||
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagesEdits_DisableImageGeneration_Returns404(t *testing.T) {
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil)
|
||||
handler := NewOpenAIAPIHandler(base)
|
||||
body := strings.NewReader(`{"prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`)
|
||||
|
||||
resp := performImagesEndpointRequest(t, imagesEditsPath, "application/json", body, handler.ImagesEdits)
|
||||
|
||||
if resp.Code != http.StatusNotFound {
|
||||
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user