From 719ec4f2927a1be91564d15386cbed225d2d12e6 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Tue, 28 Apr 2026 03:54:44 +0100 Subject: [PATCH] refactor: share OpenAI-compatible image provider --- .../.generated/plugin-sdk-api-baseline.sha256 | 4 +- docs/plugins/sdk-migration.md | 2 +- docs/plugins/sdk-subpaths.md | 2 +- .../image-generation-provider.test.ts | 8 + .../deepinfra/image-generation-provider.ts | 185 +++--------- .../litellm/image-generation-provider.test.ts | 11 + .../litellm/image-generation-provider.ts | 141 +++------ .../openai/image-generation-provider.ts | 34 +-- .../xai/image-generation-provider.test.ts | 53 ++++ extensions/xai/image-generation-provider.ts | 146 +++------- .../openai-compatible-image-provider.test.ts | 251 ++++++++++++++++ .../openai-compatible-image-provider.ts | 274 ++++++++++++++++++ src/plugin-sdk/image-generation.ts | 8 + 13 files changed, 741 insertions(+), 378 deletions(-) create mode 100644 src/image-generation/openai-compatible-image-provider.test.ts create mode 100644 src/image-generation/openai-compatible-image-provider.ts diff --git a/docs/.generated/plugin-sdk-api-baseline.sha256 b/docs/.generated/plugin-sdk-api-baseline.sha256 index 34b1cc96718..77c81248ed9 100644 --- a/docs/.generated/plugin-sdk-api-baseline.sha256 +++ b/docs/.generated/plugin-sdk-api-baseline.sha256 @@ -1,2 +1,2 @@ -343a555f212dd5ebf26dccbefff1cb4b56a08e4dcc2c801ac7ab5fb98973192a plugin-sdk-api-baseline.json -02aaccbe13f261de2d41fcb4270fc9ae70b931966089e56d21a4ebc8e80c8821 plugin-sdk-api-baseline.jsonl +5a77a53e9d48f4b683838a3993e583db8037edf836e8e13a209cc0ad1c89d809 plugin-sdk-api-baseline.json +d31c5887a379c48a714e91d00a7d7fde7dff353319aee38a98babfdb753f49fb plugin-sdk-api-baseline.jsonl diff --git a/docs/plugins/sdk-migration.md b/docs/plugins/sdk-migration.md index db61c32c6fa..de64c2dd8e2 100644 --- a/docs/plugins/sdk-migration.md +++ b/docs/plugins/sdk-migration.md @@ -486,7 +486,7 @@ releases. | `plugin-sdk/speech-core` | Shared speech core | Speech provider types, registry, directives, normalization | | `plugin-sdk/realtime-transcription` | Realtime transcription helpers | Provider types, registry helpers, and shared WebSocket session helper | | `plugin-sdk/realtime-voice` | Realtime voice helpers | Provider types, registry/resolution helpers, and bridge session helpers | - | `plugin-sdk/image-generation` | Image-generation helpers | Image generation provider types plus image asset/data URL helpers | + | `plugin-sdk/image-generation` | Image-generation helpers | Image generation provider types plus image asset/data URL helpers and the OpenAI-compatible image provider builder | | `plugin-sdk/image-generation-core` | Shared image-generation core | Image-generation types, failover, auth, and registry helpers | | `plugin-sdk/music-generation` | Music-generation helpers | Music-generation provider/request/result types | | `plugin-sdk/music-generation-core` | Shared music-generation core | Music-generation types, failover helpers, provider lookup, and model-ref parsing | diff --git a/docs/plugins/sdk-subpaths.md b/docs/plugins/sdk-subpaths.md index f7b93493164..e06062d54e5 100644 --- a/docs/plugins/sdk-subpaths.md +++ b/docs/plugins/sdk-subpaths.md @@ -262,7 +262,7 @@ For the plugin authoring guide, see [Plugin SDK overview](/plugins/sdk-overview) | `plugin-sdk/speech-core` | Shared speech provider types, registry, directive, normalization, and speech helper exports | | `plugin-sdk/realtime-transcription` | Realtime transcription provider types, registry helpers, and shared WebSocket session helper | | `plugin-sdk/realtime-voice` | Realtime voice provider types and registry helpers | - | `plugin-sdk/image-generation` | Image generation provider types plus image asset/data URL helpers | + | `plugin-sdk/image-generation` | Image generation provider types plus image asset/data URL helpers and the OpenAI-compatible image provider builder | | `plugin-sdk/image-generation-core` | Shared image-generation types, failover, auth, and registry helpers | | `plugin-sdk/music-generation` | Music generation provider/request/result types | | `plugin-sdk/music-generation-core` | Shared music-generation types, failover helpers, provider lookup, and model-ref parsing | diff --git a/extensions/deepinfra/image-generation-provider.test.ts b/extensions/deepinfra/image-generation-provider.test.ts index d2e08d2236d..1cd108eacbc 100644 --- a/extensions/deepinfra/image-generation-provider.test.ts +++ b/extensions/deepinfra/image-generation-provider.test.ts @@ -7,11 +7,17 @@ const { postMultipartRequestMock, resolveApiKeyForProviderMock, resolveProviderHttpRequestConfigMock, + createProviderOperationDeadlineMock, + resolveProviderOperationTimeoutMsMock, } = vi.hoisted(() => ({ assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), postJsonRequestMock: vi.fn(), postMultipartRequestMock: vi.fn(), resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "deepinfra-key" })), + createProviderOperationDeadlineMock: vi.fn((params: Record) => params), + resolveProviderOperationTimeoutMsMock: vi.fn( + (params: Record) => params.defaultTimeoutMs, + ), resolveProviderHttpRequestConfigMock: vi.fn((params: Record) => ({ baseUrl: params.baseUrl ?? params.defaultBaseUrl ?? "https://api.deepinfra.com/v1/openai", allowPrivateNetwork: false, @@ -26,9 +32,11 @@ vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ vi.mock("openclaw/plugin-sdk/provider-http", () => ({ assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, + createProviderOperationDeadline: createProviderOperationDeadlineMock, postJsonRequest: postJsonRequestMock, postMultipartRequest: postMultipartRequestMock, resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, + resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock, sanitizeConfiguredModelProviderRequest: vi.fn((request) => request), })); diff --git a/extensions/deepinfra/image-generation-provider.ts b/extensions/deepinfra/image-generation-provider.ts index 7ac4de1c9de..c4caa0ecf49 100644 --- a/extensions/deepinfra/image-generation-provider.ts +++ b/extensions/deepinfra/image-generation-provider.ts @@ -1,18 +1,8 @@ -import type { OpenClawConfig } from "openclaw/plugin-sdk/config-types"; -import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation"; import { + createOpenAiCompatibleImageGenerationProvider, imageSourceUploadFileName, - parseOpenAiCompatibleImageResponse, + type ImageGenerationProvider, } from "openclaw/plugin-sdk/image-generation"; -import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; -import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; -import { - assertOkOrThrowHttpError, - postJsonRequest, - postMultipartRequest, - resolveProviderHttpRequestConfig, - sanitizeConfiguredModelProviderRequest, -} from "openclaw/plugin-sdk/provider-http"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; import { DEEPINFRA_BASE_URL, @@ -26,35 +16,12 @@ import { const DEEPINFRA_IMAGE_SIZES = ["512x512", "1024x1024", "1024x1792", "1792x1024"] as const; const MAX_DEEPINFRA_INPUT_IMAGES = 1; -type DeepInfraProviderConfig = NonNullable< - NonNullable["providers"] ->[string]; - -type DeepInfraImageApiResponse = { - data?: Array<{ - b64_json?: string; - revised_prompt?: string; - url?: string; - }>; -}; - -function resolveDeepInfraProviderConfig( - cfg: OpenClawConfig | undefined, -): DeepInfraProviderConfig | undefined { - return cfg?.models?.providers?.deepinfra; -} - export function buildDeepInfraImageGenerationProvider(): ImageGenerationProvider { - return { + return createOpenAiCompatibleImageGenerationProvider({ id: "deepinfra", label: "DeepInfra", defaultModel: DEFAULT_DEEPINFRA_IMAGE_MODEL, models: [...DEEPINFRA_IMAGE_MODELS], - isConfigured: ({ agentDir }) => - isProviderApiKeyConfigured({ - provider: "deepinfra", - agentDir, - }), capabilities: { generate: { maxCount: 4, @@ -74,111 +41,49 @@ export function buildDeepInfraImageGenerationProvider(): ImageGenerationProvider sizes: [...DEEPINFRA_IMAGE_SIZES], }, }, - async generateImage(req) { - const inputImages = req.inputImages ?? []; - const isEdit = inputImages.length > 0; - if (inputImages.length > MAX_DEEPINFRA_INPUT_IMAGES) { - throw new Error("DeepInfra image editing supports one reference image."); + defaultBaseUrl: DEEPINFRA_BASE_URL, + normalizeModel: normalizeDeepInfraModelRef, + resolveBaseUrl: ({ providerConfig }) => + normalizeDeepInfraBaseUrl(providerConfig?.baseUrl, DEEPINFRA_BASE_URL), + resolveAllowPrivateNetwork: () => false, + useConfiguredRequest: true, + resolveCount: ({ req, mode }) => (mode === "edit" ? 1 : (req.count ?? 1)), + buildGenerateRequest: ({ req, model, count }) => ({ + kind: "json", + body: { + model, + prompt: req.prompt, + n: count, + size: normalizeOptionalString(req.size) ?? DEFAULT_DEEPINFRA_IMAGE_SIZE, + response_format: "b64_json", + }, + }), + buildEditRequest: ({ req, inputImages, model, count }) => { + const image = inputImages[0]; + if (!image) { + throw new Error("DeepInfra image edit missing reference image."); } - const auth = await resolveApiKeyForProvider({ - provider: "deepinfra", - cfg: req.cfg, - agentDir: req.agentDir, - store: req.authStore, - }); - if (!auth.apiKey) { - throw new Error("DeepInfra API key missing"); - } - - const providerConfig = resolveDeepInfraProviderConfig(req.cfg); - const resolvedBaseUrl = normalizeDeepInfraBaseUrl( - providerConfig?.baseUrl, - DEEPINFRA_BASE_URL, + const form = new FormData(); + form.set("model", model); + form.set("prompt", req.prompt); + form.set("n", String(count)); + form.set("size", normalizeOptionalString(req.size) ?? DEFAULT_DEEPINFRA_IMAGE_SIZE); + form.set("response_format", "b64_json"); + const mimeType = normalizeOptionalString(image.mimeType) ?? "image/png"; + form.append( + "image", + new Blob([new Uint8Array(image.buffer)], { type: mimeType }), + imageSourceUploadFileName({ image, index: 0 }), ); - const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } = - resolveProviderHttpRequestConfig({ - baseUrl: resolvedBaseUrl, - defaultBaseUrl: DEEPINFRA_BASE_URL, - allowPrivateNetwork: false, - request: sanitizeConfiguredModelProviderRequest(providerConfig?.request), - defaultHeaders: { - Authorization: `Bearer ${auth.apiKey}`, - }, - provider: "deepinfra", - capability: "image", - transport: "http", - }); - - const model = normalizeDeepInfraModelRef(req.model, DEFAULT_DEEPINFRA_IMAGE_MODEL); - const count = isEdit ? 1 : (req.count ?? 1); - const size = normalizeOptionalString(req.size) ?? DEFAULT_DEEPINFRA_IMAGE_SIZE; - const endpoint = isEdit ? "images/edits" : "images/generations"; - const request = isEdit - ? (() => { - const form = new FormData(); - form.set("model", model); - form.set("prompt", req.prompt); - form.set("n", String(count)); - form.set("size", size); - form.set("response_format", "b64_json"); - const image = inputImages[0]; - if (!image) { - throw new Error("DeepInfra image edit missing reference image."); - } - const mimeType = normalizeOptionalString(image.mimeType) ?? "image/png"; - form.append( - "image", - new Blob([new Uint8Array(image.buffer)], { type: mimeType }), - imageSourceUploadFileName({ image, index: 0 }), - ); - const multipartHeaders = new Headers(headers); - multipartHeaders.delete("Content-Type"); - return postMultipartRequest({ - url: `${baseUrl}/${endpoint}`, - headers: multipartHeaders, - body: form, - timeoutMs: req.timeoutMs, - fetchFn: fetch, - allowPrivateNetwork, - dispatcherPolicy, - }); - })() - : postJsonRequest({ - url: `${baseUrl}/${endpoint}`, - headers: new Headers({ - ...Object.fromEntries(headers.entries()), - "Content-Type": "application/json", - }), - body: { - model, - prompt: req.prompt, - n: count, - size, - response_format: "b64_json", - }, - timeoutMs: req.timeoutMs, - fetchFn: fetch, - allowPrivateNetwork, - dispatcherPolicy, - }); - - const { response, release } = await request; - try { - await assertOkOrThrowHttpError( - response, - isEdit ? "DeepInfra image edit failed" : "DeepInfra image generation failed", - ); - const images = parseOpenAiCompatibleImageResponse( - (await response.json()) as DeepInfraImageApiResponse, - { defaultMimeType: "image/jpeg", sniffMimeType: true }, - ); - if (images.length === 0) { - throw new Error("DeepInfra image response did not include generated image data"); - } - return { images, model }; - } finally { - await release(); - } + return { kind: "multipart", form }; }, - }; + response: { defaultMimeType: "image/jpeg", sniffMimeType: true }, + tooManyInputImagesError: "DeepInfra image editing supports one reference image.", + missingApiKeyError: "DeepInfra API key missing", + emptyResponseError: "DeepInfra image response did not include generated image data", + failureLabels: { + generate: "DeepInfra image generation failed", + edit: "DeepInfra image edit failed", + }, + }); } diff --git a/extensions/litellm/image-generation-provider.test.ts b/extensions/litellm/image-generation-provider.test.ts index 5cd835de33e..8ebb3ff18fe 100644 --- a/extensions/litellm/image-generation-provider.test.ts +++ b/extensions/litellm/image-generation-provider.test.ts @@ -4,19 +4,27 @@ import { buildLitellmImageGenerationProvider } from "./image-generation-provider const { resolveApiKeyForProviderMock, postJsonRequestMock, + postMultipartRequestMock, assertOkOrThrowHttpErrorMock, + createProviderOperationDeadlineMock, resolveProviderHttpRequestConfigMock, + resolveProviderOperationTimeoutMsMock, sanitizeConfiguredModelProviderRequestMock, } = vi.hoisted(() => ({ resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "litellm-key" })), postJsonRequestMock: vi.fn(), + postMultipartRequestMock: vi.fn(), assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), + createProviderOperationDeadlineMock: vi.fn((params: Record) => params), resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ baseUrl: params.baseUrl ?? params.defaultBaseUrl, allowPrivateNetwork: Boolean(params.allowPrivateNetwork ?? params.request?.allowPrivateNetwork), headers: new Headers(params.defaultHeaders), dispatcherPolicy: undefined as unknown, })), + resolveProviderOperationTimeoutMsMock: vi.fn( + (params: Record) => params.defaultTimeoutMs, + ), sanitizeConfiguredModelProviderRequestMock: vi.fn((request) => request), })); @@ -26,8 +34,11 @@ vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ vi.mock("openclaw/plugin-sdk/provider-http", () => ({ assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, + createProviderOperationDeadline: createProviderOperationDeadlineMock, postJsonRequest: postJsonRequestMock, + postMultipartRequest: postMultipartRequestMock, resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, + resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock, sanitizeConfiguredModelProviderRequest: sanitizeConfiguredModelProviderRequestMock, })); diff --git a/extensions/litellm/image-generation-provider.ts b/extensions/litellm/image-generation-provider.ts index c12685799b8..e60b7ff7921 100644 --- a/extensions/litellm/image-generation-provider.ts +++ b/extensions/litellm/image-generation-provider.ts @@ -1,17 +1,10 @@ import type { OpenClawConfig } from "openclaw/plugin-sdk/config-types"; -import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation"; import { - parseOpenAiCompatibleImageResponse, + createOpenAiCompatibleImageGenerationProvider, + type ImageGenerationProvider, + type ImageGenerationSourceImage, toImageDataUrl, } from "openclaw/plugin-sdk/image-generation"; -import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; -import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; -import { - assertOkOrThrowHttpError, - postJsonRequest, - resolveProviderHttpRequestConfig, - sanitizeConfiguredModelProviderRequest, -} from "openclaw/plugin-sdk/provider-http"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; import { LITELLM_BASE_URL } from "./onboard.js"; @@ -46,6 +39,10 @@ function resolveConfiguredLitellmBaseUrl(cfg: OpenClawConfig | undefined): strin return normalizeOptionalString(resolveLitellmProviderConfig(cfg)?.baseUrl) ?? LITELLM_BASE_URL; } +function imageToDataUrl(image: ImageGenerationSourceImage): string { + return toImageDataUrl({ buffer: image.buffer, mimeType: image.mimeType }); +} + // LiteLLM's default proxy is loopback. Auto-enable private-network access only // for loopback-style hosts; LAN/custom private endpoints should use the // explicit models.providers.litellm.request.allowPrivateNetwork opt-in. @@ -85,24 +82,12 @@ function shouldAutoAllowPrivateLitellmEndpoint(baseUrl: string): boolean { } } -type LitellmImageApiResponse = { - data?: Array<{ - b64_json?: string; - revised_prompt?: string; - }>; -}; - export function buildLitellmImageGenerationProvider(): ImageGenerationProvider { - return { + return createOpenAiCompatibleImageGenerationProvider({ id: "litellm", label: "LiteLLM", defaultModel: DEFAULT_LITELLM_IMAGE_MODEL, models: [DEFAULT_LITELLM_IMAGE_MODEL], - isConfigured: ({ agentDir }) => - isProviderApiKeyConfigured({ - provider: "litellm", - agentDir, - }), capabilities: { generate: { maxCount: 4, @@ -122,84 +107,36 @@ export function buildLitellmImageGenerationProvider(): ImageGenerationProvider { sizes: [...LITELLM_SUPPORTED_SIZES], }, }, - async generateImage(req) { - const inputImages = req.inputImages ?? []; - const isEdit = inputImages.length > 0; - const auth = await resolveApiKeyForProvider({ - provider: "litellm", - cfg: req.cfg, - agentDir: req.agentDir, - store: req.authStore, - }); - if (!auth.apiKey) { - throw new Error("LiteLLM API key missing"); - } - const providerConfig = resolveLitellmProviderConfig(req.cfg); - const resolvedBaseUrl = resolveConfiguredLitellmBaseUrl(req.cfg); - const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } = - resolveProviderHttpRequestConfig({ - baseUrl: resolvedBaseUrl, - defaultBaseUrl: LITELLM_BASE_URL, - allowPrivateNetwork: shouldAutoAllowPrivateLitellmEndpoint(resolvedBaseUrl) - ? true - : undefined, - request: sanitizeConfiguredModelProviderRequest(providerConfig?.request), - defaultHeaders: { - Authorization: `Bearer ${auth.apiKey}`, - }, - provider: "litellm", - capability: "image", - transport: "http", - }); - - const model = req.model || DEFAULT_LITELLM_IMAGE_MODEL; - const count = req.count ?? 1; - const size = req.size ?? DEFAULT_SIZE; - - const jsonHeaders = new Headers(headers); - jsonHeaders.set("Content-Type", "application/json"); - const endpoint = isEdit ? "images/edits" : "images/generations"; - const body = isEdit - ? { - model, - prompt: req.prompt, - n: count, - size, - images: inputImages.map((image) => ({ - image_url: toImageDataUrl(image), - })), - } - : { - model, - prompt: req.prompt, - n: count, - size, - }; - const { response, release } = await postJsonRequest({ - url: `${baseUrl}/${endpoint}`, - headers: jsonHeaders, - body, - timeoutMs: req.timeoutMs, - fetchFn: fetch, - allowPrivateNetwork, - dispatcherPolicy, - }); - try { - await assertOkOrThrowHttpError( - response, - isEdit ? "LiteLLM image edit failed" : "LiteLLM image generation failed", - ); - - const data = (await response.json()) as LitellmImageApiResponse; - const images = parseOpenAiCompatibleImageResponse(data); - - return { - images, - model, - }; - } finally { - await release(); - } + defaultBaseUrl: LITELLM_BASE_URL, + resolveBaseUrl: ({ req }) => resolveConfiguredLitellmBaseUrl(req.cfg), + resolveAllowPrivateNetwork: ({ baseUrl }) => + shouldAutoAllowPrivateLitellmEndpoint(baseUrl) ? true : undefined, + useConfiguredRequest: true, + buildGenerateRequest: ({ req, model, count }) => ({ + kind: "json", + body: { + model, + prompt: req.prompt, + n: count, + size: req.size ?? DEFAULT_SIZE, + }, + }), + buildEditRequest: ({ req, inputImages, model, count }) => ({ + kind: "json", + body: { + model, + prompt: req.prompt, + n: count, + size: req.size ?? DEFAULT_SIZE, + images: inputImages.map((image) => ({ + image_url: imageToDataUrl(image), + })), + }, + }), + missingApiKeyError: "LiteLLM API key missing", + failureLabels: { + generate: "LiteLLM image generation failed", + edit: "LiteLLM image edit failed", }, - }; + }); } diff --git a/extensions/openai/image-generation-provider.ts b/extensions/openai/image-generation-provider.ts index 02cd072c0db..690cd67881d 100644 --- a/extensions/openai/image-generation-provider.ts +++ b/extensions/openai/image-generation-provider.ts @@ -4,7 +4,10 @@ import type { ImageGenerationOutputFormat, ImageGenerationProvider, ImageGenerationResult, - ImageGenerationSourceImage, +} from "openclaw/plugin-sdk/image-generation"; +import { + parseOpenAiCompatibleImageResponse, + toImageDataUrl, } from "openclaw/plugin-sdk/image-generation"; import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core"; import { resolveClosestSize } from "openclaw/plugin-sdk/media-generation-runtime"; @@ -388,11 +391,6 @@ function inferImageUploadFileName(params: { return `image-${params.index + 1}.${ext}`; } -function toOpenAIDataUrl(image: ImageGenerationSourceImage): string { - const mimeType = image.mimeType?.trim() || DEFAULT_OUTPUT_MIME; - return `data:${mimeType};base64,${Buffer.from(image.buffer).toString("base64")}`; -} - async function readResponseBodyText(response: Response): Promise { if (!response.body) { const text = await response.text(); @@ -643,7 +641,7 @@ async function generateOpenAICodexImage(params: { { type: "input_text", text: req.prompt }, ...inputImages.map((image) => ({ type: "input_image", - image_url: toOpenAIDataUrl(image), + image_url: toImageDataUrl({ buffer: image.buffer, mimeType: image.mimeType }), detail: "auto", })), ]; @@ -876,21 +874,13 @@ export function buildOpenAIImageGenerationProvider(): ImageGenerationProvider { const data = (await response.json()) as OpenAIImageApiResponse; const output = resolveOutputMime(req.outputFormat); - const images = (data.data ?? []) - .map((entry, index) => { - if (!entry.b64_json) { - return null; - } - return Object.assign( - { - buffer: Buffer.from(entry.b64_json, `base64`), - mimeType: output.mimeType, - fileName: `image-${index + 1}.${output.extension}`, - }, - entry.revised_prompt ? { revisedPrompt: entry.revised_prompt } : {}, - ); - }) - .filter((entry): entry is NonNullable => entry !== null); + const images = parseOpenAiCompatibleImageResponse(data, { + defaultMimeType: output.mimeType, + }).map((image, index) => + Object.assign(image, { + fileName: `image-${index + 1}.${output.extension}`, + }), + ); return { images, diff --git a/extensions/xai/image-generation-provider.test.ts b/extensions/xai/image-generation-provider.test.ts index e1dfb5f0878..a574f3b4589 100644 --- a/extensions/xai/image-generation-provider.test.ts +++ b/extensions/xai/image-generation-provider.test.ts @@ -4,13 +4,16 @@ import { buildXaiImageGenerationProvider } from "./image-generation-provider.js" const { resolveApiKeyForProviderMock, postJsonRequestMock, + postMultipartRequestMock, assertOkOrThrowHttpErrorMock, resolveProviderHttpRequestConfigMock, createProviderOperationDeadlineMock, resolveProviderOperationTimeoutMsMock, + sanitizeConfiguredModelProviderRequestMock, } = vi.hoisted(() => ({ resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "xai-key" })), postJsonRequestMock: vi.fn(), + postMultipartRequestMock: vi.fn(), assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), resolveProviderHttpRequestConfigMock: vi.fn((params: Record) => ({ baseUrl: params.baseUrl ?? params.defaultBaseUrl ?? "https://api.x.ai/v1", @@ -25,6 +28,7 @@ const { resolveProviderOperationTimeoutMsMock: vi.fn( (params: Record) => params.defaultTimeoutMs ?? 60000, ), + sanitizeConfiguredModelProviderRequestMock: vi.fn((request) => request), })); vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ @@ -35,8 +39,10 @@ vi.mock("openclaw/plugin-sdk/provider-http", () => ({ assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, createProviderOperationDeadline: createProviderOperationDeadlineMock, postJsonRequest: postJsonRequestMock, + postMultipartRequest: postMultipartRequestMock, resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock, + sanitizeConfiguredModelProviderRequest: sanitizeConfiguredModelProviderRequestMock, })); vi.mock("openclaw/plugin-sdk/text-runtime", () => ({ @@ -54,6 +60,7 @@ describe("xai image generation provider", () => { resolveProviderHttpRequestConfigMock.mockClear(); createProviderOperationDeadlineMock.mockClear(); resolveProviderOperationTimeoutMsMock.mockClear(); + sanitizeConfiguredModelProviderRequestMock.mockClear(); }); it("builds provider with correct models, default, and capabilities", () => { @@ -174,4 +181,50 @@ describe("xai image generation provider", () => { }), ); }); + + it("uses the plural xAI images payload for multiple edit inputs", async () => { + postJsonRequestMock.mockResolvedValue({ + response: { + json: async () => ({ + data: [ + { + b64_json: Buffer.from("edited").toString("base64"), + mime_type: "image/png", + }, + ], + }), + }, + release: vi.fn(async () => {}), + }); + + const provider = buildXaiImageGenerationProvider(); + await provider.generateImage({ + provider: "xai", + model: "grok-imagine-image", + prompt: "Combine the references", + inputImages: [ + { buffer: Buffer.from("first"), mimeType: "image/png" }, + { buffer: Buffer.from("second"), mimeType: "image/jpeg" }, + ], + cfg: {}, + } as any); + + expect(postJsonRequestMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: expect.stringContaining("/images/edits"), + body: expect.objectContaining({ + images: [ + { + url: expect.stringContaining("data:image/png;base64,"), + type: "image_url", + }, + { + url: expect.stringContaining("data:image/jpeg;base64,"), + type: "image_url", + }, + ], + }), + }), + ); + }); }); diff --git a/extensions/xai/image-generation-provider.ts b/extensions/xai/image-generation-provider.ts index 96969c4e40c..c98ec614083 100644 --- a/extensions/xai/image-generation-provider.ts +++ b/extensions/xai/image-generation-provider.ts @@ -1,21 +1,12 @@ import type { ImageGenerationProvider, ImageGenerationRequest, - ImageGenerationResult, + ImageGenerationSourceImage, } from "openclaw/plugin-sdk/image-generation"; import { - parseOpenAiCompatibleImageResponse, + createOpenAiCompatibleImageGenerationProvider, toImageDataUrl, } from "openclaw/plugin-sdk/image-generation"; -import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; -import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; -import { - assertOkOrThrowHttpError, - createProviderOperationDeadline, - postJsonRequest, - resolveProviderHttpRequestConfig, - resolveProviderOperationTimeoutMs, -} from "openclaw/plugin-sdk/provider-http"; import { normalizeOptionalLowercaseString, normalizeOptionalString, @@ -26,16 +17,8 @@ const DEFAULT_TIMEOUT_MS = 60_000; const XAI_SUPPORTED_ASPECT_RATIOS = ["1:1", "16:9", "9:16", "4:3", "3:4", "2:3", "3:2"] as const; -type XaiImageApiResponse = { - data?: Array<{ - b64_json?: string; - mime_type?: string; - revised_prompt?: string; - }>; -}; - function resolveImageForEdit( - input: { url?: string; buffer?: Buffer; mimeType?: string } | undefined, + input: (ImageGenerationSourceImage & { url?: string }) | undefined, ): string { if (!input) { throw new Error("xAI image edit requires an input image."); @@ -50,44 +33,42 @@ function resolveImageForEdit( return toImageDataUrl({ buffer: input.buffer, mimeType: input.mimeType }); } -function isEdit(req: ImageGenerationRequest): boolean { - return (req.inputImages?.length ?? 0) > 0; -} - function resolveXaiImageBaseUrl(req: ImageGenerationRequest): string { return normalizeOptionalString(req.cfg?.models?.providers?.xai?.baseUrl) ?? XAI_BASE_URL; } -function buildBody(req: ImageGenerationRequest, edit: boolean): Record { - const model = normalizeOptionalString(req.model) ?? XAI_DEFAULT_IMAGE_MODEL; - const count = req.count ?? 1; +function buildBody(params: { + req: ImageGenerationRequest; + inputImages: ImageGenerationSourceImage[]; + model: string; + count: number; +}): Record { const body: Record = { - model, - prompt: req.prompt, - n: Math.min(count, 4), + model: params.model, + prompt: params.req.prompt, + n: Math.min(params.count, 4), response_format: "b64_json" as const, }; - const aspect = normalizeOptionalString(req.aspectRatio); + const aspect = normalizeOptionalString(params.req.aspectRatio); if (aspect && (XAI_SUPPORTED_ASPECT_RATIOS as readonly string[]).includes(aspect)) { body.aspect_ratio = aspect; } - const resolution = normalizeOptionalLowercaseString(req.resolution); + const resolution = normalizeOptionalLowercaseString(params.req.resolution); if (resolution) { body.resolution = resolution; } - if (edit) { - const inputImages = req.inputImages ?? []; - if (inputImages.length > 1) { - body.images = inputImages.map((input) => ({ + if (params.inputImages.length > 0) { + if (params.inputImages.length > 1) { + body.images = params.inputImages.map((input) => ({ url: resolveImageForEdit(input), type: "image_url", })); } else { body.image = { - url: resolveImageForEdit(inputImages[0]), + url: resolveImageForEdit(params.inputImages[0]), type: "image_url", }; } @@ -97,16 +78,11 @@ function buildBody(req: ImageGenerationRequest, edit: boolean): Record - isProviderApiKeyConfigured({ - provider: "xai", - agentDir, - }), capabilities: { generate: { maxCount: 4, @@ -127,72 +103,22 @@ export function buildXaiImageGenerationProvider(): ImageGenerationProvider { resolutions: ["1K", "2K"], }, }, - async generateImage(req: ImageGenerationRequest): Promise { - const edit = isEdit(req); - const auth = await resolveApiKeyForProvider({ - provider: "xai", - cfg: req.cfg, - agentDir: req.agentDir, - store: req.authStore, - }); - if (!auth.apiKey) { - throw new Error("xAI API key missing"); - } - - const fetchFn = fetch; - const deadline = createProviderOperationDeadline({ - timeoutMs: req.timeoutMs, - label: edit ? "xAI image edit" : "xAI image generation", - }); - const { - baseUrl: resolvedBaseUrl, - allowPrivateNetwork, - headers, - dispatcherPolicy, - } = resolveProviderHttpRequestConfig({ - baseUrl: resolveXaiImageBaseUrl(req), - defaultBaseUrl: XAI_BASE_URL, - allowPrivateNetwork: false, - defaultHeaders: { - Authorization: `Bearer ${auth.apiKey}`, - "Content-Type": "application/json", - }, - provider: "xai", - capability: "image", - transport: "http", - }); - - const body = buildBody(req, edit); - const endpoint = edit ? "/images/edits" : "/images/generations"; - const { response, release } = await postJsonRequest({ - url: `${resolvedBaseUrl}${endpoint}`, - headers, - body, - timeoutMs: resolveProviderOperationTimeoutMs({ - deadline, - defaultTimeoutMs: DEFAULT_TIMEOUT_MS, - }), - fetchFn, - allowPrivateNetwork, - dispatcherPolicy, - }); - - try { - await assertOkOrThrowHttpError( - response, - edit ? "xAI image edit failed" : "xAI image generation failed", - ); - - const payload = (await response.json()) as XaiImageApiResponse; - const images = parseOpenAiCompatibleImageResponse(payload); - - return { - images, - model: normalizeOptionalString(req.model) ?? XAI_DEFAULT_IMAGE_MODEL, - }; - } finally { - await release(); - } + defaultBaseUrl: XAI_BASE_URL, + resolveBaseUrl: ({ req }) => resolveXaiImageBaseUrl(req), + resolveAllowPrivateNetwork: () => false, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + buildGenerateRequest: ({ req, inputImages, model, count }) => ({ + kind: "json", + body: buildBody({ req, inputImages, model, count }), + }), + buildEditRequest: ({ req, inputImages, model, count }) => ({ + kind: "json", + body: buildBody({ req, inputImages, model, count }), + }), + missingApiKeyError: "xAI API key missing", + failureLabels: { + generate: "xAI image generation failed", + edit: "xAI image edit failed", }, - }; + }); } diff --git a/src/image-generation/openai-compatible-image-provider.test.ts b/src/image-generation/openai-compatible-image-provider.test.ts new file mode 100644 index 00000000000..90d2a80ba3e --- /dev/null +++ b/src/image-generation/openai-compatible-image-provider.test.ts @@ -0,0 +1,251 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + createOpenAiCompatibleImageGenerationProvider, + type OpenAiCompatibleImageProviderOptions, +} from "./openai-compatible-image-provider.js"; + +const { + assertOkOrThrowHttpErrorMock, + createProviderOperationDeadlineMock, + isProviderApiKeyConfiguredMock, + postJsonRequestMock, + postMultipartRequestMock, + resolveApiKeyForProviderMock, + resolveProviderHttpRequestConfigMock, + resolveProviderOperationTimeoutMsMock, + sanitizeConfiguredModelProviderRequestMock, +} = vi.hoisted(() => ({ + assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), + createProviderOperationDeadlineMock: vi.fn((params: Record) => ({ + timeoutMs: params.timeoutMs, + label: params.label, + })), + isProviderApiKeyConfiguredMock: vi.fn(() => true), + postJsonRequestMock: vi.fn(), + postMultipartRequestMock: vi.fn(), + resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "provider-key" })), + resolveProviderHttpRequestConfigMock: vi.fn((params: Record) => { + const request = + typeof params.request === "object" && params.request !== null + ? (params.request as Record) + : undefined; + return { + baseUrl: params.baseUrl ?? params.defaultBaseUrl, + allowPrivateNetwork: Boolean(params.allowPrivateNetwork ?? request?.allowPrivateNetwork), + headers: new Headers(params.defaultHeaders as HeadersInit | undefined), + dispatcherPolicy: request ? { request } : undefined, + }; + }), + resolveProviderOperationTimeoutMsMock: vi.fn( + (params: Record) => params.defaultTimeoutMs, + ), + sanitizeConfiguredModelProviderRequestMock: vi.fn((request) => request), +})); + +vi.mock("openclaw/plugin-sdk/provider-auth", () => ({ + isProviderApiKeyConfigured: isProviderApiKeyConfiguredMock, +})); + +vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ + resolveApiKeyForProvider: resolveApiKeyForProviderMock, +})); + +vi.mock("openclaw/plugin-sdk/provider-http", () => ({ + assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, + createProviderOperationDeadline: createProviderOperationDeadlineMock, + postJsonRequest: postJsonRequestMock, + postMultipartRequest: postMultipartRequestMock, + resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, + resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock, + sanitizeConfiguredModelProviderRequest: sanitizeConfiguredModelProviderRequestMock, +})); + +function createProvider(overrides: Partial = {}) { + return createOpenAiCompatibleImageGenerationProvider({ + id: "sample", + label: "Sample", + defaultModel: "sample-image", + models: ["sample-image"], + defaultBaseUrl: "https://sample.example/v1", + capabilities: { + generate: { maxCount: 4, supportsSize: true }, + edit: { enabled: true, maxCount: 1, maxInputImages: 1, supportsSize: true }, + geometry: { sizes: ["1024x1024"] }, + }, + useConfiguredRequest: true, + buildGenerateRequest: ({ req, model, count }) => ({ + kind: "json", + body: { + model, + prompt: req.prompt, + n: count, + size: req.size ?? "1024x1024", + response_format: "b64_json", + }, + }), + buildEditRequest: ({ req, inputImages, model, count }) => { + const form = new FormData(); + form.set("model", model); + form.set("prompt", req.prompt); + form.set("n", String(count)); + form.append( + "image", + new Blob([new Uint8Array(inputImages[0]?.buffer ?? Buffer.alloc(0))], { + type: inputImages[0]?.mimeType ?? "image/png", + }), + inputImages[0]?.fileName ?? "image.png", + ); + return { kind: "multipart", form }; + }, + ...overrides, + }); +} + +function mockGeneratedResponse() { + const release = vi.fn(async () => {}); + const payload = { + data: [ + { + b64_json: Buffer.from("image-bytes").toString("base64"), + revised_prompt: "revised", + }, + ], + }; + postJsonRequestMock.mockResolvedValue({ response: { json: async () => payload }, release }); + postMultipartRequestMock.mockResolvedValue({ response: { json: async () => payload }, release }); + return release; +} + +describe("OpenAI-compatible image provider helper", () => { + afterEach(() => { + assertOkOrThrowHttpErrorMock.mockClear(); + createProviderOperationDeadlineMock.mockClear(); + isProviderApiKeyConfiguredMock.mockClear(); + postJsonRequestMock.mockReset(); + postMultipartRequestMock.mockReset(); + resolveApiKeyForProviderMock.mockReset(); + resolveApiKeyForProviderMock.mockResolvedValue({ apiKey: "provider-key" }); + resolveProviderHttpRequestConfigMock.mockClear(); + resolveProviderOperationTimeoutMsMock.mockClear(); + sanitizeConfiguredModelProviderRequestMock.mockClear(); + }); + + it("builds provider metadata and delegates configuration checks", () => { + const provider = createProvider(); + + expect(provider.id).toBe("sample"); + expect(provider.label).toBe("Sample"); + expect(provider.defaultModel).toBe("sample-image"); + expect(provider.isConfigured?.({ agentDir: "/tmp/agent" })).toBe(true); + expect(isProviderApiKeyConfiguredMock).toHaveBeenCalledWith({ + provider: "sample", + agentDir: "/tmp/agent", + }); + }); + + it("posts JSON generation requests and parses OpenAI-compatible image data", async () => { + const release = mockGeneratedResponse(); + const provider = createProvider(); + + const result = await provider.generateImage({ + provider: "sample", + model: "custom-image", + prompt: "draw a square", + count: 2, + size: "512x512", + cfg: { + models: { + providers: { + sample: { + baseUrl: "https://sample.example/v1/", + request: { allowPrivateNetwork: true }, + }, + }, + }, + }, + } as never); + + expect(resolveApiKeyForProviderMock).toHaveBeenCalledWith( + expect.objectContaining({ provider: "sample" }), + ); + expect(sanitizeConfiguredModelProviderRequestMock).toHaveBeenCalledWith({ + allowPrivateNetwork: true, + }); + expect(postJsonRequestMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://sample.example/v1/images/generations", + allowPrivateNetwork: true, + dispatcherPolicy: { request: { allowPrivateNetwork: true } }, + body: { + model: "custom-image", + prompt: "draw a square", + n: 2, + size: "512x512", + response_format: "b64_json", + }, + }), + ); + const headers = postJsonRequestMock.mock.calls[0]?.[0].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/json"); + expect(result).toMatchObject({ + model: "custom-image", + images: [{ mimeType: "image/png", fileName: "image-1.png", revisedPrompt: "revised" }], + }); + expect(release).toHaveBeenCalledOnce(); + }); + + it("posts multipart edit requests without forwarding a content-type header", async () => { + mockGeneratedResponse(); + const provider = createProvider(); + + await provider.generateImage({ + provider: "sample", + model: "sample-image", + prompt: "edit it", + inputImages: [{ buffer: Buffer.from("source"), mimeType: "image/png" }], + cfg: {} as never, + }); + + expect(postMultipartRequestMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://sample.example/v1/images/edits", + body: expect.any(FormData), + }), + ); + const headers = postMultipartRequestMock.mock.calls[0]?.[0].headers as Headers; + expect(headers.has("Content-Type")).toBe(false); + }); + + it("honors default operation timeouts and empty-response errors", async () => { + postJsonRequestMock.mockResolvedValue({ + response: { json: async () => ({ data: [] }) }, + release: vi.fn(async () => {}), + }); + const provider = createProvider({ + defaultTimeoutMs: 60_000, + emptyResponseError: "Sample response missing image data", + }); + + await expect( + provider.generateImage({ + provider: "sample", + model: "sample-image", + prompt: "empty", + timeoutMs: 123, + cfg: {} as never, + }), + ).rejects.toThrow("Sample response missing image data"); + + expect(createProviderOperationDeadlineMock).toHaveBeenCalledWith({ + timeoutMs: 123, + label: "Sample image generation", + }); + expect(resolveProviderOperationTimeoutMsMock).toHaveBeenCalledWith({ + deadline: { timeoutMs: 123, label: "Sample image generation" }, + defaultTimeoutMs: 60_000, + }); + expect(postJsonRequestMock).toHaveBeenCalledWith( + expect.objectContaining({ timeoutMs: 60_000 }), + ); + }); +}); diff --git a/src/image-generation/openai-compatible-image-provider.ts b/src/image-generation/openai-compatible-image-provider.ts new file mode 100644 index 00000000000..89fef5bc27b --- /dev/null +++ b/src/image-generation/openai-compatible-image-provider.ts @@ -0,0 +1,274 @@ +import type { OpenClawConfig } from "openclaw/plugin-sdk/config-types"; +import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; +import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { + assertOkOrThrowHttpError, + createProviderOperationDeadline, + postJsonRequest, + postMultipartRequest, + resolveProviderHttpRequestConfig, + resolveProviderOperationTimeoutMs, + sanitizeConfiguredModelProviderRequest, +} from "openclaw/plugin-sdk/provider-http"; +import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; +import { + parseOpenAiCompatibleImageResponse, + type OpenAiCompatibleImageResponsePayload, +} from "./image-assets.js"; +import type { + ImageGenerationProvider, + ImageGenerationProviderCapabilities, + ImageGenerationRequest, + ImageGenerationResult, + ImageGenerationSourceImage, +} from "./types.js"; + +type ModelProviderConfig = NonNullable["providers"]>[string]; + +export type OpenAiCompatibleImageRequestMode = "generate" | "edit"; + +export type OpenAiCompatibleImageProviderRequestParams = { + req: ImageGenerationRequest; + inputImages: ImageGenerationSourceImage[]; + model: string; + count: number; + mode: OpenAiCompatibleImageRequestMode; +}; + +export type OpenAiCompatibleImageProviderRequestBody = + | { kind: "json"; body: Record } + | { kind: "multipart"; form: FormData }; + +export type OpenAiCompatibleImageProviderOptions = { + id: string; + label: string; + defaultModel: string; + models: readonly string[]; + capabilities: ImageGenerationProviderCapabilities; + defaultBaseUrl: string; + providerConfigKey?: string; + normalizeModel?: (model: string | undefined, fallback: string) => string; + resolveBaseUrl?: (params: { + req: ImageGenerationRequest; + providerConfig?: ModelProviderConfig; + defaultBaseUrl: string; + }) => string; + resolveAllowPrivateNetwork?: (params: { + baseUrl: string; + req: ImageGenerationRequest; + providerConfig?: ModelProviderConfig; + }) => boolean | undefined; + useConfiguredRequest?: boolean; + defaultTimeoutMs?: number; + resolveCount?: (params: { + req: ImageGenerationRequest; + mode: OpenAiCompatibleImageRequestMode; + }) => number; + buildGenerateRequest: ( + params: OpenAiCompatibleImageProviderRequestParams & { mode: "generate" }, + ) => OpenAiCompatibleImageProviderRequestBody; + buildEditRequest: ( + params: OpenAiCompatibleImageProviderRequestParams & { mode: "edit" }, + ) => OpenAiCompatibleImageProviderRequestBody; + response?: { + defaultMimeType?: string; + fileNamePrefix?: string; + sniffMimeType?: boolean; + }; + missingApiKeyError?: string; + tooManyInputImagesError?: string; + missingInputImageError?: string; + emptyResponseError?: string; + failureLabels?: { + generate?: string; + edit?: string; + }; +}; + +function readProviderConfig( + cfg: OpenClawConfig | undefined, + providerConfigKey: string, +): ModelProviderConfig | undefined { + return cfg?.models?.providers?.[providerConfigKey]; +} + +function resolveDefaultModel(model: string | undefined, fallback: string): string { + return normalizeOptionalString(model) ?? fallback; +} + +function trimTrailingSlash(value: string): string { + return value.replace(/\/+$/u, ""); +} + +function appendImagesPath(baseUrl: string, mode: OpenAiCompatibleImageRequestMode): string { + return `${trimTrailingSlash(baseUrl)}/images/${mode === "edit" ? "edits" : "generations"}`; +} + +function resolveRequestTimeoutMs(params: { + options: OpenAiCompatibleImageProviderOptions; + req: ImageGenerationRequest; + mode: OpenAiCompatibleImageRequestMode; +}): number | undefined { + if (params.options.defaultTimeoutMs === undefined) { + return params.req.timeoutMs; + } + const label = + params.mode === "edit" + ? (params.options.failureLabels?.edit ?? `${params.options.label} image edit`) + : (params.options.failureLabels?.generate ?? `${params.options.label} image generation`); + const deadline = createProviderOperationDeadline({ + timeoutMs: params.req.timeoutMs, + label, + }); + return resolveProviderOperationTimeoutMs({ + deadline, + defaultTimeoutMs: params.options.defaultTimeoutMs, + }); +} + +export function createOpenAiCompatibleImageGenerationProvider( + options: OpenAiCompatibleImageProviderOptions, +): ImageGenerationProvider { + const providerConfigKey = options.providerConfigKey ?? options.id; + const normalizeModel = options.normalizeModel ?? resolveDefaultModel; + const resolveCount = + options.resolveCount ?? + (({ req }) => { + return req.count ?? 1; + }); + + return { + id: options.id, + label: options.label, + defaultModel: options.defaultModel, + models: [...options.models], + isConfigured: ({ agentDir }) => + isProviderApiKeyConfigured({ + provider: options.id, + agentDir, + }), + capabilities: options.capabilities, + async generateImage(req): Promise { + const inputImages = req.inputImages ?? []; + const mode: OpenAiCompatibleImageRequestMode = inputImages.length > 0 ? "edit" : "generate"; + const maxInputImages = options.capabilities.edit.maxInputImages; + if (mode === "edit" && !options.capabilities.edit.enabled) { + throw new Error(`${options.label} image editing is not supported.`); + } + if (mode === "edit" && maxInputImages !== undefined && inputImages.length > maxInputImages) { + throw new Error( + options.tooManyInputImagesError ?? + `${options.label} image editing supports up to ${maxInputImages} reference image${ + maxInputImages === 1 ? "" : "s" + }.`, + ); + } + if (mode === "edit" && inputImages.length === 0) { + throw new Error( + options.missingInputImageError ?? `${options.label} image edit missing reference image.`, + ); + } + + const auth = await resolveApiKeyForProvider({ + provider: options.id, + cfg: req.cfg, + agentDir: req.agentDir, + store: req.authStore, + }); + if (!auth.apiKey) { + throw new Error(options.missingApiKeyError ?? `${options.label} API key missing`); + } + + const providerConfig = readProviderConfig(req.cfg, providerConfigKey); + const resolvedBaseUrl = + options.resolveBaseUrl?.({ + req, + providerConfig, + defaultBaseUrl: options.defaultBaseUrl, + }) ?? + normalizeOptionalString(providerConfig?.baseUrl) ?? + options.defaultBaseUrl; + const allowPrivateNetwork = options.resolveAllowPrivateNetwork?.({ + baseUrl: resolvedBaseUrl, + req, + providerConfig, + }); + const { + baseUrl, + allowPrivateNetwork: resolvedAllowPrivateNetwork, + headers, + dispatcherPolicy, + } = resolveProviderHttpRequestConfig({ + baseUrl: resolvedBaseUrl, + defaultBaseUrl: options.defaultBaseUrl, + allowPrivateNetwork, + request: options.useConfiguredRequest + ? sanitizeConfiguredModelProviderRequest(providerConfig?.request) + : undefined, + defaultHeaders: { + Authorization: `Bearer ${auth.apiKey}`, + }, + provider: options.id, + capability: "image", + transport: "http", + }); + + const model = normalizeModel(req.model, options.defaultModel); + const count = resolveCount({ req, mode }); + const requestParams = { req, inputImages, model, count, mode }; + const requestBody = + mode === "edit" + ? options.buildEditRequest({ ...requestParams, mode }) + : options.buildGenerateRequest({ ...requestParams, mode }); + const timeoutMs = resolveRequestTimeoutMs({ options, req, mode }); + const request = + requestBody.kind === "multipart" + ? postMultipartRequest({ + url: appendImagesPath(baseUrl, mode), + headers: (() => { + const multipartHeaders = new Headers(headers); + multipartHeaders.delete("Content-Type"); + return multipartHeaders; + })(), + body: requestBody.form, + timeoutMs, + fetchFn: fetch, + allowPrivateNetwork: resolvedAllowPrivateNetwork, + dispatcherPolicy, + }) + : postJsonRequest({ + url: appendImagesPath(baseUrl, mode), + headers: (() => { + const jsonHeaders = new Headers(headers); + jsonHeaders.set("Content-Type", "application/json"); + return jsonHeaders; + })(), + body: requestBody.body, + timeoutMs, + fetchFn: fetch, + allowPrivateNetwork: resolvedAllowPrivateNetwork, + dispatcherPolicy, + }); + + const { response, release } = await request; + try { + await assertOkOrThrowHttpError( + response, + mode === "edit" + ? (options.failureLabels?.edit ?? `${options.label} image edit failed`) + : (options.failureLabels?.generate ?? `${options.label} image generation failed`), + ); + const images = parseOpenAiCompatibleImageResponse( + (await response.json()) as OpenAiCompatibleImageResponsePayload, + options.response, + ); + if (options.emptyResponseError && images.length === 0) { + throw new Error(options.emptyResponseError); + } + return { images, model }; + } finally { + await release(); + } + }, + }; +} diff --git a/src/plugin-sdk/image-generation.ts b/src/plugin-sdk/image-generation.ts index 0c1391e63d2..1f08cc0db40 100644 --- a/src/plugin-sdk/image-generation.ts +++ b/src/plugin-sdk/image-generation.ts @@ -1,5 +1,13 @@ // Public image-generation helpers and types for provider plugins. +export { + createOpenAiCompatibleImageGenerationProvider, + type OpenAiCompatibleImageProviderOptions, + type OpenAiCompatibleImageProviderRequestBody, + type OpenAiCompatibleImageProviderRequestParams, + type OpenAiCompatibleImageRequestMode, +} from "../image-generation/openai-compatible-image-provider.js"; + export { generatedImageAssetFromBase64, generatedImageAssetFromDataUrl,