diff --git a/sdk/api/handlers/openai/openai_images_handlers.go b/sdk/api/handlers/openai/openai_images_handlers.go index 64b41232f..081547c0f 100644 --- a/sdk/api/handlers/openai/openai_images_handlers.go +++ b/sdk/api/handlers/openai/openai_images_handlers.go @@ -24,6 +24,8 @@ import ( const ( defaultImagesMainModel = "gpt-5.4-mini" defaultImagesToolModel = "gpt-image-2" + imagesGenerationsPath = "/v1/images/generations" + imagesEditsPath = "/v1/images/edits" ) type imageCallResult struct { @@ -99,6 +101,28 @@ func (a *sseFrameAccumulator) Flush() [][]byte { return frames } +func isSupportedImagesModel(model string) bool { + baseModel := strings.TrimSpace(model) + if idx := strings.LastIndex(baseModel, "/"); idx >= 0 && idx < len(baseModel)-1 { + baseModel = strings.TrimSpace(baseModel[idx+1:]) + } + return baseModel == defaultImagesToolModel +} + +func rejectUnsupportedImagesModel(c *gin.Context, model string) bool { + if isSupportedImagesModel(model) { + return false + } + + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel), + Type: "invalid_request_error", + }, + }) + return true +} + func mimeTypeFromOutputFormat(outputFormat string) string { if outputFormat == "" { return "image/png" @@ -194,6 +218,14 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { return } + imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) + if imageModel == "" { + imageModel = defaultImagesToolModel + } + if rejectUnsupportedImagesModel(c, imageModel) { + return + } + prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String()) if prompt == "" { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ @@ -205,10 +237,6 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { return } - imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) - if imageModel == "" { - imageModel = defaultImagesToolModel - } responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String()) if responseFormat == "" { responseFormat = "b64_json" @@ -283,6 +311,14 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) { return } + imageModel := strings.TrimSpace(c.PostForm("model")) + if imageModel == "" { + imageModel = defaultImagesToolModel + } + if rejectUnsupportedImagesModel(c, imageModel) { + return + } + prompt := strings.TrimSpace(c.PostForm("prompt")) if prompt == "" { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ @@ -340,10 +376,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) { maskDataURL = &dataURL } - imageModel := strings.TrimSpace(c.PostForm("model")) - if imageModel == "" { - imageModel = defaultImagesToolModel - } responseFormat := strings.TrimSpace(c.PostForm("response_format")) if responseFormat == "" { responseFormat = "b64_json" @@ -412,6 +444,14 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) { return } + imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) + if imageModel == "" { + imageModel = defaultImagesToolModel + } + if rejectUnsupportedImagesModel(c, imageModel) { + return + } + prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String()) if prompt == "" { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ @@ -460,10 +500,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) { return } - imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) - if imageModel == "" { - imageModel = defaultImagesToolModel - } responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String()) if responseFormat == "" { responseFormat = "b64_json" diff --git a/sdk/api/handlers/openai/openai_images_handlers_test.go b/sdk/api/handlers/openai/openai_images_handlers_test.go new file mode 100644 index 000000000..679bec6a2 --- /dev/null +++ b/sdk/api/handlers/openai/openai_images_handlers_test.go @@ -0,0 +1,95 @@ +package openai + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +func performImagesEndpointRequest(t *testing.T, endpointPath string, contentType string, body io.Reader, handler gin.HandlerFunc) *httptest.ResponseRecorder { + t.Helper() + + gin.SetMode(gin.TestMode) + router := gin.New() + router.POST(endpointPath, handler) + + req := httptest.NewRequest(http.MethodPost, endpointPath, body) + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + return resp +} + +func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseRecorder, model string) { + t.Helper() + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + + message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String() + expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + "." + if message != expectedMessage { + t.Fatalf("error message = %q, want %q", message, expectedMessage) + } + if errorType := gjson.GetBytes(resp.Body.Bytes(), "error.type").String(); errorType != "invalid_request_error" { + t.Fatalf("error type = %q, want invalid_request_error", errorType) + } +} + +func TestImagesModelValidationAllowsGPTImage2WithOptionalPrefix(t *testing.T) { + for _, model := range []string{"gpt-image-2", "codex/gpt-image-2"} { + if !isSupportedImagesModel(model) { + t.Fatalf("expected %s to be supported", model) + } + } + if isSupportedImagesModel("gpt-5.4-mini") { + t.Fatal("expected gpt-5.4-mini to be rejected") + } +} + +func TestImagesGenerationsRejectsUnsupportedModel(t *testing.T) { + handler := &OpenAIAPIHandler{} + body := strings.NewReader(`{"model":"gpt-5.4-mini","prompt":"draw a square"}`) + + resp := performImagesEndpointRequest(t, imagesGenerationsPath, "application/json", body, handler.ImagesGenerations) + + assertUnsupportedImagesModelResponse(t, resp, "gpt-5.4-mini") +} + +func TestImagesEditsJSONRejectsUnsupportedModel(t *testing.T) { + handler := &OpenAIAPIHandler{} + body := strings.NewReader(`{"model":"gpt-5.4-mini","prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`) + + resp := performImagesEndpointRequest(t, imagesEditsPath, "application/json", body, handler.ImagesEdits) + + assertUnsupportedImagesModelResponse(t, resp, "gpt-5.4-mini") +} + +func TestImagesEditsMultipartRejectsUnsupportedModel(t *testing.T) { + handler := &OpenAIAPIHandler{} + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if err := writer.WriteField("model", "gpt-5.4-mini"); err != nil { + t.Fatalf("write model field: %v", err) + } + if err := writer.WriteField("prompt", "edit this"); err != nil { + t.Fatalf("write prompt field: %v", err) + } + if errClose := writer.Close(); errClose != nil { + t.Fatalf("close multipart writer: %v", errClose) + } + + resp := performImagesEndpointRequest(t, imagesEditsPath, writer.FormDataContentType(), &body, handler.ImagesEdits) + + assertUnsupportedImagesModelResponse(t, resp, "gpt-5.4-mini") +}