diff --git a/src/agents/failover-error.test.ts b/src/agents/failover-error.test.ts index 4b03c5aa12a..ebc3f032d1f 100644 --- a/src/agents/failover-error.test.ts +++ b/src/agents/failover-error.test.ts @@ -117,6 +117,13 @@ describe("failover-error", () => { }, }), ).toBeNull(); + expect( + resolveFailoverReasonFromError({ + status: 422, + message: "check open ai req parameter error", + cause: new Error("No response body"), + }), + ).toBeNull(); // Transient server errors (500/502/503/504) should trigger failover as timeout. expect(resolveFailoverReasonFromError({ status: 500 })).toBe("timeout"); expect(resolveFailoverReasonFromError({ status: 502 })).toBe("timeout"); @@ -263,6 +270,27 @@ describe("failover-error", () => { ).toBe("auth"); }); + it("preserves parent provider context for wrapped billing signals", () => { + expect( + resolveFailoverReasonFromError({ + provider: "openrouter", + status: 401, + cause: { message: "Key limit exceeded" }, + }), + ).toBe("billing"); + }); + + it("revisits shared nested errors when a later wrapper adds provider context", () => { + const shared = { message: "Key limit exceeded" }; + + expect( + resolveFailoverReasonFromError({ + cause: { cause: shared }, + error: { provider: "openrouter", cause: shared }, + }), + ).toBe("billing"); + }); + it("classifies generic model-does-not-exist messages as model_not_found", () => { expect( resolveFailoverReasonFromError({ @@ -787,4 +815,13 @@ describe("failover-error", () => { expect(described.message).toBe("123"); expect(described.reason).toBeUndefined(); }); + + it("does not recurse forever on mixed cause/error cycles", () => { + const first: { message: string; cause?: unknown } = { message: "wrapper" }; + const second: { message: string; error?: unknown } = { message: "nested" }; + first.cause = second; + second.error = first; + + expect(resolveFailoverReasonFromError(first)).toBeNull(); + }); }); diff --git a/src/agents/failover-error.ts b/src/agents/failover-error.ts index fece393489b..08eaeba4d28 100644 --- a/src/agents/failover-error.ts +++ b/src/agents/failover-error.ts @@ -240,6 +240,20 @@ function normalizeDirectErrorSignal(err: unknown): FailoverSignal { }; } +type FailoverSignalContext = Pick; + +function withInheritedProviderContext( + context: FailoverSignalContext, + signal: FailoverSignal, +): FailoverSignal { + return { + provider: signal.provider ?? context.provider, + status: signal.status, + code: signal.code, + message: signal.message, + }; +} + function getNestedErrorCandidates(err: unknown): unknown[] { if (!err || typeof err !== "object") { return []; @@ -250,56 +264,112 @@ function getNestedErrorCandidates(err: unknown): unknown[] { ); } -function resolveFailoverClassificationFromError(err: unknown): FailoverClassification | null { - if (isFailoverError(err)) { - return { - kind: "reason", - reason: err.reason, - }; - } +function shouldInspectNestedClassification( + classification: FailoverClassification | null, + messageClassification: FailoverClassification | null, +): boolean { + return ( + !classification || classification.kind === "context_overflow" || messageClassification === null + ); +} - const directSignal = normalizeDirectErrorSignal(err); - const messageClassification = directSignal.message - ? classifyFailoverSignal({ - message: directSignal.message, - provider: directSignal.provider, - }) - : null; - const classification = classifyFailoverSignal(directSignal); - const nestedCandidates = getNestedErrorCandidates(err).filter((candidate) => candidate !== err); - if ( - !classification || - classification.kind === "context_overflow" || - messageClassification === null - ) { - // Let wrapped causes override parent timeout/overflow/format guesses when - // the nested error carries a more specific failover signal. - for (const candidate of nestedCandidates) { - const causeClassification = resolveFailoverClassificationFromError(candidate); - if (causeClassification) { - return causeClassification; - } - if ( - classification?.kind === "reason" && - classification.reason === "format" && - isUnclassifiedNoBodyHttpSignal(normalizeDirectErrorSignal(candidate)) - ) { - return null; - } +function isFormatClassification(classification: FailoverClassification | null): boolean { + return classification?.kind === "reason" && classification.reason === "format"; +} + +function isNestedNoBodySignal(candidate: unknown, inheritedStatus: number | undefined): boolean { + const candidateSignal = normalizeDirectErrorSignal(candidate); + return isUnclassifiedNoBodyHttpSignal({ + ...candidateSignal, + status: candidateSignal.status ?? inheritedStatus, + }); +} + +function resolveNestedFailoverOverride( + err: unknown, + nestedContext: FailoverSignalContext, + seen: WeakSet, + classification: FailoverClassification | null, +): FailoverClassification | null | undefined { + for (const candidate of getNestedErrorCandidates(err)) { + const nestedClassification = resolveFailoverClassificationFromError( + candidate, + nestedContext, + seen, + ); + if (nestedClassification) { + return nestedClassification; + } + if ( + isFormatClassification(classification) && + isNestedNoBodySignal(candidate, nestedContext.status) + ) { + return null; } } + return undefined; +} - if (classification) { - return classification; +function resolveFailoverClassificationFromError( + err: unknown, + context: FailoverSignalContext = {}, + seen: WeakSet = new WeakSet(), +): FailoverClassification | null { + const isObject = Boolean(err && typeof err === "object"); + const objectErr = isObject ? err : undefined; + if (err && typeof err === "object") { + if (seen.has(err)) { + return null; + } + seen.add(err); } + try { + if (isFailoverError(err)) { + return { + kind: "reason", + reason: err.reason, + }; + } - if (isTimeoutError(err)) { - return { - kind: "reason", - reason: "timeout", - }; + const directSignal = withInheritedProviderContext(context, normalizeDirectErrorSignal(err)); + const messageClassification = directSignal.message + ? classifyFailoverSignal({ + message: directSignal.message, + provider: directSignal.provider, + }) + : null; + const classification = classifyFailoverSignal(directSignal); + if (shouldInspectNestedClassification(classification, messageClassification)) { + const nestedOverride = resolveNestedFailoverOverride( + err, + { + status: directSignal.status, + provider: directSignal.provider, + }, + seen, + classification, + ); + if (nestedOverride !== undefined) { + return nestedOverride; + } + } + + if (classification) { + return classification; + } + + if (isTimeoutError(err)) { + return { + kind: "reason", + reason: "timeout", + }; + } + return null; + } finally { + if (objectErr) { + seen.delete(objectErr); + } } - return null; } export function resolveFailoverReasonFromError(err: unknown): FailoverReason | null {