diff --git a/packages/console/app/src/routes/zen/util/handler.ts b/packages/console/app/src/routes/zen/util/handler.ts index 3e191918e5..8c391d590a 100644 --- a/packages/console/app/src/routes/zen/util/handler.ts +++ b/packages/console/app/src/routes/zen/util/handler.ts @@ -106,7 +106,7 @@ export async function handler( const zenData = ZenData.list(opts.modelList) const modelInfo = validateModel(zenData, model) const dataDumper = createDataDumper(sessionId, requestId, projectId) - const trialLimiter = createTrialLimiter(modelInfo.trialProviders, ip) + const trialLimiter = createTrialLimiter(modelInfo.trialProvider, ip) const trialProviders = await trialLimiter?.check() const rateLimiter = createRateLimiter( modelInfo.id, @@ -392,7 +392,7 @@ export async function handler( function validateModel(zenData: ZenData, reqModel: string) { if (!(reqModel in zenData.models)) throw new ModelError(t("zen.api.error.modelNotSupported", { model: reqModel })) - const modelId = reqModel as keyof typeof zenData.models + const modelId = reqModel const modelData = Array.isArray(zenData.models[modelId]) ? zenData.models[modelId].find((model) => opts.format === model.formatFilter) : zenData.models[modelId] diff --git a/packages/console/core/src/model.ts b/packages/console/core/src/model.ts index 3b24394316..b4149373fe 100644 --- a/packages/console/core/src/model.ts +++ b/packages/console/core/src/model.ts @@ -26,7 +26,7 @@ export namespace ZenData { allowAnonymous: z.boolean().optional(), byokProvider: z.enum(["openai", "anthropic", "google"]).optional(), stickyProvider: z.enum(["strict", "prefer"]).optional(), - trialProviders: z.array(z.string()).optional(), + trialProvider: z.string().optional(), trialEnded: z.boolean().optional(), fallbackProvider: z.string().optional(), rateLimit: z.number().optional(), @@ -45,7 +45,7 @@ export namespace ZenData { const ProviderSchema = z.object({ api: z.string(), - apiKey: z.string(), + apiKey: z.union([z.string(), z.record(z.string(), z.string())]), format: FormatSchema.optional(), headerMappings: z.record(z.string(), z.string()).optional(), payloadModifier: z.record(z.string(), z.any()).optional(), @@ -54,7 +54,10 @@ export namespace ZenData { }) const ModelsSchema = z.object({ - models: z.record(z.string(), z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))])), + zenModels: z.record( + z.string(), + z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))]), + ), liteModels: z.record( z.string(), z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))]), @@ -99,10 +102,66 @@ export namespace ZenData { Resource.ZEN_MODELS29.value + Resource.ZEN_MODELS30.value, ) - const { models, liteModels, providers } = ModelsSchema.parse(json) + const { zenModels, liteModels, providers } = ModelsSchema.parse(json) + const compositeProviders = Object.fromEntries( + Object.entries(providers).map(([id, provider]) => [ + id, + typeof provider.apiKey === "string" + ? [{ id: id, key: provider.apiKey }] + : Object.entries(provider.apiKey).map(([kid, key]) => ({ + id: `${id}.${kid}`, + key, + })), + ]), + ) return { - models: modelList === "lite" ? liteModels : models, - providers, + providers: Object.fromEntries( + Object.entries(providers).flatMap(([providerId, provider]) => + compositeProviders[providerId].map((p) => [p.id, { ...provider, apiKey: p.key }]), + ), + ), + models: (() => { + const normalize = (model: z.infer) => { + const composite = model.providers.find((p) => compositeProviders[p.id].length > 1) + if (!composite) + return { + trialProvider: model.trialProvider ? [model.trialProvider] : undefined, + } + + const weightMulti = compositeProviders[composite.id].length + + return { + trialProvider: (() => { + if (!model.trialProvider) return undefined + if (model.trialProvider === composite.id) return compositeProviders[composite.id].map((p) => p.id) + return [model.trialProvider] + })(), + providers: model.providers.flatMap((p) => + p.id === composite.id + ? compositeProviders[p.id].map((sub) => ({ + ...p, + id: sub.id, + weight: p.weight ?? 1, + })) + : [ + { + ...p, + weight: (p.weight ?? 1) * weightMulti, + }, + ], + ), + } + } + + return Object.fromEntries( + Object.entries(modelList === "lite" ? liteModels : zenModels).map(([modelId, model]) => { + const n = Array.isArray(model) + ? model.map((m) => ({ ...m, ...normalize(m) })) + : { ...model, ...normalize(model) } + return [modelId, n] + }), + ) + })(), } }) }