mirror of
https://fastgit.cc/https://github.com/anomalyco/opencode
synced 2026-04-21 05:10:58 +08:00
zen: routing logic
This commit is contained in:
@@ -45,6 +45,7 @@ import { LiteData } from "@opencode-ai/console-core/lite.js"
|
||||
import { Resource } from "@opencode-ai/console-resource"
|
||||
import { i18n, type Key } from "~/i18n"
|
||||
import { localeFromRequest } from "~/lib/language"
|
||||
import { createModelTpmLimiter } from "./modelTpmLimiter"
|
||||
|
||||
type ZenData = Awaited<ReturnType<typeof ZenData.list>>
|
||||
type RetryOptions = {
|
||||
@@ -121,6 +122,8 @@ export async function handler(
|
||||
const authInfo = await authenticate(modelInfo, zenApiKey)
|
||||
const billingSource = validateBilling(authInfo, modelInfo)
|
||||
logger.metric({ source: billingSource })
|
||||
const modelTpmLimiter = createModelTpmLimiter(modelInfo.providers)
|
||||
const modelTpmLimits = await modelTpmLimiter?.check()
|
||||
|
||||
const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
|
||||
const providerInfo = selectProvider(
|
||||
@@ -133,6 +136,7 @@ export async function handler(
|
||||
trialProviders,
|
||||
retry,
|
||||
stickyProvider,
|
||||
modelTpmLimits,
|
||||
)
|
||||
validateModelSettings(billingSource, authInfo)
|
||||
updateProviderKey(authInfo, providerInfo)
|
||||
@@ -229,6 +233,7 @@ export async function handler(
|
||||
const usageInfo = providerInfo.normalizeUsage(json.usage)
|
||||
const costInfo = calculateCost(modelInfo, usageInfo)
|
||||
await trialLimiter?.track(usageInfo)
|
||||
await modelTpmLimiter?.track(providerInfo.id, providerInfo.model, usageInfo)
|
||||
await trackUsage(sessionId, billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
|
||||
await reload(billingSource, authInfo, costInfo)
|
||||
json.cost = calculateOccurredCost(billingSource, costInfo)
|
||||
@@ -278,6 +283,7 @@ export async function handler(
|
||||
const usageInfo = providerInfo.normalizeUsage(usage)
|
||||
const costInfo = calculateCost(modelInfo, usageInfo)
|
||||
await trialLimiter?.track(usageInfo)
|
||||
await modelTpmLimiter?.track(providerInfo.id, providerInfo.model, usageInfo)
|
||||
await trackUsage(sessionId, billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
|
||||
await reload(billingSource, authInfo, costInfo)
|
||||
const cost = calculateOccurredCost(billingSource, costInfo)
|
||||
@@ -433,12 +439,16 @@ export async function handler(
|
||||
trialProviders: string[] | undefined,
|
||||
retry: RetryOptions,
|
||||
stickyProvider: string | undefined,
|
||||
modelTpmLimits: Record<string, number> | undefined,
|
||||
) {
|
||||
const modelProvider = (() => {
|
||||
// Byok is top priority b/c if user set their own API key, we should use it
|
||||
// instead of using the sticky provider for the same session
|
||||
if (authInfo?.provider?.credentials) {
|
||||
return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
|
||||
}
|
||||
|
||||
// Always use the same provider for the same session
|
||||
if (stickyProvider) {
|
||||
const provider = modelInfo.providers.find((provider) => provider.id === stickyProvider)
|
||||
if (provider) return provider
|
||||
@@ -451,10 +461,20 @@ export async function handler(
|
||||
}
|
||||
|
||||
if (retry.retryCount !== MAX_FAILOVER_RETRIES) {
|
||||
const providers = modelInfo.providers
|
||||
const allProviders = modelInfo.providers
|
||||
.filter((provider) => !provider.disabled)
|
||||
.filter((provider) => provider.weight !== 0)
|
||||
.filter((provider) => !retry.excludeProviders.includes(provider.id))
|
||||
.flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
|
||||
.filter((provider) => {
|
||||
if (!provider.tpmLimit) return true
|
||||
const usage = modelTpmLimits?.[`${provider.id}/${provider.model}`] ?? 0
|
||||
return usage < provider.tpmLimit * 1_000_000
|
||||
})
|
||||
|
||||
const topPriority = Math.min(...allProviders.map((p) => p.priority))
|
||||
const providers = allProviders
|
||||
.filter((p) => p.priority <= topPriority)
|
||||
.flatMap((provider) => Array<typeof provider>(provider.weight).fill(provider))
|
||||
|
||||
// Use the last 4 characters of session ID to select a provider
|
||||
const identifier = sessionId.length ? sessionId : ip
|
||||
|
||||
49
packages/console/app/src/routes/zen/util/modelTpmLimiter.ts
Normal file
49
packages/console/app/src/routes/zen/util/modelTpmLimiter.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { and, Database, eq, inArray, sql } from "@opencode-ai/console-core/drizzle/index.js"
|
||||
import { ModelRateLimitTable } from "@opencode-ai/console-core/schema/ip.sql.js"
|
||||
import { UsageInfo } from "./provider/provider"
|
||||
|
||||
export function createModelTpmLimiter(providers: { id: string; model: string; tpmLimit?: number }[]) {
|
||||
const keys = providers.filter((p) => p.tpmLimit).map((p) => `${p.id}/${p.model}`)
|
||||
if (keys.length === 0) return
|
||||
|
||||
const yyyyMMddHHmm = new Date(Date.now())
|
||||
.toISOString()
|
||||
.replace(/[^0-9]/g, "")
|
||||
.substring(0, 12)
|
||||
|
||||
return {
|
||||
check: async () => {
|
||||
const data = await Database.use((tx) =>
|
||||
tx
|
||||
.select()
|
||||
.from(ModelRateLimitTable)
|
||||
.where(and(inArray(ModelRateLimitTable.key, keys), eq(ModelRateLimitTable.interval, yyyyMMddHHmm))),
|
||||
)
|
||||
|
||||
// convert to map of model to count
|
||||
return data.reduce(
|
||||
(acc, curr) => {
|
||||
acc[curr.key] = curr.count
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, number>,
|
||||
)
|
||||
},
|
||||
track: async (id: string, model: string, usageInfo: UsageInfo) => {
|
||||
const usage =
|
||||
usageInfo.inputTokens +
|
||||
usageInfo.outputTokens +
|
||||
(usageInfo.reasoningTokens ?? 0) +
|
||||
(usageInfo.cacheReadTokens ?? 0) +
|
||||
(usageInfo.cacheWrite5mTokens ?? 0) +
|
||||
(usageInfo.cacheWrite1hTokens ?? 0)
|
||||
if (usage <= 0) return
|
||||
await Database.use((tx) =>
|
||||
tx
|
||||
.insert(ModelRateLimitTable)
|
||||
.values({ key: `${id}/${model}`, interval: yyyyMMddHHmm, count: usage })
|
||||
.onDuplicateKeyUpdate({ set: { count: sql`${ModelRateLimitTable.count} + ${usage}` } }),
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
CREATE TABLE `model_rate_limit` (
|
||||
`key` varchar(255) NOT NULL,
|
||||
`interval` varchar(40) NOT NULL,
|
||||
`count` int NOT NULL,
|
||||
CONSTRAINT PRIMARY KEY(`key`,`interval`)
|
||||
);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,8 @@ export namespace ZenData {
|
||||
z.object({
|
||||
id: z.string(),
|
||||
model: z.string(),
|
||||
priority: z.number().optional(),
|
||||
tpmLimit: z.number().optional(),
|
||||
weight: z.number().optional(),
|
||||
disabled: z.boolean().optional(),
|
||||
storeModel: z.string().optional(),
|
||||
@@ -123,10 +125,16 @@ export namespace ZenData {
|
||||
),
|
||||
models: (() => {
|
||||
const normalize = (model: z.infer<typeof ModelSchema>) => {
|
||||
const composite = model.providers.find((p) => compositeProviders[p.id].length > 1)
|
||||
const providers = model.providers.map((p) => ({
|
||||
...p,
|
||||
priority: p.priority ?? Infinity,
|
||||
weight: p.weight ?? 1,
|
||||
}))
|
||||
const composite = providers.find((p) => compositeProviders[p.id].length > 1)
|
||||
if (!composite)
|
||||
return {
|
||||
trialProvider: model.trialProvider ? [model.trialProvider] : undefined,
|
||||
providers,
|
||||
}
|
||||
|
||||
const weightMulti = compositeProviders[composite.id].length
|
||||
@@ -137,17 +145,16 @@ export namespace ZenData {
|
||||
if (model.trialProvider === composite.id) return compositeProviders[composite.id].map((p) => p.id)
|
||||
return [model.trialProvider]
|
||||
})(),
|
||||
providers: model.providers.flatMap((p) =>
|
||||
providers: providers.flatMap((p) =>
|
||||
p.id === composite.id
|
||||
? compositeProviders[p.id].map((sub) => ({
|
||||
...p,
|
||||
id: sub.id,
|
||||
weight: p.weight ?? 1,
|
||||
}))
|
||||
: [
|
||||
{
|
||||
...p,
|
||||
weight: (p.weight ?? 1) * weightMulti,
|
||||
weight: p.weight * weightMulti,
|
||||
},
|
||||
],
|
||||
),
|
||||
|
||||
@@ -30,3 +30,13 @@ export const KeyRateLimitTable = mysqlTable(
|
||||
},
|
||||
(table) => [primaryKey({ columns: [table.key, table.interval] })],
|
||||
)
|
||||
|
||||
export const ModelRateLimitTable = mysqlTable(
|
||||
"model_rate_limit",
|
||||
{
|
||||
key: varchar("key", { length: 255 }).notNull(),
|
||||
interval: varchar("interval", { length: 40 }).notNull(),
|
||||
count: int("count").notNull(),
|
||||
},
|
||||
(table) => [primaryKey({ columns: [table.key, table.interval] })],
|
||||
)
|
||||
|
||||
10
packages/shared/sst-env.d.ts
vendored
Normal file
10
packages/shared/sst-env.d.ts
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
/* This file is auto-generated by SST. Do not edit. */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
/* deno-fmt-ignore-file */
|
||||
/* biome-ignore-all lint: auto-generated */
|
||||
|
||||
/// <reference path="../../sst-env.d.ts" />
|
||||
|
||||
import "sst"
|
||||
export {}
|
||||
Reference in New Issue
Block a user