feat(effect-zod): transform support + walk memoization + flattened checks (#23203)

This commit is contained in:
Kit Langton
2026-04-17 19:55:55 -04:00
committed by GitHub
parent 280b9d4c80
commit f3d1fd9ce8
2 changed files with 213 additions and 15 deletions

View File

@@ -1,4 +1,4 @@
import { Schema, SchemaAST } from "effect"
import { Effect, Option, Schema, SchemaAST } from "effect"
import z from "zod"
/**
@@ -8,33 +8,85 @@ import z from "zod"
*/
export const ZodOverride: unique symbol = Symbol.for("effect-zod/override")
// AST nodes are immutable and frequently shared across schemas (e.g. a single
// Schema.Class embedded in multiple parents). Memoizing by node identity
// avoids rebuilding equivalent Zod subtrees and keeps derived children stable
// by reference across callers.
const walkCache = new WeakMap<SchemaAST.AST, z.ZodTypeAny>()
// Shared empty ParseOptions for the rare callers that need one — avoids
// allocating a fresh object per parse inside refinements and transforms.
const EMPTY_PARSE_OPTIONS = {} as SchemaAST.ParseOptions
export function zod<S extends Schema.Top>(schema: S): z.ZodType<Schema.Schema.Type<S>> {
return walk(schema.ast) as z.ZodType<Schema.Schema.Type<S>>
}
function walk(ast: SchemaAST.AST): z.ZodTypeAny {
const cached = walkCache.get(ast)
if (cached) return cached
const result = walkUncached(ast)
walkCache.set(ast, result)
return result
}
function walkUncached(ast: SchemaAST.AST): z.ZodTypeAny {
const override = (ast.annotations as any)?.[ZodOverride] as z.ZodTypeAny | undefined
if (override) return override
let out = body(ast)
for (const check of ast.checks ?? []) {
out = applyCheck(out, check, ast)
}
// Schema.Class wraps its fields in a Declaration AST plus an encoding that
// constructs the class instance. For the Zod derivation we want the plain
// field shape (the decoded/consumer view), not the class instance — so
// Declarations fall through to body(), not encoded(). User-level
// Schema.decodeTo / Schema.transform attach encoding to non-Declaration
// nodes, where we do apply the transform.
const hasTransform = ast.encoding?.length && ast._tag !== "Declaration"
const base = hasTransform ? encoded(ast) : body(ast)
const out = ast.checks?.length ? applyChecks(base, ast.checks, ast) : base
const desc = SchemaAST.resolveDescription(ast)
const ref = SchemaAST.resolveIdentifier(ast)
const next = desc ? out.describe(desc) : out
return ref ? next.meta({ ref }) : next
const described = desc ? out.describe(desc) : out
return ref ? described.meta({ ref }) : described
}
function applyCheck(out: z.ZodTypeAny, check: SchemaAST.Check<any>, ast: SchemaAST.AST): z.ZodTypeAny {
if (check._tag === "FilterGroup") {
return check.checks.reduce((acc, sub) => applyCheck(acc, sub, ast), out)
// Walk the encoded side and apply each link's decode to produce the decoded
// shape. A node `Target` produced by `from.decodeTo(Target)` carries
// `Target.encoding = [Link(from, transformation)]`. Chained decodeTo calls
// nest the encoding via `Link.to` so walking it recursively threads all
// prior transforms — typical encoding.length is 1.
function encoded(ast: SchemaAST.AST): z.ZodTypeAny {
const encoding = ast.encoding!
return encoding.reduce<z.ZodTypeAny>((acc, link) => acc.transform((v) => decode(link.transformation, v)), walk(encoding[0].to))
}
// Transformations built via pure `SchemaGetter.transform(fn)` (the common
// decodeTo case) resolve synchronously, so running with no services is safe.
// Effectful / middleware-based transforms will surface as Effect defects.
function decode(transformation: SchemaAST.Link["transformation"], value: unknown): unknown {
const exit = Effect.runSyncExit(
(transformation.decode as any).run(Option.some(value), EMPTY_PARSE_OPTIONS) as Effect.Effect<Option.Option<unknown>>,
)
if (exit._tag === "Failure") throw new Error(`effect-zod: transform failed: ${String(exit.cause)}`)
return Option.getOrElse(exit.value, () => value)
}
// Flatten FilterGroups and any nested variants into a linear list of Filters
// so we can run all of them inside a single Zod .superRefine wrapper instead
// of stacking N wrapper layers (one per check).
function applyChecks(out: z.ZodTypeAny, checks: SchemaAST.Checks, ast: SchemaAST.AST): z.ZodTypeAny {
const filters: SchemaAST.Filter<unknown>[] = []
const collect = (c: SchemaAST.Check<unknown>) => {
if (c._tag === "FilterGroup") c.checks.forEach(collect)
else filters.push(c)
}
checks.forEach(collect)
return out.superRefine((value, ctx) => {
const issue = check.run(value, ast, {} as any)
if (!issue) return
const message = issueMessage(issue) ?? (check.annotations as any)?.message ?? "Validation failed"
ctx.addIssue({ code: "custom", message })
for (const filter of filters) {
const issue = filter.run(value, ast, EMPTY_PARSE_OPTIONS)
if (!issue) continue
const message = issueMessage(issue) ?? (filter.annotations as any)?.message ?? "Validation failed"
ctx.addIssue({ code: "custom", message })
}
})
}

View File

@@ -1,5 +1,5 @@
import { describe, expect, test } from "bun:test"
import { Schema } from "effect"
import { Schema, SchemaGetter } from "effect"
import z from "zod"
import { zod, ZodOverride } from "../../src/util/effect-zod"
@@ -332,4 +332,150 @@ describe("util.effect-zod", () => {
expect(schema.parse({ id: "x" })).toEqual({ id: "x" })
})
})
describe("transforms (Schema.decodeTo)", () => {
test("Number -> pseudo-Duration (seconds) applies the decode function", () => {
// Models the account/account.ts DurationFromSeconds pattern.
const SecondsToMs = Schema.Number.pipe(
Schema.decodeTo(Schema.Number, {
decode: SchemaGetter.transform((n: number) => n * 1000),
encode: SchemaGetter.transform((ms: number) => ms / 1000),
}),
)
const schema = zod(SecondsToMs)
expect(schema.parse(3)).toBe(3000)
expect(schema.parse(0)).toBe(0)
})
test("String -> Number via parseInt decode", () => {
const ParsedInt = Schema.String.pipe(
Schema.decodeTo(Schema.Number, {
decode: SchemaGetter.transform((s: string) => Number.parseInt(s, 10)),
encode: SchemaGetter.transform((n: number) => String(n)),
}),
)
const schema = zod(ParsedInt)
expect(schema.parse("42")).toBe(42)
expect(schema.parse("0")).toBe(0)
})
test("transform inside a struct field applies per-field", () => {
const Field = Schema.Number.pipe(
Schema.decodeTo(Schema.Number, {
decode: SchemaGetter.transform((n: number) => n + 1),
encode: SchemaGetter.transform((n: number) => n - 1),
}),
)
const schema = zod(
Schema.Struct({
plain: Schema.Number,
bumped: Field,
}),
)
expect(schema.parse({ plain: 5, bumped: 10 })).toEqual({ plain: 5, bumped: 11 })
})
test("chained decodeTo composes transforms in order", () => {
// String -> Number (parseInt) -> Number (doubled).
// Exercises the encoded() reduce, not just a single link.
const Chained = Schema.String.pipe(
Schema.decodeTo(Schema.Number, {
decode: SchemaGetter.transform((s: string) => Number.parseInt(s, 10)),
encode: SchemaGetter.transform((n: number) => String(n)),
}),
Schema.decodeTo(Schema.Number, {
decode: SchemaGetter.transform((n: number) => n * 2),
encode: SchemaGetter.transform((n: number) => n / 2),
}),
)
const schema = zod(Chained)
expect(schema.parse("21")).toBe(42)
expect(schema.parse("0")).toBe(0)
})
test("Schema.Class is unaffected by transform walker (returns plain object, not instance)", () => {
// Schema.Class uses Declaration + encoding under the hood to construct
// class instances. The walker must NOT apply that transform, or zod
// parsing would return class instances instead of plain objects.
class Method extends Schema.Class<Method>("TxTestMethod")({
type: Schema.String,
value: Schema.Number,
}) {}
const schema = zod(Method)
const parsed = schema.parse({ type: "oauth", value: 1 })
expect(parsed).toEqual({ type: "oauth", value: 1 })
// Guardrail: ensure we didn't get back a Method instance.
expect(parsed).not.toBeInstanceOf(Method)
})
})
describe("optimizations", () => {
test("walk() memoizes by AST identity — same AST node returns same Zod", () => {
const shared = Schema.Struct({ id: Schema.String, name: Schema.String })
const left = zod(shared)
const right = zod(shared)
expect(left).toBe(right)
})
test("nested reuse of the same AST reuses the cached Zod child", () => {
// Two different parents embed the same inner schema. The inner zod
// child should be identical by reference inside both parents.
class Inner extends Schema.Class<Inner>("MemoTestInner")({
value: Schema.String,
}) {}
class OuterA extends Schema.Class<OuterA>("MemoTestOuterA")({
inner: Inner,
}) {}
class OuterB extends Schema.Class<OuterB>("MemoTestOuterB")({
inner: Inner,
}) {}
const shapeA = (zod(OuterA) as any).shape ?? (zod(OuterA) as any)._def?.shape?.()
const shapeB = (zod(OuterB) as any).shape ?? (zod(OuterB) as any)._def?.shape?.()
expect(shapeA.inner).toBe(shapeB.inner)
})
test("multiple checks run in a single refinement layer (all fire on one value)", () => {
// Three checks attached to the same schema. All three must run and
// report — asserting that no check silently got dropped when we
// flattened into one superRefine.
const positive = Schema.makeFilter((n: number) => (n > 0 ? undefined : "not positive"))
const even = Schema.makeFilter((n: number) => (n % 2 === 0 ? undefined : "not even"))
const under100 = Schema.makeFilter((n: number) => (n < 100 ? undefined : "too big"))
const schema = zod(Schema.Number.check(positive).check(even).check(under100))
const neg = schema.safeParse(-3)
expect(neg.success).toBe(false)
expect(neg.error!.issues.map((i) => i.message)).toEqual(expect.arrayContaining(["not positive", "not even"]))
const big = schema.safeParse(101)
expect(big.success).toBe(false)
expect(big.error!.issues.map((i) => i.message)).toContain("too big")
// Passing value satisfies all three
expect(schema.parse(42)).toBe(42)
})
test("FilterGroup flattens into the single refinement layer alongside its siblings", () => {
const positive = Schema.makeFilter((n: number) => (n > 0 ? undefined : "not positive"))
const even = Schema.makeFilter((n: number) => (n % 2 === 0 ? undefined : "not even"))
const group = Schema.makeFilterGroup([positive, even])
const under100 = Schema.makeFilter((n: number) => (n < 100 ? undefined : "too big"))
const schema = zod(Schema.Number.check(group).check(under100))
const bad = schema.safeParse(-3)
expect(bad.success).toBe(false)
expect(bad.error!.issues.map((i) => i.message)).toEqual(expect.arrayContaining(["not positive", "not even"]))
})
})
})