From 38181e425f7bcc8fcc0054dd7c19338b8cd5a4ff Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Sat, 27 Jun 2026 01:45:02 +0800 Subject: [PATCH 1/2] feat: add container memory budget --- .agents/languages/cpp.md | 7 + .agents/languages/csharp.md | 2 + .agents/languages/dart.md | 1 + .agents/languages/go.md | 8 + .agents/languages/java.md | 9 + .agents/languages/javascript.md | 7 + .agents/languages/python.md | 6 + .agents/languages/rust.md | 10 + .agents/languages/swift.md | 2 + AGENTS.md | 16 +- cpp/fory/serialization/BUILD | 10 + cpp/fory/serialization/CMakeLists.txt | 5 + .../serialization/collection_serializer.h | 221 +++++++++--- cpp/fory/serialization/config.h | 4 + .../container_memory_budget_test.cc | 265 +++++++++++++++ cpp/fory/serialization/context.cc | 40 +++ cpp/fory/serialization/context.h | 108 ++++++ cpp/fory/serialization/fory.h | 70 ++-- cpp/fory/serialization/map_serializer.h | 31 ++ cpp/fory/serialization/struct_serializer.h | 20 +- cpp/fory/serialization/union_serializer.h | 12 +- .../src/Fory.Generator/ForyModelGenerator.cs | 2 + csharp/src/Fory/CollectionSerializers.cs | 21 +- csharp/src/Fory/Config.cs | 32 ++ csharp/src/Fory/DictionarySerializers.cs | 2 + csharp/src/Fory/Fory.cs | 3 + csharp/src/Fory/NullableKeyDictionary.cs | 2 + .../Fory/PrimitiveDictionarySerializers.cs | 2 + csharp/src/Fory/ReadContext.cs | 140 ++++++++ .../Fory.Tests/ContainerMemoryBudgetTests.cs | 188 +++++++++++ dart/packages/fory/lib/src/config.dart | 11 + .../fory/lib/src/context/read_context.dart | 95 ++++++ dart/packages/fory/lib/src/fory.dart | 10 + .../serializer/collection_serializers.dart | 134 +++++--- .../lib/src/serializer/map_serializers.dart | 118 +++---- .../test/container_memory_budget_test.dart | 280 +++++++++++++++ docs/guide/cpp/configuration.md | 26 ++ docs/guide/csharp/configuration.md | 17 + docs/guide/dart/configuration.md | 25 ++ docs/guide/go/configuration.md | 23 ++ docs/guide/java/configuration.md | 5 + docs/guide/javascript/configuration.md | 24 ++ docs/guide/python/configuration.md | 8 + docs/guide/rust/configuration.md | 32 ++ docs/guide/swift/configuration.md | 14 +- docs/security/deserialization.md | 48 ++- .../xlang_implementation_guide.md | 25 +- go/fory/README.md | 4 + go/fory/array.go | 5 +- go/fory/codegen/decoder.go | 37 ++ go/fory/codegen/generator.go | 31 ++ go/fory/container_memory_budget_test.go | 207 ++++++++++++ go/fory/field_serializer.go | 5 +- go/fory/fory.go | 22 ++ go/fory/map.go | 3 + go/fory/map_primitive.go | 160 +++++---- go/fory/reader.go | 268 +++++++++++++-- go/fory/set.go | 6 + go/fory/slice.go | 8 + go/fory/slice_dyn.go | 54 +-- go/fory/slice_primitive.go | 3 + go/fory/slice_primitive_list.go | 55 ++- go/fory/stream.go | 11 + go/fory/tests/structs_fory_gen.go | 72 +++- go/fory/type_resolver.go | 4 +- .../src/main/java/org/apache/fory/Fory.java | 27 +- .../java/org/apache/fory/config/Config.java | 9 + .../org/apache/fory/config/ForyBuilder.java | 17 + .../org/apache/fory/context/ReadContext.java | 85 ++++- .../org/apache/fory/memory/MemoryBuffer.java | 7 + .../fory/serializer/ArraySerializers.java | 40 +-- .../CompatibleCollectionArrayReader.java | 28 +- .../collection/ChildContainerSerializers.java | 6 +- .../collection/CollectionLikeSerializer.java | 6 +- .../collection/CollectionSerializers.java | 37 +- .../GuavaCollectionSerializers.java | 14 +- .../ImmutableCollectionSerializers.java | 6 +- .../collection/MapLikeSerializer.java | 6 +- .../serializer/collection/MapSerializers.java | 16 +- .../collection/SubListSerializers.java | 2 +- .../org/apache/fory/memory/MemoryBuffer.java | 5 + .../java/org/apache/fory/ForyTestBase.java | 2 +- .../fory/io/MemoryBufferObjectInputTest.java | 2 +- .../fory/io/MemoryBufferObjectOutputTest.java | 2 +- .../fory/resolver/ClassResolverTest.java | 6 +- .../fory/serializer/ArraySerializersTest.java | 16 +- .../serializer/CompatibleSerializerTest.java | 20 +- .../serializer/ContainerMemoryBudgetTest.java | 318 ++++++++++++++++++ .../serializer/ExceptionSerializersTest.java | 4 +- .../serializer/PrimitiveSerializersTest.java | 4 +- .../ChildContainerSerializersTest.java | 2 +- .../collection/CollectionSerializersTest.java | 2 +- javascript/packages/core/lib/context.ts | 76 +++++ javascript/packages/core/lib/fory.ts | 12 + .../packages/core/lib/gen/collection.ts | 30 ++ javascript/packages/core/lib/gen/map.ts | 2 + javascript/packages/core/lib/type.ts | 1 + javascript/test/containerMemoryBudget.test.ts | 225 +++++++++++++ .../serializer/kotlin/CollectionSerializer.kt | 10 +- .../kotlin/CollectionSerializerTest.kt | 18 + python/pyfory/_fory.py | 14 + python/pyfory/collection.pxi | 64 +++- python/pyfory/collection.py | 6 + python/pyfory/context.pxi | 81 +++++ python/pyfory/context.py | 63 ++++ python/pyfory/serialization.pyx | 36 ++ python/pyfory/serializer.py | 1 + .../tests/test_container_memory_budget.py | 220 ++++++++++++ rust/fory-core/src/config.rs | 10 + rust/fory-core/src/context.rs | 143 ++++++++ rust/fory-core/src/fory.rs | 29 +- rust/fory-core/src/serializer/codec.rs | 7 +- rust/fory-core/src/serializer/collection.rs | 18 +- rust/fory-core/src/serializer/map.rs | 12 +- rust/tests/tests/mod.rs | 1 + .../tests/test_container_memory_budget.rs | 244 ++++++++++++++ .../scala/CollectionSerializer.scala | 7 +- .../fory/serializer/scala/MapSerializer.scala | 7 +- .../scala/XlangCollectionSerializer.scala | 12 +- .../scala/CollectionSerializerTest.scala | 32 ++ .../scala/ScalaXlangSerializerTest.scala | 20 ++ swift/Sources/Fory/AnySerializer.swift | 12 +- .../Sources/Fory/CollectionSerializers.swift | 56 ++- swift/Sources/Fory/FieldCodecs.swift | 52 +-- swift/Sources/Fory/Fory.swift | 12 +- swift/Sources/Fory/ReadContext.swift | 168 +++++++++ .../ContainerMemoryBudgetTests.swift | 232 +++++++++++++ swift/Tests/ForyTests/ForySwiftTests.swift | 5 + 128 files changed, 5187 insertions(+), 536 deletions(-) create mode 100644 cpp/fory/serialization/container_memory_budget_test.cc create mode 100644 csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs create mode 100644 dart/packages/fory/test/container_memory_budget_test.dart create mode 100644 go/fory/container_memory_budget_test.go create mode 100644 java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java create mode 100644 javascript/test/containerMemoryBudget.test.ts create mode 100644 python/pyfory/tests/test_container_memory_budget.py create mode 100644 rust/tests/tests/test_container_memory_budget.rs create mode 100644 swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 8a28fe0d03..f640aa8552 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -17,6 +17,13 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - Do not redesign alias-based or low-level public type shapes to add convenience methods unless the user explicitly asks for that API change. - For cross-language feature ports, match protocol behavior but use idiomatic C++ ownership and layering instead of mirroring Java structure literally. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container budgets are owned by `ReadContext` and initialized by the root + `Fory::deserialize` overload. Keep `max_container_memory_bytes` as `-1 / auto` or a positive + explicit limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed + `128 MiB`. Reserve estimated container-owned memory before allocation but preserve existing + byte-availability checks and their non-empty metadata ordering. Skip only dedicated string, + binary, primitive vector, and primitive dense-array owners; general `std::vector` for + non-primitive `T` is inline container storage and must be charged. ## Key Paths diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 098f9a50fe..8785b440e1 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -12,6 +12,8 @@ Load this file when changing `csharp/` or C# xlang behavior. - Generated C# gRPC service companions are compiler-owned files that depend on application-provided gRPC packages, not `csharp/src/Fory`. Keep gRPC package references out of the Fory runtime package. - C# generated schema modules are source-file owners. Service companions must use that module's `ThreadSafeFory` and must not introduce namespace-owned aliases or duplicate serializer registration paths. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, so auto uses known input length; generated serializers may call `ReadContext`'s generated-code reservation helpers, but should not expose or depend on serializer helper classes such as `CollectionCodec`. +- For C# container budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Dedicated string, binary, and primitive dense-array serializers stay skipped and rely on byte availability checks. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. ## Commands diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index d4f6bebbbc..9db2504693 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -14,6 +14,7 @@ Load this file when changing `dart/`. - Keep root numeric wrapper defaults separate from generated field metadata. Root wrapper resolution belongs in the builtin resolver, while annotations and generated metadata choose fixed, tagged, or declared-field encodings. - Dart 64-bit carriers are optimized for each platform. Do not replace native extension-type wrappers with allocation-heavy classes or route web/native hot paths through `BigInt` unless the user approves a representation change. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. +- Root deserialization container memory budgets are owned by `ReadContext`; `maxContainerMemoryBytes` defaults to `-1 / auto`, positive explicit values win, and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are memory-backed. Charge Dart lists, sets, maps, object/reference arrays, compatible list-to-array inline storage, and compatible array-to-list materialization before allocation. Skip only dedicated string, binary, `BoolList`, and typed-array dense owner paths with byte checks. Do not add stream bytes-read accounting, per-element accounting, or extra hot-path allocations for this budget. - Do not add parallel header-low/header-high slot caches or multi-slot recent caches in TypeMeta hot paths to chase benchmark gaps. Header-cache hits must use the concrete checked cache owner directly; if a hit hint is needed, cache one TypeInfo/TypeMeta object and compare the validated header identity on that object, not separate low/high header fields or benchmark-pattern state. - If Dart TypeMeta cache ownership changes, keep the invariant in a source comment near the hit path: a checked metadata-cache hit skips the body and must not grow low-bit sentinels, accepted-header fields, parallel header slots, or benchmark-pattern state. - Dart expected-type TypeDef reads should compare the expected `TypeInfo` object's cached local TypeDef header before consulting the parsed-metadata map. A match is a direct local-schema hit: skip the remote body, add the expected type to the per-read shared type table, and do not publish to `ParsedTypeMetaCache`, record a remote schema version, or parse/hash the body. diff --git a/.agents/languages/go.md b/.agents/languages/go.md index 94d47fe94c..949dd5030e 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -7,6 +7,14 @@ Load this file when changing `go/fory/` or Go xlang behavior. - Run Go commands from within `go/fory/`. - Changes under `go/` must pass formatting and tests. - The Go implementation focuses on reflection-based and codegen-based serialization. +- Root deserialization container memory budgets are owned by `ReadContext`. + `WithMaxContainerMemoryBytes` defaults to `-1 / auto`; byte-slice roots use + `inputBytes * 8 + 64 KiB`, and `DeserializeFromReader`/`DeserializeFromStream` + use fixed `128 MiB`. Charge Go slices, maps, map-backed sets, LIST-encoded + inline/value slices, and generated container reads before allocation. Fixed + arrays are caller-owned and normally not charged; `arrayDynSerializer` charges + its temporary slice. Skip only dedicated string, binary, BufferObject, + primitive ARRAY slice, and primitive array owners with byte checks. - Set `FORY_PANIC_ON_ERROR=1` when debugging a failing Go test so you get the full call stack. - Do not set `FORY_PANIC_ON_ERROR=1` when running the full Go test suite, because some tests assert on error contents. diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 3f4b165126..e88c67493e 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -14,6 +14,15 @@ Load this file when changing anything under `java/` or when Java drives a cross- values; use qualified names only when a real name conflict requires it. - If you run temporary tests with `java -cp`, run `mvn -T16 install -DskipTests` first so local Fory jars are current. - `WriteContext`, `ReadContext`, and `CopyContext` must stay explicit. Do not reintroduce `ThreadLocal` or ambient runtime-context patterns. +- Java root deserialization container memory budgeting belongs to `ReadContext` + and is initialized by `Fory` root APIs. Public config is + `maxContainerMemoryBytes` with `-1` auto, positive explicit override, + known-length auto `inputBytes * 8 + 64 KiB`, and stream/unknown auto + `128 MiB`. Collection/map/object-array serializers should charge estimated + container-owned memory before allocation while preserving existing + `checkReadableBytes` guards before backing allocation or capacity + reservation. Do not add nested serializer-path `try/finally`, per-element + work, or dynamic stream bytes-read accounting for this budget. - Generated serializers must not retain runtime context fields. `Fory` should stay a root-operation facade rather than accumulating serializer or convenience state. - When the serializer class and constructor shape are known at the call site, prefer direct constructor lambdas or direct instantiation over reflective `Serializers.newSerializer(...)`. - For GraalVM, use `fory codegen` to generate serializers when building native images. Do not add reflection configuration except for JDK `proxy`. diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index 9ebf99d3be..954e73a234 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -14,6 +14,13 @@ Load this file when changing `javascript/`. - Runtime value carriers such as decimal or reduced-precision numeric types belong under the core `types/` ownership boundary, with imports, exports, and codegen externals updated together. - Keep `TypeInfo` as schema metadata. Compatibility-sensitive decisions belong on `TypeResolver` or explicit operations, not as retained resolver state on metadata objects. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. +- JavaScript root deserialization container memory budgeting belongs to `ReadContext`. + `maxContainerMemoryBytes` uses `-1` auto, positive explicit limits, and known + `Uint8Array` root length as `inputBytes * 8 + 64 KiB`. Generated and dynamic + list/set/map readers must reserve before allocation while preserving existing + byte checks. Keep dedicated string, binary, and dense typed-array owners out of + this budget; compatible list-to-typed-array reads must charge typed inline + storage. - Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. - Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. - Compatible scalar conversion is immediate-field-only. Recursive schema comparison for collection elements, array elements, map keys, and map values must reject scalar mismatches instead of applying the top-level scalar conversion matrix. diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 7365632c37..3ed69c6eb7 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -13,6 +13,12 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Cython mode owns the hot runtime path. Do not duplicate core runtime types between Python and Cython, tunnel Python facade methods into hidden Cython internals, or keep dead shims unless the user explicitly needs a compatibility module path. - Use explicit Cython fields and methods for fixed hot-path shapes. Avoid `__getattr__`, generic `object` fields, public bridge internals, or `Fory` backreferences where ownership can stay explicit. - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. +- Root deserialization container memory budgets are owned by pure-Python and Cython `ReadContext`. + Keep `max_container_memory_bytes` public on `pyfory.Fory`/`Config`; `-1` uses known-length + `inputBytes * 8 + 64 KiB` or fixed `128 MiB` for stream roots. Reserve fixed container cost plus + reference slots for list/tuple/set, map object/table/entry estimates for dict, and object-dtype + ndarray item storage. Keep string, bytes, `array.array`, primitive dense array, and primitive + ndarray owners skipped, and preserve byte-availability checks after budget reservation. - Public value constructors should accept normal Python values. Raw-bit, raw-buffer, and memoryview entry points should be explicit low-level APIs, and packed carriers should expose the buffer protocol from the actual storage owner when appropriate. - When debugging runtime or benchmark behavior, install the local package into the exact interpreter under test instead of relying on mixed `PYTHONPATH` state. - For wheel or extension pipeline changes, derive extension-module paths from current build targets, packaging config, or wheel payload discovery rather than historical module names. diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index ffe5648330..a31126586c 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -18,6 +18,16 @@ Load this file when changing `rust/` or Rust xlang behavior. - If breakage is explicitly acceptable during a Rust module refactor, rewire macros, tests, and sibling crates directly to the new boundaries instead of adding compatibility re-exports. - For panic-safety in hot paths, preserve TLS context reuse. Add scoped guards or owned fallbacks rather than per-call context allocation, and reset reused contexts at entry and successful exit. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container memory budget state belongs to `ReadContext` and is initialized by + the root `Fory` read methods before the header is consumed. Rust roots are byte-slice/`Reader` + backed, so auto budget uses `inputBytes * 8 + 64 KiB`; do not add dynamic bytes-read accounting. +- Rust `Vec` stores inline element storage, so general LIST paths charge fixed `Vec` cost plus + `len * size_of::()`, including `Vec` and `Vec`. Dedicated primitive dense + ARRAY `Vec` readers, strings, binary, and primitive fixed-array owners stay skipped and keep + their byte checks. +- Direct `Serializer` collection/map paths and derive `Codec` collection/map paths are separate + allocation owners. Keep reservations in both before `Vec::with_capacity`, + `HashMap::with_capacity`, or collection materialization; charge zero-size containers. ## Key Paths diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index cfd5198524..8a6556eccd 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -13,6 +13,8 @@ Load this file when changing `swift/` or Swift xlang behavior. - Preserve distinct temporal semantics. Timestamp values and day-only local dates should have protocol-accurate helper names and no stale aliases after a refactor. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container memory budget state belongs to `ReadContext`. Swift public roots are `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or serializer-local budget state. +- For Swift container budget formulas, distinguish inline/value storage from reference storage: use `MemoryLayout.stride` for value arrays/lists/maps and the 4-byte reference fallback for `Serializer.isRefType` / `FieldCodec.isRefType` paths. Dedicated `String`, `Data`/binary, and primitive packed-array owners stay skipped, except compatible packed-array-to-list reads must charge the target list materialization before allocation. ## Commands diff --git a/AGENTS.md b/AGENTS.md index 5cd2581d5c..c8346e0f4f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,7 +32,8 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Respect ownership. Keep logic, state, and helpers in their natural owner, and do not move serializer-local, context-local, runtime-type-local, or protocol-local problems into global utilities. - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. -- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. +- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Container memory-budget reservation is accounting only and may happen before that byte check, but it must not replace the byte check. +- Root deserialization container memory budgets are estimated container-owned memory, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Charge fixed container cost, backing/reference/inline storage, map table and entry overhead, and zero-size containers; skip only dedicated string, binary, primitive array, and primitive dense-array owners. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. @@ -111,6 +112,19 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - `docs/DEVELOPMENT.md` plus updates under `docs/guide/` and `docs/benchmarks/` are synced to `apache/fory-site`; other website content belongs there. - When benchmark logic, scripts, configuration, or compared serializers change, rerun the relevant benchmarks and refresh the artifacts under `docs/benchmarks/**`. +## Network Error Command Log + +- 2026-06-26: `cargo check` from `rust/` failed while updating crates.io through + `127.0.0.1:7890`; retried as + `env -u all_proxy -u http_proxy -u https_proxy -u ALL_PROXY -u HTTP_PROXY -u HTTPS_PROXY cargo check`, + which still used the configured proxy. `cargo check --offline` succeeded using the local Cargo + cache. +- 2026-06-26: `cmake -S . -B ../tasks/cpp-cmake-build -DFORY_BUILD_TESTS=ON +-DFORY_BUILD_SHARED=OFF -DFORY_BUILD_STATIC=ON` from `cpp/` failed while FetchContent tried to + clone googletest through `127.0.0.1:7890`; retried as + `env -u all_proxy -u http_proxy -u https_proxy -u ALL_PROXY -u HTTP_PROXY -u HTTPS_PROXY cmake -S . -B ../tasks/cpp-cmake-build -DFORY_BUILD_TESTS=ON -DFORY_BUILD_SHARED=OFF -DFORY_BUILD_STATIC=ON`, + which still used the configured proxy in the nested clone. + ## Shared Engineering Expectations - Favor zero-copy techniques, JIT or codegen opportunities, and cache-friendly memory access patterns in performance-critical paths. diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index b74c356a2b..1102e53f1c 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -109,6 +109,16 @@ cc_test( ], ) +cc_test( + name = "container_memory_budget_test", + srcs = ["container_memory_budget_test.cc"], + deps = [ + ":fory_serialization", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + cc_test( name = "variant_serializer_test", srcs = ["variant_serializer_test.cc"], diff --git a/cpp/fory/serialization/CMakeLists.txt b/cpp/fory/serialization/CMakeLists.txt index 0b88f0e5e3..7ff18d0320 100644 --- a/cpp/fory/serialization/CMakeLists.txt +++ b/cpp/fory/serialization/CMakeLists.txt @@ -102,6 +102,11 @@ if(FORY_BUILD_TESTS) target_link_libraries(fory_serialization_map_test fory_serialization GTest::gtest GTest::gtest_main) gtest_discover_tests(fory_serialization_map_test) + add_executable(fory_serialization_container_memory_budget_test container_memory_budget_test.cc) + fory_configure_target(fory_serialization_container_memory_budget_test) + target_link_libraries(fory_serialization_container_memory_budget_test fory_serialization GTest::gtest GTest::gtest_main) + gtest_discover_tests(fory_serialization_container_memory_budget_test) + add_executable(fory_serialization_variant_test variant_serializer_test.cc) fory_configure_target(fory_serialization_variant_test) target_link_libraries(fory_serialization_variant_test fory_serialization GTest::gtest GTest::gtest_main) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 473f9d6950..7c89e1a265 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -380,6 +380,34 @@ struct has_reserve inline constexpr bool has_reserve_v = has_reserve::value; +constexpr size_t kContainerEntryOverheadBytes = 16; +constexpr size_t kContainerReferenceBytes = sizeof(void *); + +template +struct is_std_vector_container : std::false_type {}; + +template +struct is_std_vector_container> : std::true_type {}; + +template +inline constexpr bool is_std_vector_container_v = + is_std_vector_container::value; + +template +constexpr size_t collection_element_memory_bytes() { + using Elem = typename Container::value_type; + if constexpr (is_std_vector_container_v) { + return sizeof(Elem); + } else { + static_assert(sizeof(Elem) <= std::numeric_limits::max() - + kContainerEntryOverheadBytes - + kContainerReferenceBytes * 2, + "container element memory estimate overflows"); + return sizeof(Elem) + kContainerEntryOverheadBytes + + kContainerReferenceBytes * 2; + } +} + template inline bool reserve_collection(Container &result, ReadContext &ctx, uint32_t length) { @@ -388,6 +416,12 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } + constexpr size_t fixed_bytes = sizeof(Container); + constexpr size_t elem_bytes = collection_element_memory_bytes(); + if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< + fixed_bytes, elem_bytes>(length)))) { + return false; + } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { return false; } @@ -397,6 +431,14 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, return true; } +template +inline bool reserve_empty_collection(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + return ctx.reserve_container_memory(sizeof(Container)); +} + // Helper to insert element into container (vector or set) template inline void collection_insert(Container &result, T &&elem) { @@ -412,9 +454,9 @@ template inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { Container result; if (length == 0) { - return result; - } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } @@ -443,6 +485,10 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } + // Read elements if (is_same_type) { if (track_ref) { @@ -922,16 +968,20 @@ struct Serializer< return std::vector(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::vector(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::vector result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -961,7 +1011,6 @@ struct Serializer< } } - std::vector result; if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { return result; } @@ -1058,6 +1107,10 @@ struct Serializer< std::vector result; if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -1217,16 +1270,20 @@ template struct Serializer> { return std::list(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::list(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::list result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1256,7 +1313,9 @@ template struct Serializer> { } } - std::list result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1349,6 +1408,16 @@ template struct Serializer> { } std::list result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1409,16 +1478,20 @@ template struct Serializer> { return std::deque(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::deque(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::deque result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1448,7 +1521,9 @@ template struct Serializer> { } } - std::deque result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1541,6 +1616,16 @@ template struct Serializer> { } std::deque result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1602,9 +1687,14 @@ struct Serializer> { return std::forward_list(); } + std::forward_list result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - return std::forward_list(); + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; } // Dispatch to slow path for polymorphic/shared-ref elements @@ -1620,7 +1710,7 @@ struct Serializer> { // Elements header bitmap (CollectionFlags) uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::forward_list(); + return result; } bool track_ref = (bitmap & COLL_TRACKING_REF) != 0; bool has_null = (bitmap & COLL_HAS_NULL) != 0; @@ -1632,7 +1722,7 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::forward_list(); + return result; } using ElemType = nullable_element_t; uint32_t expected = @@ -1644,8 +1734,12 @@ struct Serializer> { } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, length))) { - return std::forward_list(); + return result; } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1688,7 +1782,6 @@ struct Serializer> { } // Build forward_list in reverse order using push_front - std::forward_list result; for (auto it = temp.rbegin(); it != temp.rend(); ++it) { result.push_front(std::move(*it)); } @@ -1968,9 +2061,20 @@ struct Serializer> { return std::forward_list(); } + std::forward_list result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } std::vector temp; if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, size))) { - return std::forward_list(); + return result; } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -1980,7 +2084,6 @@ struct Serializer> { temp.push_back(std::move(elem)); } // Build forward_list in reverse order - std::forward_list result; for (auto it = temp.rbegin(); it != temp.rend(); ++it) { result.push_front(std::move(*it)); } @@ -2069,16 +2172,20 @@ struct Serializer> { return std::set(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (size == 0) { - return std::set(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, size); } else { + std::set result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Read elements header bitmap (CollectionFlags) in xlang mode @@ -2094,17 +2201,20 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::set(); + return result; } uint32_t expected = static_cast(Serializer::type_id); if (!type_id_matches(elem_type_info->type_id, expected)) { ctx.set_error( Error::type_mismatch(elem_type_info->type_id, expected)); - return std::set(); + return result; } } - std::set result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } + if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2151,6 +2261,16 @@ struct Serializer> { } std::set result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -2244,17 +2364,22 @@ struct Serializer> { return std::unordered_set(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (size == 0) { - return std::unordered_set(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, size); } else { + std::unordered_set result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>( + ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Read elements header bitmap (CollectionFlags) in xlang mode @@ -2270,20 +2395,20 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::unordered_set(); + return result; } uint32_t expected = static_cast(Serializer::type_id); if (!type_id_matches(elem_type_info->type_id, expected)) { ctx.set_error( Error::type_mismatch(elem_type_info->type_id, expected)); - return std::unordered_set(); + return result; } } - std::unordered_set result; if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } + if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2330,6 +2455,14 @@ struct Serializer> { } std::unordered_set result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>( + ctx)))) { + return result; + } + return result; + } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index eac9c14436..a59b575f71 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,6 +52,10 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; + /// Maximum estimated container-owned memory accepted during one root + /// deserialization. `-1` selects the automatic input-shaped limit. + int64_t max_container_memory_bytes = -1; + /// Maximum accepted field count in one received struct TypeMeta. uint32_t max_type_fields = 512; diff --git a/cpp/fory/serialization/container_memory_budget_test.cc b/cpp/fory/serialization/container_memory_budget_test.cc new file mode 100644 index 0000000000..781e9312bc --- /dev/null +++ b/cpp/fory/serialization/container_memory_budget_test.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "fory/serialization/fory.h" +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include + +namespace fory { +namespace serialization { +namespace { + +constexpr size_t kKnownBudgetSlack = 64 * 1024; + +struct BudgetItem { + int32_t id = 0; + std::string name; + + bool operator==(const BudgetItem &other) const { + return id == other.id && name == other.name; + } + + FORY_STRUCT(BudgetItem, id, name); +}; + +struct BudgetSiblings { + std::vector left; + std::vector right; + + bool operator==(const BudgetSiblings &other) const { + return left == other.left && right == other.right; + } + + FORY_STRUCT(BudgetSiblings, left, right); +}; + +template +auto with_fory(int64_t max_container_memory_bytes, Fn &&fn) { + auto fory = Fory::builder() + .xlang(true) + .compatible(false) + .track_ref(false) + .max_container_memory_bytes(max_container_memory_bytes) + .build(); + fory.register_struct(1); + fory.register_struct(2); + return std::forward(fn)(fory); +} + +template std::vector serialize_value(const T &value) { + auto bytes = with_fory(-1, [&](Fory &fory) { return fory.serialize(value); }); + EXPECT_TRUE(bytes.ok()) << bytes.error().to_string(); + return std::move(bytes).value(); +} + +size_t nested_empty_budget(size_t count) { + using Inner = std::vector; + using Outer = std::vector; + return sizeof(Outer) + count * sizeof(Inner) + count * sizeof(Inner); +} + +TEST(ContainerMemoryBudgetTest, KnownLengthAutoBudget) { + constexpr size_t count = 3000; + std::vector> value(count); + auto bytes = serialize_value(value); + const size_t auto_limit = bytes.size() * 8 + kKnownBudgetSlack; + const size_t required = nested_empty_budget(count); + ASSERT_GT(required, auto_limit); + + auto default_result = with_fory(-1, [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(default_result.ok()); + EXPECT_EQ(default_result.error().code(), ErrorCode::InvalidData); + + auto explicit_auto_result = + with_fory(static_cast(auto_limit), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(explicit_auto_result.ok()); + EXPECT_EQ(explicit_auto_result.error().code(), ErrorCode::InvalidData); + + auto explicit_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_TRUE(explicit_result.ok()) << explicit_result.error().to_string(); + EXPECT_EQ(explicit_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, StreamAutoBudget) { + constexpr size_t count = 10000; + std::vector> value(count); + auto bytes = serialize_value(value); + const size_t known_limit = bytes.size() * 8 + kKnownBudgetSlack; + ASSERT_GT(nested_empty_budget(count), known_limit); + + auto known_result = with_fory(-1, [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(known_result.ok()); + EXPECT_EQ(known_result.error().code(), ErrorCode::InvalidData); + + std::string input(reinterpret_cast(bytes.data()), bytes.size()); + std::istringstream source(input); + StdInputStream stream(source, 8); + auto stream_result = with_fory(-1, [&](Fory &fory) { + return fory.deserialize>>(stream); + }); + ASSERT_TRUE(stream_result.ok()) << stream_result.error().to_string(); + EXPECT_EQ(stream_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, ExplicitOverride) { + std::vector value(8); + auto bytes = serialize_value(value); + const size_t required = + sizeof(std::vector) + value.size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, EmptyContainersChargeFixedCost) { + std::vector> value(1); + auto bytes = serialize_value(value); + const size_t required = nested_empty_budget(1); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, SiblingCumulativeBudget) { + BudgetSiblings value; + value.left.resize(16); + value.right.resize(16); + auto bytes = serialize_value(value); + const size_t one_vector = + sizeof(std::vector) + value.left.size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(one_vector), [&](Fory &fory) { + return fory.deserialize(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto enough_result = + with_fory(static_cast(one_vector * 2), [&](Fory &fory) { + return fory.deserialize(bytes); + }); + ASSERT_TRUE(enough_result.ok()) << enough_result.error().to_string(); + EXPECT_EQ(enough_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, MapBudget) { + std::map value{{"a", 1}, {"b", 2}, {"c", 3}}; + auto bytes = serialize_value(value); + const size_t entry_bytes = + sizeof(std::string) + sizeof(int32_t) + 16 + sizeof(void *) * 3; + const size_t required = sizeof(value) + value.size() * entry_bytes; + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, DensePathsSkipped) { + { + std::string value = "container-budget-string"; + auto bytes = serialize_value(value); + auto result = with_fory( + 1, [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } + { + std::vector value(256, 7); + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } + { + std::vector value(256, 42); + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } +} + +TEST(ContainerMemoryBudgetTest, ByteCheckStillRejectsLargeLength) { + Config config; + auto resolver = std::make_unique(); + ReadContext ctx(config, std::move(resolver)); + std::vector bytes{64}; + Buffer buffer(bytes.data(), static_cast(bytes.size()), false); + ctx.attach(buffer); + + auto result = Serializer>::read_data(ctx); + EXPECT_TRUE(result.empty()); + ASSERT_TRUE(ctx.has_error()); + EXPECT_EQ(ctx.error().code(), ErrorCode::BufferOutOfBound); +} + +} // namespace +} // namespace serialization +} // namespace fory diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index deff5ee16c..686db558ad 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -739,6 +739,43 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } +bool ReadContext::reserve_counted_container_checked(uint32_t length, + size_t fixed_bytes, + size_t elem_bytes) { + if (FORY_PREDICT_FALSE( + elem_bytes != 0 && + static_cast(length) > + (std::numeric_limits::max() - fixed_bytes) / + elem_bytes)) { + return set_container_memory_overflow(length, elem_bytes); + } + return reserve_container_memory(static_cast(length) * elem_bytes + + fixed_bytes); +} + +bool ReadContext::set_container_memory_error(const std::string &message) { + set_error(Error::invalid_data(message)); + return false; +} + +bool ReadContext::set_container_memory_overflow(uint32_t length, + size_t elem_bytes) { + set_error(Error::invalid_data( + "container memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); + return false; +} + +bool ReadContext::set_container_memory_exceeded(size_t bytes, + size_t remaining) { + set_error(Error::invalid_data( + "estimated container memory request " + std::to_string(bytes) + + " bytes exceeds max_container_memory_bytes remaining budget " + + std::to_string(remaining) + " bytes out of effective limit " + + std::to_string(container_memory_limit_bytes_) + " bytes")); + return false; +} + void ReadContext::reset() { // Clear error state first error_ = Error(); @@ -747,6 +784,9 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; + // Root deserialization initializes the container budget before reading the + // header; direct ReadContext users start with the unlimited sentinel fields. + // Leave those fields untouched here so root guard cleanup stays store-light. if (meta_string_table_active_) { meta_string_table_.reset(); meta_string_table_active_ = false; diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 6af99c4ccc..5d2bbc3c60 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -32,6 +32,7 @@ #include "fory/util/result.h" #include +#include #include #include @@ -504,6 +505,98 @@ class ReadContext { } } + FORY_ALWAYS_INLINE bool init_container_budget_known(size_t root_bytes) { + size_t limit = 0; + if (config_->max_container_memory_bytes > 0) { + const uint64_t configured = + static_cast(config_->max_container_memory_bytes); + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE( + configured > + static_cast(std::numeric_limits::max()))) { + return set_container_memory_error( + "max_container_memory_bytes does not fit size_t"); + } + } + limit = static_cast(configured); + } else { + constexpr size_t max_root_bytes = (std::numeric_limits::max() - + kKnownContainerBudgetSlackBytes) / + kKnownContainerBudgetMultiplier; + if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { + return set_container_memory_error( + "root input size overflows automatic container memory budget"); + } + limit = root_bytes * kKnownContainerBudgetMultiplier + + kKnownContainerBudgetSlackBytes; + } + container_memory_limit_bytes_ = limit; + remaining_container_memory_bytes_ = limit; + return true; + } + + FORY_ALWAYS_INLINE bool init_container_budget_unknown() { + size_t limit = 0; + if (config_->max_container_memory_bytes > 0) { + const uint64_t configured = + static_cast(config_->max_container_memory_bytes); + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE( + configured > + static_cast(std::numeric_limits::max()))) { + return set_container_memory_error( + "max_container_memory_bytes does not fit size_t"); + } + } + limit = static_cast(configured); + } else { + limit = kUnknownContainerBudgetBytes; + } + container_memory_limit_bytes_ = limit; + remaining_container_memory_bytes_ = limit; + return true; + } + + FORY_ALWAYS_INLINE bool reserve_container_memory(size_t bytes) { + const size_t remaining = remaining_container_memory_bytes_; + if (FORY_PREDICT_FALSE(bytes > remaining)) { + return set_container_memory_exceeded(bytes, remaining); + } + remaining_container_memory_bytes_ = remaining - bytes; + return true; + } + + FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length, + size_t fixed_bytes, + size_t elem_bytes) { + if (length == 0) { + return reserve_container_memory(fixed_bytes); + } + constexpr size_t kMaxLength = + static_cast(std::numeric_limits::max()); + if (FORY_PREDICT_TRUE(elem_bytes <= + (std::numeric_limits::max() - fixed_bytes) / + kMaxLength)) { + return reserve_container_memory(static_cast(length) * elem_bytes + + fixed_bytes); + } + return reserve_counted_container_checked(length, fixed_bytes, elem_bytes); + } + + template + FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length) { + constexpr size_t kMaxLength = + static_cast(std::numeric_limits::max()); + if constexpr (elem_bytes <= + (std::numeric_limits::max() - fixed_bytes) / + kMaxLength) { + return reserve_container_memory(static_cast(length) * elem_bytes + + fixed_bytes); + } else { + return reserve_counted_container_checked(length, fixed_bytes, elem_bytes); + } + } + // =========================================================================== // Read methods with Error& parameter // All methods accept Error& as parameter for reduced overhead. @@ -659,9 +752,22 @@ class ReadContext { inline const Config &config() const { return *config_; } private: + static constexpr size_t kKnownContainerBudgetMultiplier = 8; + static constexpr size_t kKnownContainerBudgetSlackBytes = 64 * 1024; + static constexpr size_t kUnknownContainerBudgetBytes = + 128ULL * 1024ULL * 1024ULL; + FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); + FORY_NOINLINE bool reserve_counted_container_checked(uint32_t length, + size_t fixed_bytes, + size_t elem_bytes); + FORY_NOINLINE bool set_container_memory_error(const std::string &message); + FORY_NOINLINE bool set_container_memory_overflow(uint32_t length, + size_t elem_bytes); + FORY_NOINLINE bool set_container_memory_exceeded(size_t bytes, + size_t remaining); // Error state - accumulated during deserialization, checked at the end Error error_; @@ -671,6 +777,8 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; + size_t container_memory_limit_bytes_ = std::numeric_limits::max(); + size_t remaining_container_memory_bytes_ = std::numeric_limits::max(); // Meta sharing state (for compatible mode) // Persistent cache storage for TypeInfo objects keyed by meta header. diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 36ef992d17..6d26c3bfa7 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -109,6 +109,16 @@ class ForyBuilder { return *this; } + /// Set maximum estimated container-owned memory for one root deserialization. + /// + /// Use `-1` for automatic limits. Positive values are explicit byte limits. + ForyBuilder &max_container_memory_bytes(int64_t max_bytes) { + FORY_CHECK(max_bytes == -1 || max_bytes > 0) + << "max_container_memory_bytes must be positive or -1 for auto"; + config_.max_container_memory_bytes = max_bytes; + return *this; + } + /// Set maximum accepted field count in one received struct TypeMeta. ForyBuilder &max_type_fields(uint32_t max_fields) { FORY_CHECK(max_fields > 0) << "max_type_fields must be positive"; @@ -673,19 +683,7 @@ class Fory : public BaseFory { Buffer buffer(const_cast(data), static_cast(size), false); - - Error header_error; - const uint8_t header = buffer.read_uint8(header_error); - if (FORY_PREDICT_FALSE(!header_error.ok())) { - return Unexpected(std::move(header_error)); - } - if (FORY_PREDICT_FALSE(header != precomputed_header_)) { - return Unexpected(invalid_root_header(header)); - } - - read_ctx_->attach(buffer); - ReadContextGuard guard(*read_ctx_); - return deserialize_impl(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from a byte vector. @@ -711,18 +709,7 @@ class Fory : public BaseFory { if (FORY_PREDICT_FALSE(!finalized_)) { ensure_finalized(); } - Error header_error; - const uint8_t header = buffer.read_uint8(header_error); - if (FORY_PREDICT_FALSE(!header_error.ok())) { - return Unexpected(std::move(header_error)); - } - if (FORY_PREDICT_FALSE(header != precomputed_header_)) { - return Unexpected(invalid_root_header(header)); - } - - read_ctx_->attach(buffer); - ReadContextGuard guard(*read_ctx_); - return deserialize_impl(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from an input stream. @@ -745,7 +732,10 @@ class Fory : public BaseFory { }; StreamShrinkGuard shrink_guard{&input_stream}; Buffer &buffer = input_stream.get_buffer(); - return deserialize(buffer); + if (FORY_PREDICT_FALSE(!finalized_)) { + ensure_finalized(); + } + return deserialize_buffer(buffer); } /// Deserialize an object from StdInputStream. @@ -883,6 +873,34 @@ class Fory : public BaseFory { return result; } + template + FORY_ALWAYS_INLINE Result deserialize_buffer(Buffer &buffer) { + const bool budget_ok = + unknown_root + ? read_ctx_->init_container_budget_unknown() + : read_ctx_->init_container_budget_known(buffer.remaining_size()); + if (FORY_PREDICT_FALSE(!budget_ok)) { + Error error = read_ctx_->take_error(); + read_ctx_->reset(); + return Unexpected(std::move(error)); + } + + Error header_error; + const uint8_t header = buffer.read_uint8(header_error); + if (FORY_PREDICT_FALSE(!header_error.ok())) { + read_ctx_->reset(); + return Unexpected(std::move(header_error)); + } + if (FORY_PREDICT_FALSE(header != precomputed_header_)) { + read_ctx_->reset(); + return Unexpected(invalid_root_header(header)); + } + + read_ctx_->attach(buffer); + ReadContextGuard guard(*read_ctx_); + return deserialize_impl(buffer); + } + template Result cached_write_root_type_info() { constexpr uint64_t ctid = type_index(); diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 830e5fbae5..a7a3bc615d 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -21,6 +21,7 @@ #include "fory/serialization/serializer.h" #include +#include #include #include #include @@ -81,6 +82,9 @@ struct MapReserver inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { // Lazy error propagation may continue into later readers; do not let that @@ -88,6 +92,20 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } + using Key = typename MapType::key_type; + using Value = typename MapType::mapped_type; + static_assert(sizeof(Key) <= std::numeric_limits::max() - + sizeof(Value) - kMapEntryBudgetBytes - + kMapReferenceBudgetBytes * 3, + "map entry memory estimate overflows"); + constexpr size_t fixed_bytes = sizeof(MapType); + constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value) + + kMapEntryBudgetBytes + + kMapReferenceBudgetBytes * 3; + if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< + fixed_bytes, elem_bytes>(length)))) { + return false; + } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { return false; } @@ -95,6 +113,13 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { return true; } +template inline bool reserve_empty_map(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + return ctx.reserve_container_memory(sizeof(MapType)); +} + /// write chunk size at header offset inline void write_chunk_size(WriteContext &ctx, size_t header_offset, uint8_t size) { @@ -567,6 +592,9 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) { MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { @@ -699,6 +727,9 @@ template inline MapType read_map_data_slow(ReadContext &ctx, uint32_t length) { MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 4da4f7751c..00acf71b2a 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -897,9 +897,9 @@ Container read_configured_list_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; if (length == 0) { - return result; - } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -916,6 +916,9 @@ Container read_configured_list_data(ReadContext &ctx) { return result; } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } const RefMode elem_ref_mode = track_ref ? RefMode::Tracking : (has_null ? RefMode::NullOnly : RefMode::None); @@ -939,7 +942,13 @@ FORY_NOINLINE Container read_configured_list_data_as_array_field( using Elem = element_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; - if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -1051,6 +1060,9 @@ MapType read_configured_map_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/cpp/fory/serialization/union_serializer.h b/cpp/fory/serialization/union_serializer.h index d5247d431f..8a8bc99fe3 100644 --- a/cpp/fory/serialization/union_serializer.h +++ b/cpp/fory/serialization/union_serializer.h @@ -466,9 +466,9 @@ Container read_union_configured_list_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; if (length == 0) { - return result; - } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -483,6 +483,9 @@ Container read_union_configured_list_data(ReadContext &ctx) { return result; } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } for (uint32_t i = 0; i < length; ++i) { if constexpr (ElemNode >= 0) { auto elem = read_union_configured_value( @@ -553,6 +556,9 @@ MapType read_union_configured_map_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 8e051da478..50c4682515 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -1163,10 +1163,12 @@ private static void EmitReadCompatibleListArrayPayload( uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); if (codec.CarrierKind == CarrierKind.Array) { + sb.AppendLine($"{indent}context.ReserveArrayMemory<{elementTypeName}>({lengthVar});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); } else { + sb.AppendLine($"{indent}context.ReserveListMemory<{elementTypeName}>({lengthVar});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); } diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index c407153fd5..2e5a610d00 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -201,6 +201,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea int length = checked((int)context.Reader.ReadVarUInt32()); if (length == 0) { + context.ReserveListMemory(length); return []; } @@ -213,6 +214,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; + context.ReserveNonEmptyListMemory(length); context.Reader.CheckBound(length); List values = new(length); if (!sameType) @@ -522,6 +524,7 @@ public override void WriteData(WriteContext context, in T[] value, bool hasGener public override T[] ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveArrayMemory(values.Count); return values.ToArray(); } } @@ -554,7 +557,9 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { - return [.. CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)]; + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return [.. values]; } } @@ -570,7 +575,9 @@ public override void WriteData(WriteContext context, in SortedSet value, bool public override SortedSet ReadData(ReadContext context) { - return [.. CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)]; + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return [.. values]; } } @@ -586,7 +593,9 @@ public override void WriteData(WriteContext context, in ImmutableHashSet valu public override ImmutableHashSet ReadData(ReadContext context) { - return ImmutableHashSet.CreateRange(CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)); + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return ImmutableHashSet.CreateRange(values); } } @@ -602,7 +611,9 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { - return new LinkedList(CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)); + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return new LinkedList(values); } } @@ -619,6 +630,7 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); Queue queue = new(values.Count); for (int i = 0; i < values.Count; i++) { @@ -655,6 +667,7 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); Stack stack = new(values.Count); for (int i = 0; i < values.Count; i++) { diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 438039d2c8..1947bac29c 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -28,6 +28,7 @@ internal Config( bool compatible, bool checkStructVersion, int maxDepth, + long maxContainerMemoryBytes, int maxTypeFields, int maxTypeMetaBytes, int maxSchemaVersionsPerType, @@ -37,6 +38,12 @@ internal Config( { throw new ArgumentOutOfRangeException(nameof(maxDepth), "MaxDepth must be greater than 0."); } + if (maxContainerMemoryBytes != -1 && maxContainerMemoryBytes <= 0) + { + throw new ArgumentOutOfRangeException( + nameof(maxContainerMemoryBytes), + "MaxContainerMemoryBytes must be positive or -1 for auto."); + } if (maxTypeFields <= 0) { throw new ArgumentOutOfRangeException(nameof(maxTypeFields), "MaxTypeFields must be greater than 0."); @@ -58,6 +65,7 @@ internal Config( Compatible = compatible; CheckStructVersion = checkStructVersion; MaxDepth = maxDepth; + MaxContainerMemoryBytes = maxContainerMemoryBytes; MaxTypeFields = maxTypeFields; MaxTypeMetaBytes = maxTypeMetaBytes; MaxSchemaVersionsPerType = maxSchemaVersionsPerType; @@ -84,6 +92,11 @@ internal Config( /// public int MaxDepth { get; } + /// + /// Gets the maximum estimated container-owned memory accepted during one root deserialization. + /// + public long MaxContainerMemoryBytes { get; } + /// /// Gets the maximum accepted field count in one received struct TypeMeta. /// @@ -114,6 +127,7 @@ public sealed class ForyBuilder private bool? _compatible; private bool _checkStructVersion; private int _maxDepth = 20; + private long _maxContainerMemoryBytes = -1; private int _maxTypeFields = 512; private int _maxTypeMetaBytes = 4096; private int _maxSchemaVersionsPerType = 10; @@ -169,6 +183,23 @@ public ForyBuilder MaxDepth(int value) return this; } + /// + /// Sets the maximum estimated container-owned memory accepted during one root deserialization. + /// Use -1 for the automatic root-size-based limit, or a positive byte limit. + /// + public ForyBuilder MaxContainerMemoryBytes(long value) + { + if (value != -1 && value <= 0) + { + throw new ArgumentOutOfRangeException( + nameof(value), + "MaxContainerMemoryBytes must be positive or -1 for auto."); + } + + _maxContainerMemoryBytes = value; + return this; + } + /// /// Sets the maximum accepted field count in one received struct TypeMeta. /// @@ -235,6 +266,7 @@ private Config BuildConfig() compatible: compatible, checkStructVersion: compatible ? false : _checkStructVersion, maxDepth: _maxDepth, + maxContainerMemoryBytes: _maxContainerMemoryBytes, maxTypeFields: _maxTypeFields, maxTypeMetaBytes: _maxTypeMetaBytes, maxSchemaVersionsPerType: _maxSchemaVersionsPerType, diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index 5aa49dfa75..bdcec3222a 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -214,9 +214,11 @@ public override TDictionary ReadData(ReadContext context) int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + context.ReserveMapMemory(totalLength); return CreateMap(0); } + context.ReserveNonEmptyMapMemory(totalLength); context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 9bbafd1775..edcfdb13b2 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -190,6 +190,7 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); + _readContext.InitContainerBudgetKnown(payload.Length); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -210,6 +211,7 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); + _readContext.InitContainerBudgetKnown(payload.Length); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -230,6 +232,7 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); + _readContext.InitContainerBudgetKnown(bytes.Length); T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index d6c8caab47..fe573cae02 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -537,9 +537,11 @@ public override NullableKeyDictionary ReadData(ReadContext context int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + context.ReserveMapMemory(totalLength); return new NullableKeyDictionary(); } + context.ReserveNonEmptyMapMemory(totalLength); context.Reader.CheckBound(totalLength); NullableKeyDictionary map = new(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index a136bd57bd..e753280f72 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -672,9 +672,11 @@ public static TMap ReadMap( int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + context.ReserveMapMemory(totalLength); return TMapOps.Create(0); } + context.ReserveNonEmptyMapMemory(totalLength); context.Reader.CheckBound(totalLength); TMap map = TMapOps.Create(totalLength); TypeId keyTypeId = TKeyCodec.WireTypeId; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index f83ac0e99e..31cc878714 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -15,11 +15,21 @@ // specific language governing permissions and limitations // under the License. +using System.ComponentModel; +using System.Runtime.CompilerServices; + namespace Apache.Fory; public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; + internal const long KnownContainerBudgetSlackBytes = 64 * 1024; + internal const long UnknownContainerBudgetBytes = 128L * 1024 * 1024; + internal const int ContainerFixedBytes = 32; + internal const int ArrayHeaderBytes = 24; + internal const int ReferenceBytes = 4; + internal const int CollectionEntryOverheadBytes = 16; + internal const int MapEntryOverheadBytes = 24; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -40,6 +50,8 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; + private long _containerMemoryLimitBytes = long.MaxValue; + private long _remainingContainerMemoryBytes = long.MaxValue; public ReadContext( ByteReader reader, @@ -70,6 +82,134 @@ public ReadContext( internal RefReader RefReader { get; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int ElementBytes() => ContainerElementBytes.Value; + + private static class ContainerElementBytes + { + internal static readonly int Value = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + } + + private static class MapElementBytes + { + internal static readonly int Value = + ElementBytes() + ElementBytes() + MapEntryOverheadBytes + ReferenceBytes; + } + + /// + /// Reserves estimated list-owned memory for generated serializer code. + /// Configure instead of calling this directly. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveListMemory(int length) + { + if (length == 0) + { + ReserveContainerMemory(ContainerFixedBytes); + return; + } + + ReserveNonEmptyListMemory(length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveNonEmptyListMemory(int length) + { + ReserveContainerMemory((long)(uint)length * ElementBytes() + ContainerFixedBytes + ArrayHeaderBytes); + } + + /// + /// Reserves estimated array-owned memory for generated serializer code. + /// Configure instead of calling this directly. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveArrayMemory(int length) + { + ReserveCountedContainerMemory( + length, + ArrayHeaderBytes, + ElementBytes()); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveLinkedCollectionMemory(int length) + { + if (length == 0) + { + ReserveContainerMemory(ContainerFixedBytes); + return; + } + + ReserveContainerMemory( + (long)(uint)length * (ElementBytes() + CollectionEntryOverheadBytes + ReferenceBytes * 2) + + ContainerFixedBytes); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveMapMemory(int length) + { + if (length == 0) + { + ReserveContainerMemory(ContainerFixedBytes); + return; + } + + ReserveNonEmptyMapMemory(length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveNonEmptyMapMemory(int length) + { + ReserveContainerMemory( + (long)(uint)length * MapElementBytes.Value + ContainerFixedBytes + ArrayHeaderBytes * 2); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void InitContainerBudgetKnown(int rootBytes) + { + long limit = _config.MaxContainerMemoryBytes; + if (limit < 0) + { + limit = (long)rootBytes * 8 + KnownContainerBudgetSlackBytes; + } + + _containerMemoryLimitBytes = limit; + _remainingContainerMemoryBytes = limit; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveContainerMemory(long bytes) + { + long remaining = _remainingContainerMemoryBytes; + if ((ulong)bytes > (ulong)remaining) + { + ThrowContainerBudgetExceeded(bytes, remaining, _containerMemoryLimitBytes); + } + + _remainingContainerMemoryBytes = remaining - bytes; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveCountedContainerMemory(int count, int fixedBytes, int elementBytes) + { + ReserveContainerMemory((long)(uint)count * elementBytes + fixedBytes); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowContainerBudgetOverflow() + { + throw new InvalidDataException("container memory estimate overflows"); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowContainerBudgetExceeded(long bytes, long remaining, long limit) + { + throw new InvalidDataException( + $"estimated container memory request {bytes} bytes exceeds MaxContainerMemoryBytes remaining budget {remaining} bytes out of effective limit {limit} bytes"); + } + internal void ResetFor(ByteReader reader) { Reader = reader; diff --git a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs new file mode 100644 index 0000000000..dabaa03b28 --- /dev/null +++ b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers; +using Apache.Fory; +using ForyRuntime = Apache.Fory.Fory; + +namespace Apache.Fory.Tests; + +[ForyStruct] +public sealed class BudgetItem +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; +} + +[ForyStruct] +public sealed class BudgetSiblings +{ + public List Left { get; set; } = []; + public List Right { get; set; } = []; +} + +[ForyStruct] +public sealed class BudgetArrayHolder +{ + public BudgetItem[] Values { get; set; } = []; +} + +public sealed class ContainerMemoryBudgetTests +{ + private static ForyRuntime NewFory(long maxContainerMemoryBytes = -1) + { + return ForyRuntime.Builder() + .Compatible(false) + .TrackRef(false) + .MaxContainerMemoryBytes(maxContainerMemoryBytes) + .Build() + .Register(1001) + .Register(1002) + .Register(1003); + } + + private static byte[] Serialize(T value) + { + return NewFory().Serialize(value); + } + + private static long ListBudget(int count) + { + return count == 0 + ? ReadContext.ContainerFixedBytes + : ReadContext.ContainerFixedBytes + ReadContext.ArrayHeaderBytes + + (long)count * ReadContext.ElementBytes(); + } + + private static long ArrayBudget(int count) + { + return ReadContext.ArrayHeaderBytes + (long)count * ReadContext.ElementBytes(); + } + + private static long MapBudget(int count) + { + return count == 0 + ? ReadContext.ContainerFixedBytes + : ReadContext.ContainerFixedBytes + ReadContext.ArrayHeaderBytes * 2 + + (long)count * (ReadContext.ElementBytes() + ReadContext.ElementBytes() + + ReadContext.MapEntryOverheadBytes + ReadContext.ReferenceBytes); + } + + [Fact] + public void KnownLengthAutoBudgetRejectsLargeNestedEmpties() + { + const int count = 3000; + List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); + byte[] bytes = Serialize(value); + long autoLimit = bytes.LongLength * 8 + ReadContext.KnownContainerBudgetSlackBytes; + long required = ListBudget>(count) + count * ListBudget(0); + Assert.True(required > autoLimit); + + Assert.Throws(() => NewFory().Deserialize>>(bytes)); + + List> result = NewFory(required).Deserialize>>(bytes); + Assert.Equal(count, result.Count); + } + + [Fact] + public void ReadOnlySequenceUsesKnownLengthAutoBudget() + { + const int count = 3000; + List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); + byte[] bytes = Serialize(value); + ReadOnlySequence sequence = new(bytes); + + Assert.Throws(() => NewFory().Deserialize>>(ref sequence)); + } + + [Fact] + public void ExplicitConfigOverridesAutoBudget() + { + List value = Enumerable.Range(0, 8).Select(i => new BudgetItem { Id = i }).ToList(); + byte[] bytes = Serialize(value); + long required = ListBudget(value.Count); + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + List result = NewFory(required).Deserialize>(bytes); + Assert.Equal(value.Count, result.Count); + } + + [Fact] + public void SiblingContainersShareOneBudget() + { + BudgetSiblings value = new() + { + Left = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), + Right = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), + }; + byte[] bytes = Serialize(value); + long oneList = ListBudget(16); + + Assert.Throws(() => NewFory(oneList).Deserialize(bytes)); + BudgetSiblings result = NewFory(oneList * 2).Deserialize(bytes); + Assert.Equal(16, result.Left.Count); + Assert.Equal(16, result.Right.Count); + } + + [Fact] + public void MapBudgetIsCharged() + { + Dictionary value = new() { ["a"] = 1, ["b"] = 2, ["c"] = 3 }; + byte[] bytes = Serialize(value); + long required = MapBudget(value.Count); + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + Dictionary result = NewFory(required).Deserialize>(bytes); + Assert.Equal(value, result); + } + + [Fact] + public void ReferenceArrayAndInlineValueListAreCharged() + { + BudgetArrayHolder holder = new() + { + Values = Enumerable.Range(0, 4).Select(i => new BudgetItem { Id = i }).ToArray(), + }; + byte[] holderBytes = Serialize(holder); + long holderRequired = ListBudget(4) + ArrayBudget(4); + Assert.Throws(() => NewFory(holderRequired - 1).Deserialize(holderBytes)); + Assert.Equal(4, NewFory(holderRequired).Deserialize(holderBytes).Values.Length); + + List ints = [1, 2, 3, 4]; + byte[] intBytes = Serialize(ints); + long listRequired = ListBudget(ints.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize>(intBytes)); + Assert.Equal(ints, NewFory(listRequired).Deserialize>(intBytes)); + } + + [Fact] + public void DenseStringBinaryAndPrimitiveArraysAreSkipped() + { + Assert.Equal("budget", NewFory(1).Deserialize(Serialize("budget"))); + Assert.Equal(new byte[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new byte[] { 1, 2, 3 }))); + Assert.Equal(new[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new[] { 1, 2, 3 }))); + } + + [Fact] + public void ByteAvailabilityCheckStillRejectsLargeLength() + { + byte[] bytes = [64, 0]; + ReadContext context = new(new ByteReader(bytes), new TypeResolver(), NewFory().Config); + + Assert.Throws(() => new ListSerializer().ReadData(context)); + } +} diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index d5d248cd36..6f529ecc3c 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -28,6 +28,7 @@ final class Config { static const int defaultMaxTypeMetaBytes = 4096; static const int defaultMaxSchemaVersionsPerType = 10; static const int defaultMaxAverageSchemaVersionsPerType = 3; + static const int defaultMaxContainerMemoryBytes = -1; /// Enables compatible struct encoding and decoding. /// @@ -56,6 +57,11 @@ final class Config { /// types. final int maxAverageSchemaVersionsPerType; + /// Maximum estimated container-owned memory per root deserialization. + /// + /// `-1` means auto. Positive values are explicit byte limits. + final int maxContainerMemoryBytes; + /// Creates an immutable configuration object. /// /// Invalid numeric limits fail fast. When [compatible] is `true`, @@ -69,6 +75,7 @@ final class Config { this.maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, this.maxAverageSchemaVersionsPerType = defaultMaxAverageSchemaVersionsPerType, + this.maxContainerMemoryBytes = defaultMaxContainerMemoryBytes, }) : checkStructVersion = compatible ? false : checkStructVersion, assert(maxDepth > 0, 'maxDepth must be positive'), assert(maxTypeFields > 0, 'maxTypeFields must be positive'), @@ -80,5 +87,9 @@ final class Config { assert( maxAverageSchemaVersionsPerType > 0, 'maxAverageSchemaVersionsPerType must be positive', + ), + assert( + maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, + 'maxContainerMemoryBytes must be -1 or positive', ); } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index faa8191aba..1acf28c0d6 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -45,6 +45,15 @@ import 'package:fory/src/types/uint64.dart'; /// deserialization operation. Application code normally interacts with [Fory] /// instead of preparing contexts directly. final class ReadContext { + static const int _knownRootBudgetMultiplier = 8; + static const int _knownRootBudgetSlackBytes = 64 * 1024; + static const int _collectionObjectBytes = 24; + static const int _mapObjectBytes = 48; + static const int _arrayHeaderBytes = 16; + static const int _mapEntryBytes = 32; + static const int _referenceBytes = 4; + static const int _maxSafeBudgetBytes = 9007199254740991; + /// Effective runtime configuration for the active operation. final Config config; final TypeResolver _typeResolver; @@ -54,6 +63,8 @@ final class ReadContext { late Buffer _buffer; final List _sharedTypes = []; int _depth = 0; + int _effectiveContainerMemoryBytes = 0; + int _remainingContainerMemoryBytes = 0; @internal ReadContext( @@ -64,8 +75,20 @@ final class ReadContext { ); @internal + @pragma('vm:prefer-inline') void prepare(Buffer buffer) { _buffer = buffer; + final configured = config.maxContainerMemoryBytes; + final limit = + configured > 0 + ? configured + : buffer.readableBytes * _knownRootBudgetMultiplier + + _knownRootBudgetSlackBytes; + if (limit > _maxSafeBudgetBytes) { + _throwContainerMemoryOverflow(limit); + } + _effectiveContainerMemoryBytes = limit; + _remainingContainerMemoryBytes = limit; } @internal @@ -74,6 +97,8 @@ final class ReadContext { _refReader.reset(); _metaStringReader.reset(); _depth = 0; + _effectiveContainerMemoryBytes = 0; + _remainingContainerMemoryBytes = 0; } /// The active input buffer for the current operation. @@ -85,6 +110,76 @@ final class ReadContext { @internal RefReader get refReader => _refReader; + @internal + int get effectiveContainerMemoryBytes => _effectiveContainerMemoryBytes; + + @internal + int get remainingContainerMemoryBytes => _remainingContainerMemoryBytes; + + @internal + @pragma('vm:prefer-inline') + void reserveCollectionMemory(int numElements) { + final bytes = _collectionObjectBytes + numElements * _referenceBytes; + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @internal + @pragma('vm:prefer-inline') + void reserveMapMemory(int numElements) { + final bytes = + _mapObjectBytes + + numElements * + (_referenceBytes * 2 + _mapEntryBytes + _referenceBytes * 3); + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @internal + @pragma('vm:prefer-inline') + void reserveTypedArrayMemory(int numElements, int elementBytes) { + final bytes = _arrayHeaderBytes + numElements * elementBytes; + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @internal + void reserveContainerMemory(int bytes) { + if (bytes < 0 || bytes > _maxSafeBudgetBytes) { + _throwContainerMemoryOverflow(bytes); + } + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @pragma('vm:never-inline') + Never _throwContainerMemoryOverflow(int bytes) { + throw StateError( + 'maxContainerMemoryBytes overflow: requested $bytes estimated container bytes.', + ); + } + + @pragma('vm:never-inline') + Never _throwContainerMemoryExceeded(int bytes) { + throw StateError( + 'maxContainerMemoryBytes exceeded: requested $bytes estimated container bytes, ' + '$_remainingContainerMemoryBytes remaining, effective limit ' + '$_effectiveContainerMemoryBytes.', + ); + } + @internal @pragma('vm:prefer-inline') TypeInfo readTypeMetaValue([TypeInfo? expectedNamedType]) => diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index adc6091a8d..48a5f9b133 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -62,7 +62,16 @@ final class Fory { int maxSchemaVersionsPerType = Config.defaultMaxSchemaVersionsPerType, int maxAverageSchemaVersionsPerType = Config.defaultMaxAverageSchemaVersionsPerType, + int maxContainerMemoryBytes = Config.defaultMaxContainerMemoryBytes, }) { + if (maxContainerMemoryBytes != Config.defaultMaxContainerMemoryBytes && + maxContainerMemoryBytes <= 0) { + throw ArgumentError.value( + maxContainerMemoryBytes, + 'maxContainerMemoryBytes', + 'must be -1 or positive', + ); + } final config = Config( compatible: compatible, checkStructVersion: checkStructVersion, @@ -71,6 +80,7 @@ final class Fory { maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType, + maxContainerMemoryBytes: maxContainerMemoryBytes, ); _readBuffer = Buffer(); _writeBuffer = Buffer(); diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 4e2a8050c0..b80839f2d6 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -270,10 +270,10 @@ final class ListSerializer extends Serializer { } final declaredTypeInfo = elementFieldType == null || - elementFieldType.isDynamic || - elementFieldType.typeId == TypeIds.unknown - ? null - : context.typeResolver.resolveFieldType(elementFieldType); + elementFieldType.isDynamic || + elementFieldType.typeId == TypeIds.unknown + ? null + : context.typeResolver.resolveFieldType(elementFieldType); final usesDeclaredType = declaredTypeInfo != null && usesDeclaredTypeInfo( @@ -296,8 +296,9 @@ final class ListSerializer extends Serializer { sameType: analysis.sameType, ); context.buffer.writeUint8(header); - final sameTypeInfo = - !usesDeclaredType && analysis.sameType ? analysis.sameTypeInfo : null; + final sameTypeInfo = !usesDeclaredType && analysis.sameType + ? analysis.sameTypeInfo + : null; if (!usesDeclaredType && sameTypeInfo != null && analysis.firstNonNull != null) { @@ -378,13 +379,13 @@ final class SetSerializer extends Serializer { FieldType? elementFieldType, { bool hasPreservedRef = false, }) { - return Set.of( - ListSerializer.readPayload( - context, - elementFieldType, - hasPreservedRef: hasPreservedRef, - ), + final values = ListSerializer.readPayload( + context, + elementFieldType, + hasPreservedRef: hasPreservedRef, ); + context.reserveCollectionMemory(values.length); + return Set.of(values); } } @@ -401,8 +402,9 @@ Object? readCompatibleMatchedCollectionArrayField( final remoteType = remoteField.fieldType; if (isCompatibleArrayType(localType.typeId) && remoteType.typeId == TypeIds.list) { - final elementType = - remoteType.arguments.isEmpty ? null : remoteType.arguments.single; + final elementType = remoteType.arguments.isEmpty + ? null + : remoteType.arguments.single; if (elementType == null || _arrayElementTypeId(localType.typeId) != _compatibleArrayElementTypeId(elementType.typeId)) { @@ -419,8 +421,9 @@ Object? readCompatibleMatchedCollectionArrayField( } if (localType.typeId == TypeIds.list && isCompatibleArrayType(remoteType.typeId)) { - final localElementType = - localType.arguments.isEmpty ? null : localType.arguments.single; + final localElementType = localType.arguments.isEmpty + ? null + : localType.arguments.single; if (localElementType == null || _arrayElementTypeId(remoteType.typeId) != _compatibleArrayElementTypeId(localElementType.typeId)) { @@ -429,7 +432,7 @@ Object? readCompatibleMatchedCollectionArrayField( ); } final raw = readCompatibleField(context, remoteField); - return _arrayToListValue(raw); + return _arrayToListValue(context, raw); } return readFieldValue(context, localField); } @@ -490,8 +493,9 @@ bool _listElementMatchesArray( int arrayTypeId, { required bool requireUnframedElement, }) { - final elementType = - listType.arguments.isEmpty ? null : listType.arguments.single; + final elementType = listType.arguments.isEmpty + ? null + : listType.arguments.single; // Nullable element schema is allowed for list -> array; actual // null payload elements fail in the dense-array reader. Ref-tracked // element framing is rejected here because this path stays primitive-only. @@ -508,6 +512,7 @@ Object _readCompatibleListAsArrayField( String fieldName, ) { final size = context.buffer.readVarUint32(); + context.reserveTypedArrayMemory(size, _arrayElementBytes(arrayTypeId)); if (size == 0) { return _newArrayValue(arrayTypeId, 0); } @@ -570,6 +575,21 @@ int _compatibleArrayElementTypeId(int typeId) { }; } +int _arrayElementBytes(int arrayTypeId) { + return switch (arrayTypeId) { + TypeIds.boolArray || TypeIds.int8Array || TypeIds.uint8Array => 1, + TypeIds.int16Array || + TypeIds.uint16Array || + TypeIds.float16Array || + TypeIds.bfloat16Array => 2, + TypeIds.int32Array || TypeIds.uint32Array || TypeIds.float32Array => 4, + TypeIds.int64Array || TypeIds.uint64Array || TypeIds.float64Array => 8, + _ => throw StateError( + 'Unsupported compatible array field type $arrayTypeId.', + ), + }; +} + Object _newArrayValue(int arrayTypeId, int length) { return switch (arrayTypeId) { TypeIds.boolArray => BoolList(length), @@ -585,8 +605,9 @@ Object _newArrayValue(int arrayTypeId, int length) { TypeIds.bfloat16Array => Bfloat16List(length), TypeIds.float32Array => Float32List(length), TypeIds.float64Array => Float64List(length), - _ => - throw StateError('Unsupported compatible array field type $arrayTypeId.'), + _ => throw StateError( + 'Unsupported compatible array field type $arrayTypeId.', + ), }; } @@ -601,8 +622,9 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { case TypeIds.int32Array: (target as Int32List)[index] = value as int; case TypeIds.int64Array: - (target as Int64List)[index] = - value is int ? Int64(value) : value as Int64; + (target as Int64List)[index] = value is int + ? Int64(value) + : value as Int64; case TypeIds.uint8Array: (target as Uint8List)[index] = value as int; case TypeIds.uint16Array: @@ -610,8 +632,9 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { case TypeIds.uint32Array: (target as Uint32List)[index] = value as int; case TypeIds.uint64Array: - (target as Uint64List)[index] = - value is int ? Uint64(value) : value as Uint64; + (target as Uint64List)[index] = value is int + ? Uint64(value) + : value as Uint64; case TypeIds.float16Array: (target as Float16List)[index] = value as double; case TypeIds.bfloat16Array: @@ -625,11 +648,13 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { } } -Object _arrayToListValue(Object? raw) { +Object _arrayToListValue(ReadContext context, Object? raw) { if (raw is BoolList) { + context.reserveCollectionMemory(raw.length); return raw.toList(); } if (raw is Iterable) { + context.reserveCollectionMemory(raw.length); return raw.toList(); } throw StateError('Expected compatible array payload.'); @@ -650,29 +675,29 @@ List readTypedListPayload( } final directTypeInfo = state.declaredTypeInfo ?? state.sameTypeInfo; if (directTypeInfo != null && !state.trackRef && !state.hasNull) { - final directFieldType = - state.declaredTypeInfo != null ? state.elementFieldType : null; + final directFieldType = state.declaredTypeInfo != null + ? state.elementFieldType + : null; if (directTypeInfo.type == T && directTypeInfo.kind == RegistrationKind.struct) { final structSerializer = directTypeInfo.structSerializer!; context.buffer.checkReadableBytes(state.size); - final result = - directTypeInfo.remoteTypeDef == null - ? List.generate( - state.size, - (_) => structSerializer.readValue(context, directTypeInfo) as T, - growable: false, - ) - : List.generate( - state.size, - (_) => - structSerializer.readGeneratedCompatibleValue( - context, - directTypeInfo, - ) - as T, - growable: false, - ); + final result = directTypeInfo.remoteTypeDef == null + ? List.generate( + state.size, + (_) => structSerializer.readValue(context, directTypeInfo) as T, + growable: false, + ) + : List.generate( + state.size, + (_) => + structSerializer.readGeneratedCompatibleValue( + context, + directTypeInfo, + ) + as T, + growable: false, + ); if (state.tracksDepth) { context.decreaseDepth(); } @@ -719,7 +744,9 @@ Set readTypedSetPayload( FieldType? elementFieldType, T Function(Object? value) convert, ) { - return Set.of(readTypedListPayload(context, elementFieldType, convert)); + final values = readTypedListPayload(context, elementFieldType, convert); + context.reserveCollectionMemory(values.length); + return Set.of(values); } void writeTypedListPayload( @@ -910,6 +937,7 @@ _PreparedListRead _prepareListRead( FieldType? elementFieldType, ) { final size = context.buffer.readVarUint32(); + context.reserveCollectionMemory(size); if (size == 0) { return _PreparedListRead( size: 0, @@ -936,15 +964,13 @@ _PreparedListRead _prepareListRead( elementFieldType != null && (usesDeclaredType || (sameType && TypeIds.isUserType(elementFieldType.typeId))); - final expectedElementTypeInfo = - needsExpectedElementType - ? context.typeResolver.tryResolveFieldType(elementFieldType) - : null; + final expectedElementTypeInfo = needsExpectedElementType + ? context.typeResolver.tryResolveFieldType(elementFieldType) + : null; final declaredTypeInfo = usesDeclaredType ? expectedElementTypeInfo : null; - final sameTypeInfo = - (!usesDeclaredType && sameType) - ? context.readTypeMetaValue(expectedElementTypeInfo) - : null; + final sameTypeInfo = (!usesDeclaredType && sameType) + ? context.readTypeMetaValue(expectedElementTypeInfo) + : null; final tracksDepth = (declaredTypeInfo != null && tracksNestedPayloadDepth(declaredTypeInfo)) || diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index 051454c3d6..0391699b23 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -56,14 +56,13 @@ final class MapSerializer extends Serializer { required bool trackRef, }) { context.buffer.writeVarUint32(values.length); - final declaredKeyTypeInfo = - keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(valueFieldType); + ? null + : context.typeResolver.resolveFieldType(valueFieldType); final keyDeclared = declaredKeyTypeInfo != null && usesDeclaredTypeInfo( @@ -106,17 +105,17 @@ final class MapSerializer extends Serializer { (keyDeclared ? declaredKeyTypeInfo.supportsRef : (key == null || - context.typeResolver - .resolveValue(key as Object) - .supportsRef)); + context.typeResolver + .resolveValue(key as Object) + .supportsRef)); final valueTrackRef = valueRequestedRef && (valueDeclared ? declaredValueTypeInfo.supportsRef : (value == null || - context.typeResolver - .resolveValue(value as Object) - .supportsRef)); + context.typeResolver + .resolveValue(value as Object) + .supportsRef)); _writeNullChunk( context, key, @@ -132,14 +131,12 @@ final class MapSerializer extends Serializer { ); continue; } - final chunkKeyTypeInfo = - keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(key as Object); - final chunkValueTypeInfo = - valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(value as Object); + final chunkKeyTypeInfo = keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(key as Object); + final chunkValueTypeInfo = valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(value as Object); final chunkKeyTrackRef = keyRequestedRef && chunkKeyTypeInfo.supportsRef; final chunkValueTrackRef = valueRequestedRef && chunkValueTypeInfo.supportsRef; @@ -189,14 +186,12 @@ final class MapSerializer extends Serializer { pendingEntry = nextEntry; break; } - final nextKeyTypeInfo = - keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(nextKey as Object); - final nextValueTypeInfo = - valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(nextValue as Object); + final nextKeyTypeInfo = keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(nextKey as Object); + final nextValueTypeInfo = valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(nextValue as Object); final nextKeyTrackRef = keyRequestedRef && nextKeyTypeInfo.supportsRef; final nextValueTrackRef = valueRequestedRef && nextValueTypeInfo.supportsRef; @@ -257,14 +252,15 @@ Map readTypedMapPayload( bool hasPreservedRef = false, }) { var remaining = context.buffer.readVarUint32(); - final declaredKeyTypeInfo = - keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + context.reserveMapMemory(remaining); + context.buffer.checkReadableBytes(remaining); + final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(valueFieldType); + ? null + : context.typeResolver.resolveFieldType(valueFieldType); final result = {}; if (hasPreservedRef) { context.reference(result); @@ -312,34 +308,32 @@ Map readTypedMapPayload( context.increaseDepth(); } for (var index = 0; index < chunkSize; index += 1) { - final key = - keyDeclared - ? _readDeclaredMapValue( - context, - keyFieldType!, - declaredKeyTypeInfo!, - trackRef: keyTrackRef, - ) - : _readResolvedMapValue( - context, - keyTypeInfo!, - null, - trackRef: keyTrackRef, - ); - final value = - valueDeclared - ? _readDeclaredMapValue( - context, - valueFieldType!, - declaredValueTypeInfo!, - trackRef: valueTrackRef, - ) - : _readResolvedMapValue( - context, - valueTypeInfo!, - null, - trackRef: valueTrackRef, - ); + final key = keyDeclared + ? _readDeclaredMapValue( + context, + keyFieldType!, + declaredKeyTypeInfo!, + trackRef: keyTrackRef, + ) + : _readResolvedMapValue( + context, + keyTypeInfo!, + null, + trackRef: keyTrackRef, + ); + final value = valueDeclared + ? _readDeclaredMapValue( + context, + valueFieldType!, + declaredValueTypeInfo!, + trackRef: valueTrackRef, + ) + : _readResolvedMapValue( + context, + valueTypeInfo!, + null, + trackRef: valueTrackRef, + ); result[convertKey(key)] = convertValue(value); } if (tracksDepth) { diff --git a/dart/packages/fory/test/container_memory_budget_test.dart b/dart/packages/fory/test/container_memory_budget_test.dart new file mode 100644 index 0000000000..61d6970300 --- /dev/null +++ b/dart/packages/fory/test/container_memory_budget_test.dart @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +import 'package:fory/fory.dart'; +import 'package:fory/src/context/meta_string_reader.dart'; +import 'package:fory/src/context/ref_reader.dart'; +import 'package:fory/src/resolver/type_resolver.dart'; +import 'package:fory/src/serializer/collection_serializers.dart'; +import 'package:fory/src/serializer/map_serializers.dart'; +import 'package:test/test.dart'; + +part 'container_memory_budget_test.fory.dart'; + +const Matcher _throwsContainerBudget = ThrowsContainerBudget(); + +@ForyStruct() +class BudgetGeneratedEnvelope { + BudgetGeneratedEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List ids = []; + + @SetField(element: StringType()) + Set tags = {}; + + @MapField( + key: StringType(), + value: Int32Type(encoding: Encoding.fixed), + ) + Map counts = {}; +} + +@ForyStruct() +class BudgetCompatibleListEnvelope { + BudgetCompatibleListEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List values = []; +} + +@ForyStruct() +class BudgetCompatibleArrayEnvelope { + BudgetCompatibleArrayEnvelope(); + + @ArrayField(element: Int32Type()) + Int32List values = Int32List(0); +} + +final class ThrowsContainerBudget extends Matcher { + const ThrowsContainerBudget(); + + @override + Description describe(Description description) { + return description.add('throws a maxContainerMemoryBytes StateError'); + } + + @override + bool matches(Object? item, Map matchState) { + if (item is! Function) { + return false; + } + try { + item(); + } on StateError catch (error) { + return error.message.contains('maxContainerMemoryBytes'); + } + return false; + } +} + +void _registerGenerated(Fory fory) { + ContainerMemoryBudgetTestForyModule.register( + fory, + BudgetGeneratedEnvelope, + name: 'test.BudgetGeneratedEnvelope', + ); +} + +void _registerCompatibleList(Fory fory) { + ContainerMemoryBudgetTestForyModule.register( + fory, + BudgetCompatibleListEnvelope, + name: 'test.BudgetCompatibleEnvelope', + ); +} + +void _registerCompatibleArray(Fory fory) { + ContainerMemoryBudgetTestForyModule.register( + fory, + BudgetCompatibleArrayEnvelope, + name: 'test.BudgetCompatibleEnvelope', + ); +} + +ReadContext _readContext(Buffer buffer, {int maxContainerMemoryBytes = -1}) { + final config = Config(maxContainerMemoryBytes: maxContainerMemoryBytes); + final resolver = TypeResolver(config); + return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) + ..prepare(buffer); +} + +Uint8List _serialize(Object? value) => Fory().serialize(value); + +Object? _readWithBudget(Object? value, int budget) { + return Fory( + maxContainerMemoryBytes: budget, + ).deserialize(_serialize(value)); +} + +void main() { + group('container memory budget', () { + test('known length auto derives from input bytes', () { + final buffer = Buffer.wrap(Uint8List(17)); + final context = _readContext(buffer); + + expect(context.effectiveContainerMemoryBytes, equals(17 * 8 + 64 * 1024)); + expect( + () => context.reserveContainerMemory(17 * 8 + 64 * 1024), + returnsNormally, + ); + expect(() => context.reserveContainerMemory(1), _throwsContainerBudget); + }); + + test('explicit config overrides auto', () { + final buffer = Buffer.wrap(Uint8List(4096)); + final context = _readContext(buffer, maxContainerMemoryBytes: 31); + + expect(context.effectiveContainerMemoryBytes, equals(31)); + expect(() => context.reserveContainerMemory(31), returnsNormally); + expect(() => context.reserveContainerMemory(1), _throwsContainerBudget); + expect(() => Fory(maxContainerMemoryBytes: 0), throwsArgumentError); + expect(() => Fory(maxContainerMemoryBytes: -2), throwsArgumentError); + }); + + test('charges nested empty containers', () { + final value = [[]]; + + expect(() => _readWithBudget(value, 51), _throwsContainerBudget); + expect(_readWithBudget(value, 52), equals(value)); + }); + + test('charges sibling containers cumulatively', () { + final value = [[], [], []]; + + expect(() => _readWithBudget(value, 107), _throwsContainerBudget); + expect(_readWithBudget(value, 108), equals(value)); + }); + + test('charges map table and entries', () { + final value = {'a': 1}; + + expect(() => _readWithBudget(value, 99), _throwsContainerBudget); + expect(_readWithBudget(value, 100), equals(value)); + }); + + test('charges generated list set and map reads', () { + final writer = Fory(); + _registerGenerated(writer); + final bytes = writer.serialize( + BudgetGeneratedEnvelope() + ..ids = [1] + ..tags = {'x'} + ..counts = {'one': 1}, + ); + + final failingReader = Fory(maxContainerMemoryBytes: 183); + _registerGenerated(failingReader); + expect( + () => failingReader.deserialize(bytes), + _throwsContainerBudget, + ); + + final passingReader = Fory(maxContainerMemoryBytes: 184); + _registerGenerated(passingReader); + final roundTrip = passingReader.deserialize( + bytes, + ); + expect(roundTrip.ids, equals([1])); + expect(roundTrip.tags, equals({'x'})); + expect(roundTrip.counts, equals({'one': 1})); + }); + + test('charges compatible list array materialization', () { + final listWriter = Fory(); + _registerCompatibleList(listWriter); + final listBytes = listWriter.serialize( + BudgetCompatibleListEnvelope()..values = [1, 2, 3], + ); + + final arrayFail = Fory(maxContainerMemoryBytes: 27); + _registerCompatibleArray(arrayFail); + expect( + () => arrayFail.deserialize(listBytes), + _throwsContainerBudget, + ); + + final arrayPass = Fory(maxContainerMemoryBytes: 28); + _registerCompatibleArray(arrayPass); + expect( + arrayPass + .deserialize(listBytes) + .values + .toList(), + equals([1, 2, 3]), + ); + + final arrayWriter = Fory(); + _registerCompatibleArray(arrayWriter); + final arrayBytes = arrayWriter.serialize( + BudgetCompatibleArrayEnvelope() + ..values = Int32List.fromList([1, 2, 3]), + ); + + final listFail = Fory(maxContainerMemoryBytes: 35); + _registerCompatibleList(listFail); + expect( + () => listFail.deserialize(arrayBytes), + _throwsContainerBudget, + ); + + final listPass = Fory(maxContainerMemoryBytes: 36); + _registerCompatibleList(listPass); + expect( + listPass.deserialize(arrayBytes).values, + equals([1, 2, 3]), + ); + }); + + test('skips strings binary and dense typed arrays', () { + final fory = Fory(maxContainerMemoryBytes: 1); + final text = List.filled(128, 'x').join(); + + expect(fory.deserialize(Fory().serialize(text)), hasLength(128)); + expect( + fory.deserialize(Fory().serialize(Uint8List(128))).length, + equals(128), + ); + expect( + fory.deserialize(Fory().serialize(Int32List(32))).length, + equals(32), + ); + }); + + test('keeps byte availability checks before allocation', () { + final listBuffer = Buffer() + ..writeVarUint32(64) + ..writeUint8(0); + final listContext = _readContext(listBuffer); + expect( + () => ListSerializer.readPayload(listContext, null), + throwsStateError, + ); + + final mapBuffer = Buffer()..writeVarUint32(64); + final mapContext = _readContext(mapBuffer); + expect( + () => MapSerializer.readPayload(mapContext, null, null), + throwsStateError, + ); + }); + }); +} diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index d617450041..aee6e633d5 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -96,6 +96,29 @@ When enabled, avoids duplicating shared objects and handles cycles. **Default:** `true` +### max_container_memory_bytes(int64_t) + +Set the maximum estimated memory that container objects may reserve during one +root deserialization. + +```cpp +auto fory = Fory::builder() + .max_container_memory_bytes(64 * 1024 * 1024) + .build(); +``` + +Use `-1` for the automatic limit. For byte-array and `Buffer` roots, the +automatic limit is the root input size multiplied by `8`, plus `64 KiB`. For +stream roots, the automatic limit is `128 MiB` because the full root size is not +known up front. Positive values always override the automatic limit. + +This budget is an estimate for container-owned memory such as collection +objects, backing storage, map entries, and object/reference arrays. It is not an +exact process heap limit. Dedicated string, binary, and primitive dense-array +payloads continue to rely on their byte-availability checks instead. + +**Default:** `-1` + ### max_dyn_depth(uint32_t) Set maximum allowed nesting depth for dynamically-typed objects. @@ -205,6 +228,7 @@ auto fory = Fory::builder().build_thread_safe(); // Returns ThreadSafeFory | `xlang(bool)` | Use xlang mode | `true` | | `compatible(bool)` | Enable schema evolution | `true` | | `track_ref(bool)` | Enable reference tracking | `true` | +| `max_container_memory_bytes(int64_t)` | Max estimated container memory per root read | `-1` | | `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | | `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | | `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | @@ -218,6 +242,8 @@ Security-related configuration: - Register all structs and polymorphic implementations before deserializing untrusted payloads. - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. +- Leave `max_container_memory_bytes(-1)` enabled for automatic root-size-based container limits, or + set a positive value for a stricter trusted-workload envelope. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index e7c0c24d42..c9e8e80cf6 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -41,6 +41,7 @@ ThreadSafeFory threadSafe = Fory.Builder().BuildThreadSafe(); | `Compatible` | `true` | Compatible schema-evolution metadata enabled | | `CheckStructVersion` | `false` | Struct schema hash checks disabled | | `MaxDepth` | `20` | Max dynamic nesting depth | +| `MaxContainerMemoryBytes` | `-1` | Auto container memory budget | | `MaxTypeFields` | `512` | Max fields in one received struct metadata body | | `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | | `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | @@ -96,6 +97,20 @@ Fory fory = Fory.Builder() `value` must be greater than `0`. +### `MaxContainerMemoryBytes(long value)` + +Sets the maximum estimated container-owned memory accepted during one root deserialization. + +```csharp +Fory fory = Fory.Builder() + .MaxContainerMemoryBytes(64L * 1024 * 1024) + .Build(); +``` + +Use `-1` for the default automatic limit. For current C# inputs, auto uses the root input byte +length times `8`, plus `64 KiB`. A positive value overrides the automatic limit. `0` and negative +values other than `-1` are rejected. + ### `MaxTypeFields(int value)` Sets the maximum fields accepted in one received remote struct metadata body. @@ -173,6 +188,8 @@ Security-related configuration: - Register only the expected types before deserializing untrusted payloads. - Use `CheckStructVersion(true)` with `Compatible(false)` for intentional same-schema payloads. - Set `MaxDepth(...)` to reject unexpectedly deep dynamic object graphs. +- Set `MaxContainerMemoryBytes(...)` to cap estimated list, array, set, and map memory during one + root deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated or registered concrete models over broad dynamic fields for untrusted input. diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index 6a4c640f6a..c7f851e253 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -38,6 +38,7 @@ final fory = Fory( maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, maxAverageSchemaVersionsPerType: 3, + maxContainerMemoryBytes: 64 * 1024 * 1024, ); ``` @@ -107,6 +108,27 @@ final fory = Fory( - `maxAverageSchemaVersionsPerType` limits the average across accepted remote types. The effective global floor is `8192` schemas. +### `maxContainerMemoryBytes` + +Limits estimated container-owned memory for one root deserialization. The budget covers Dart lists, +sets, maps, object/reference arrays, and compatible list/array materialization. It does not count +strings, binary values, or dense typed-array payloads, which are protected by byte-availability +checks. + +The default is `-1`, which means auto. Dart root inputs are memory-backed, so auto derives from the +root input size: + +```text +inputBytes * 8 + 64 KiB +``` + +Set a positive value when a trusted workload legitimately contains compact, container-heavy +payloads: + +```dart +final fory = Fory(maxContainerMemoryBytes: 256 * 1024 * 1024); +``` + ## Defaults | Option | Default | @@ -118,6 +140,7 @@ final fory = Fory( | `maxTypeMetaBytes` | 4096 | | `maxSchemaVersionsPerType` | 10 | | `maxAverageSchemaVersionsPerType` | 3 | +| `maxContainerMemoryBytes` | -1 | ## Xlang Notes @@ -134,6 +157,8 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. +- Keep `maxContainerMemoryBytes` at the auto default for most inputs, or set an explicit positive + byte limit for known trusted container-heavy payloads. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index 20d9012aee..d1cb294c9c 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -39,6 +39,7 @@ Default settings: | MaxDepth | 20 | Maximum nesting depth | | IsXlang | true | Xlang mode enabled | | Compatible | true | Compatible schema-evolution metadata enabled | +| MaxContainerMemoryBytes | -1 | Automatic container memory limit per root read | | MaxTypeFields | 512 | Max fields in one received struct metadata body | | MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | | MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | @@ -51,6 +52,7 @@ f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(10), + fory.WithMaxContainerMemoryBytes(-1), fory.WithMaxTypeFields(512), fory.WithMaxTypeMetaBytes(4096), fory.WithMaxSchemaVersionsPerType(10), @@ -127,6 +129,27 @@ f := fory.New(fory.WithMaxDepth(30)) - Protects against deeply nested, recursive structures or malicious data - Serialization fails with error when exceeded +### WithMaxContainerMemoryBytes + +Limit estimated container-owned memory accepted during one root deserialization: + +```go +f := fory.New(fory.WithMaxContainerMemoryBytes(256 * 1024 * 1024)) +``` + +The default `-1` selects an automatic limit. Byte-slice roots use: + +```text +inputBytes * 8 + 64 KiB +``` + +`DeserializeFromReader` and `DeserializeFromStream` use `128 MiB` because the +full root length is unknown. The budget covers Go slices, maps, sets, and +generated container reads. Strings, binary blobs, and primitive dense array +owners keep their byte-availability checks and are not charged to this budget. +Set a positive value when a service needs a stricter or larger limit for trusted +data. + ### WithMaxTypeFields Set the maximum fields accepted in one received remote struct metadata body: diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index 7b3ce60bf6..4e46d512a7 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,6 +38,7 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | +| `maxContainerMemoryBytes` | Maximum estimated container-owned memory accepted during one root deserialization. `-1` derives an automatic limit from the input shape: known-length inputs use `inputBytes * 8 + 64 KiB`, and stream or unknown-length inputs use `128 MiB`. Positive values set an explicit byte limit. | `-1` | | `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | | `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | | `maxSchemaVersionsPerType` | Maximum accepted remote metadata versions for one logical type. | `10` | @@ -90,6 +91,7 @@ Keep class registration enabled for production and any untrusted payload source: Fory fory = Fory.builder() .requireClassRegistration(true) .withMaxDepth(50) + .withMaxContainerMemoryBytes(-1) .build(); ``` @@ -97,6 +99,9 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. +- `withMaxContainerMemoryBytes(...)` bounds estimated container-owned memory during one root + deserialization. Keep `-1` for the automatic input-shaped default, or set a positive byte limit + when trusted payloads need a larger or smaller limit. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 058bccf4b3..71175c301e 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,6 +43,7 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, + maxContainerMemoryBytes: -1, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -56,6 +57,7 @@ const fory = new Fory({ | `ref` | `false` | Enable reference tracking for shared or circular object graphs | | `compatible` | `true` | Allow field additions/removals without breaking existing messages | | `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | +| `maxContainerMemoryBytes` | `-1` | Maximum estimated container-owned memory accepted during one root deserialization | | `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | | `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | | `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | @@ -92,6 +94,26 @@ to that struct. For cross-language payloads, set `compatible: false` only after verifying that every language uses the same schema, or when native types are generated from Fory schema IDL. See [Schema Evolution](schema-evolution.md). +## Container Memory Budget + +`maxContainerMemoryBytes` limits estimated memory committed by arrays, sets, +maps, and container backing storage during one root deserialization. The default +`-1` derives an automatic limit from the input bytes. JavaScript deserializes +from `Uint8Array` roots, so the automatic limit is `inputBytes * 8 + 64 KiB`. + +Use a positive byte value to set an explicit lower or higher limit: + +```ts +const fory = new Fory({ + maxContainerMemoryBytes: 32 * 1024 * 1024, +}); +``` + +String, binary, and dedicated dense primitive array payloads keep their normal +byte-size checks and do not consume this container budget. Raise the limit only +for trusted workloads that legitimately contain very compact, container-heavy +graphs. + ## Optional HPS String Path `@apache-fory/hps` provides an optional Node.js string fast path: @@ -110,6 +132,8 @@ Security-related configuration: - Register only the expected schemas before deserializing untrusted payloads. - Set `maxDepth` for the maximum nesting depth your service accepts. +- Set `maxContainerMemoryBytes` for the maximum container memory your service + accepts from one root payload. - Keep `maxTypeFields` and `maxTypeMetaBytes` at their defaults unless the data is not malicious and a trusted peer sends larger remote metadata. - Keep `maxSchemaVersionsPerType` and diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index fdd6459fea..26cb42e50f 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -40,6 +40,7 @@ class Fory: max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, + max_container_memory_bytes: int = -1, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -70,6 +71,7 @@ class ThreadSafeFory: | `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | | `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | | `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | +| `max_container_memory_bytes` | `int` | `-1` | Maximum estimated container-owned memory for one root deserialization. `-1` selects the automatic limit. | | `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | | `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | | `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | @@ -197,6 +199,7 @@ fory = pyfory.Fory( max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, + max_container_memory_bytes=-1, ) fory.register(UserModel, name="example.User") @@ -222,6 +225,10 @@ Received remote metadata is also limited: - `max_type_meta_bytes` limits the encoded body bytes accepted for one received TypeDef body. - `max_schema_versions_per_type` limits accepted remote metadata versions for one logical type. - `max_average_schema_versions_per_type` limits the average across accepted remote types. +- `max_container_memory_bytes` limits estimated list, tuple, set, dict, and object-array storage + created during one root deserialization. The default `-1` uses `input_bytes * 8 + 64 KiB` for + known-length inputs and `128 MiB` for stream inputs. Set a positive byte value for trusted + payloads that legitimately contain larger container graphs. These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or schema-evolution semantics. @@ -278,6 +285,7 @@ unchanged. - Register all expected application types before deserialization. - Use `DeserializationPolicy` when `strict=False` is necessary. - Keep `max_depth` low enough to reject unexpectedly deep payloads. +- Keep `max_container_memory_bytes=-1` unless a trusted workload needs a higher explicit limit. - Do not treat xlang/native mode choice as a security control. ## Related Topics diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 58bd070567..6e04126bf0 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -110,6 +110,30 @@ let fory = Fory::builder() - `max_average_schema_versions_per_type` defaults to `3` and limits the average across accepted remote types. The effective global floor is `8192` schemas. +### Container Memory Budget + +`max_container_memory_bytes(...)` limits the estimated memory that deserialization may allocate for +containers such as lists, sets, and maps during one root read. The default is `-1`, which selects an +automatic limit based on the input size: + +```rust +let fory = Fory::builder().max_container_memory_bytes(-1).build(); +``` + +For byte-slice and `Reader` roots, the automatic limit is: + +```text +input bytes * 8 + 64 KiB +``` + +Set a positive byte value when trusted payloads need a larger or smaller limit: + +```rust +let fory = Fory::builder() + .max_container_memory_bytes(256 * 1024 * 1024) + .build(); +``` + ### Explicit Xlang Examples Set `.xlang(true)` explicitly for xlang serialization examples: @@ -135,6 +159,11 @@ let fory = Fory::builder().xlang(false).compatible(false).build(); // Custom depth limit let fory = Fory::builder().max_dyn_depth(10).build(); +// Custom container memory budget +let fory = Fory::builder() + .max_container_memory_bytes(256 * 1024 * 1024) + .build(); + // Combined configuration let fory = Fory::builder() .xlang(false) @@ -149,6 +178,7 @@ let fory = Fory::builder() | `compatible(bool)` | Enable schema evolution | `true` | | `xlang(bool)` | Use xlang mode | `true` | | `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | +| `max_container_memory_bytes(i64)` | Estimated container memory per root read | `-1` | | `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | | `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | | `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | @@ -169,6 +199,8 @@ Security-related configuration: - Register application structs and trait-object implementations before deserializing untrusted payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. +- Keep `max_container_memory_bytes(-1)` for the default input-shaped container budget, or set a + positive byte limit for trusted workloads with larger legitimate containers. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index e3a0478817..40b80744fb 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -31,6 +31,7 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + public let maxContainerMemoryBytes: Int64 public let maxTypeFields: Int public let maxTypeMetaBytes: Int public let maxSchemaVersionsPerType: Int @@ -90,8 +91,14 @@ let fory = Fory(compatible: false, checkClassVersion: true) ### Size and Depth Limits -`maxDepth` bounds decoded payload nesting depth. Compatible-mode remote metadata -is also limited: +`maxDepth` bounds decoded payload nesting depth. + +`maxContainerMemoryBytes` bounds the estimated container-owned memory accepted during one root +deserialization. Use `-1` for the default automatic limit. Swift roots are currently `Data` or +`ByteBuffer`, so auto uses the root input byte length times `8`, plus `64 KiB`. A positive value +overrides the automatic limit. `0` and negative values other than `-1` are rejected. + +Compatible-mode remote metadata is also limited: - `maxTypeFields` defaults to `512` and limits fields in one received struct metadata body. - `maxTypeMetaBytes` defaults to `4096` and limits encoded body bytes in one received TypeMeta body, @@ -104,6 +111,7 @@ is also limited: ```swift let fory = Fory( maxDepth: 5, + maxContainerMemoryBytes: -1, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -140,5 +148,7 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` for the largest nesting depth your service accepts. +- Set `maxContainerMemoryBytes` to cap estimated list, set, array, and map memory during one root + deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 97d47b981c..330cd575e8 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -149,11 +149,13 @@ For buffer-backed input: comparison. - Multi-byte element arrays should compute the required byte size with overflow checks before allocation. -- Container readers that allocate, reserve, or size-hint from a declared - logical element count should first call the byte owner's readability check for - that count. This is not a full container-body validation; it is the allocation - proof that the sender has supplied at least proportional input bytes before - the reader preallocates from the count. +- Container readers that allocate backing storage or size-hint from a declared + logical element count should call the byte owner's readability check for that + count before that backing allocation or capacity reservation. This is not a + full container-body validation; it is the allocation proof that the sender has + supplied at least proportional input bytes before the reader preallocates from + the count. Estimated memory-budget accounting may reserve budget before this + byte check because it does not allocate backing storage. For stream-backed input: @@ -198,6 +200,42 @@ validation can cause a no-progress loop, unbounded resource growth, retained state, or success across a Fory policy boundary. Protocol-allowed chunk segmentation is normal input and is not a security issue by itself. +## Container Memory Budget + +Runtimes should enforce a root-deserialization budget for estimated +container-owned memory. This is cumulative accounting for containers created by +one root read; it is not exact heap measurement and it is not a raw element-slot +limit. + +The public configuration should be named around `maxContainerMemoryBytes`. +`-1` means automatic input-shaped budgeting. Positive user configuration always +wins. For known-length root input, the automatic budget is +`inputBytes * 8 + 64 KiB`. For true stream or otherwise unknown-length root +input, the automatic budget is fixed at `128 MiB`. Stream budgeting should not +depend on dynamic bytes-read accounting. + +Container budget accounting should: + +- happen in root-operation read state, with cleanup owned by the root + deserialization `finally`; +- reject arithmetic overflow before comparing budget or allocating; +- charge fixed container object cost, backing capacity, map table and entry + overhead, reference arrays, and inline or value storage where a runtime stores + elements inline; +- charge fixed cost even for zero-size containers; +- preserve existing byte-availability checks before backing allocation or + capacity reservation; +- skip dedicated string, binary, primitive array, and primitive dense-array + owner paths. + +Each runtime must inspect the concrete container path before choosing formulas. +Reference-backed containers should charge reference storage, using a 4-byte +reference slot when the actual reference slot size is not cheap or reliable to +query. Inline/value containers such as a value-type vector or list must charge +the inline element storage instead of treating those elements as references. +General inline-value containers must not be skipped just because dedicated +primitive dense arrays are skipped. + ## Skip Semantics Skipping unknown or incompatible data is classified by concrete impact, not by diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index d8fb205012..53f41887fd 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -388,9 +388,9 @@ chunk, nullability, reference, and type-dispatch semantics. It is still the right allocation proof for count-based preallocation: after validating a non-empty count and reading any serializer-owned header or type metadata that precedes allocation, call `checkReadableBytes(logicalCount)` before allocating, -reserving, or size-hinting from that count. The byte owner handles buffer versus -stream readiness; the container serializer then allocates with the declared -count and reads elements through its normal owner path. +reserving backing capacity, or size-hinting from that count. The byte owner +handles buffer versus stream readiness; the container serializer then allocates +with the declared count and reads elements through its normal owner path. This check is not a full container-body validation. It only prevents a small or truncated input from causing a large count-based preallocation. Chunk sizes, @@ -398,6 +398,25 @@ duplicate keys, element value semantics, and protocol strictness remain owned by the container/map serializer and should be validated only when they protect a real owner invariant. +Container readers should also charge a root-operation estimated container memory +budget before allocation or size hinting. The budget belongs to `ReadContext` or +the equivalent root read state, not to serializers and not to ambient +thread-local state. Positive `maxContainerMemoryBytes` configuration wins; auto +configuration uses `inputBytes * 8 + 64 KiB` for known-length root input and +fixed `128 MiB` for true stream or unknown-length root input. Do not add dynamic +stream bytes-read accounting for this budget. + +The budget estimates container-owned memory, not exact heap bytes. Charge fixed +container object cost, backing capacity, map table and entry overhead, +reference arrays, and inline/value element storage where the runtime stores +container elements inline. Charge zero-size containers for their fixed cost. +Skip dedicated string, binary, primitive array, and primitive dense-array owners, +but do not skip general inline-value containers such as vectors or lists of +value objects. If reference slot size is not cheap or reliable to query, use a +4-byte reference slot. Reject arithmetic overflow before budget comparison or +allocation, and keep the existing `checkReadableBytes` proof before backing +allocation or capacity reservation. + For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes are readable through the byte owner. Field-list allocation should happen after that body readability check and should not use a separate small initial-capacity diff --git a/go/fory/README.md b/go/fory/README.md index 8b5ec3392c..e9c633ad8c 100644 --- a/go/fory/README.md +++ b/go/fory/README.md @@ -93,11 +93,15 @@ f := fory.New(fory.WithXlang(false), fory.WithCompatible(false)) // Set maximum nesting depth f := fory.New(fory.WithMaxDepth(20)) +// Set maximum estimated container memory for one root read +f := fory.New(fory.WithMaxContainerMemoryBytes(256 * 1024 * 1024)) + // Combine multiple options f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(20), + fory.WithMaxContainerMemoryBytes(-1), ) ``` diff --git a/go/fory/array.go b/go/fory/array.go index f99f6ff39f..93b81a85c2 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -290,7 +290,7 @@ func (s *arrayConcreteValueSerializer) ReadWithTypeInfo(ctx *ReadContext, refMod // arrayDynSerializer wraps sliceDynSerializer for arrays with interface element types. // It converts arrays to slices and delegates to sliceDynSerializer. type arrayDynSerializer struct { - sliceSerializer sliceDynSerializer + sliceSerializer *sliceDynSerializer } func newArrayDynSerializer(elemType reflect.Type) (arrayDynSerializer, error) { @@ -318,6 +318,9 @@ func (s arrayDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) + if !ctx.reserveSliceTypeMemory(value.Len(), value.Type().Elem()) { + return + } tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) s.sliceSerializer.readData(ctx, tempSlice, value.Len()) if ctx.HasError() { diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 0e57021343..3a25284909 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -172,6 +172,9 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -200,6 +203,9 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -501,6 +507,12 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -519,6 +531,9 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -545,6 +560,9 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -568,6 +586,9 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -831,6 +852,12 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: maps are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tmapLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -849,6 +876,9 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -884,6 +914,9 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -978,6 +1011,10 @@ func getGoTypeString(t types.Type) string { return t.String() } +func unsafeSizeExpr(t types.Type) string { + return fmt.Sprintf("int64(unsafe.Sizeof(*new(%s)))", getGoTypeString(t)) +} + // generateMapKeyRead generates code to read a map key // Uses error-aware methods for deferred error checking func generateMapKeyRead(buf *bytes.Buffer, keyType types.Type, varName string) error { diff --git a/go/fory/codegen/generator.go b/go/fory/codegen/generator.go index dbe7842da4..30fcbd7fb3 100644 --- a/go/fory/codegen/generator.go +++ b/go/fory/codegen/generator.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "go/format" + "go/types" "io/ioutil" "log" "os" @@ -33,6 +34,16 @@ import ( var logger = log.New(os.Stdout, "", 0) +func typeNeedsContainerReservation(t types.Type) bool { + if _, ok := t.(*types.Slice); ok { + return true + } + if _, ok := t.(*types.Map); ok { + return true + } + return false +} + // GeneratorOptions contains configuration for the code generator type GeneratorOptions struct { TypeList string // comma-separated list of types to generate code for @@ -286,6 +297,7 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil needsTime := false needsReflect := false needsOptional := false + needsUnsafe := false for _, s := range structs { for _, field := range s.Fields { @@ -295,6 +307,12 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil } if field.IsOptional { needsOptional = true + if field.OptionalElem != nil && typeNeedsContainerReservation(field.OptionalElem) { + needsUnsafe = true + } + } + if typeNeedsContainerReservation(field.Type) { + needsUnsafe = true } // We need reflect for the interface compatibility methods needsReflect = true @@ -310,6 +328,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil if needsTime { fmt.Fprintf(&buf, "\t\"time\"\n") } + if needsUnsafe { + fmt.Fprintf(&buf, "\t\"unsafe\"\n") + } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") if needsOptional { fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") @@ -551,6 +572,7 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { needsTime := false needsReflect := false needsOptional := false + needsUnsafe := false for _, s := range structs { for _, field := range s.Fields { @@ -560,6 +582,12 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { } if field.IsOptional { needsOptional = true + if field.OptionalElem != nil && typeNeedsContainerReservation(field.OptionalElem) { + needsUnsafe = true + } + } + if typeNeedsContainerReservation(field.Type) { + needsUnsafe = true } // We need reflect for the interface compatibility methods needsReflect = true @@ -575,6 +603,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { if needsTime { fmt.Fprintf(&buf, "\t\"time\"\n") } + if needsUnsafe { + fmt.Fprintf(&buf, "\t\"unsafe\"\n") + } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") if needsOptional { fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") diff --git a/go/fory/container_memory_budget_test.go b/go/fory/container_memory_budget_test.go new file mode 100644 index 0000000000..16959b3d0a --- /dev/null +++ b/go/fory/container_memory_budget_test.go @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type budgetItem struct { + A int32 +} + +type budgetSiblings struct { + A []string + B []string +} + +func TestContainerMemoryBudgetConfig(t *testing.T) { + require.Equal(t, int64(-1), New().config.MaxContainerMemoryBytes) + require.Equal(t, int64(123), New(WithMaxContainerMemoryBytes(123)).config.MaxContainerMemoryBytes) + require.Panics(t, func() { New(WithMaxContainerMemoryBytes(0)) }) + require.Panics(t, func() { New(WithMaxContainerMemoryBytes(-2)) }) +} + +func TestContainerMemoryBudgetAutoLimits(t *testing.T) { + ctx := NewReadContext(false) + ctx.initContainerMemoryBudget(10, false) + require.False(t, ctx.HasError()) + require.Equal(t, int64(10)*knownRootBudgetMultiplier+knownRootBudgetSlackBytes, ctx.containerMemoryLimitBytes) + require.True(t, ctx.ReserveContainerMemory(ctx.containerMemoryLimitBytes)) + require.False(t, ctx.ReserveContainerMemory(1)) + require.Contains(t, ctx.CheckError().Error(), "maxContainerMemoryBytes") + + ctx = NewReadContext(false) + ctx.initContainerMemoryBudget(10, true) + require.False(t, ctx.HasError()) + require.Equal(t, streamRootBudgetBytes, ctx.containerMemoryLimitBytes) + require.True(t, ctx.ReserveContainerMemory(streamRootBudgetBytes)) + require.False(t, ctx.ReserveContainerMemory(1)) + require.Contains(t, ctx.CheckError().Error(), "maxContainerMemoryBytes") + + ctx = NewReadContext(false) + ctx.maxContainerMemoryBytes = 77 + ctx.initContainerMemoryBudget(10, true) + require.False(t, ctx.HasError()) + require.Equal(t, int64(77), ctx.containerMemoryLimitBytes) +} + +func TestContainerMemoryBudgetKnownVsStreamRoot(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var fromBytes []any + err = New(WithCompatible(false)).Deserialize(data, &fromBytes) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + var fromStream []any + err = New(WithCompatible(false)).DeserializeFromReader(bytes.NewReader(data), &fromStream) + require.NoError(t, err) + require.Len(t, fromStream, len(values)) +} + +func TestContainerMemoryBudgetExplicitOverride(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var out []any + err = New(WithCompatible(false), WithMaxContainerMemoryBytes(4*1024*1024)).Deserialize(data, &out) + require.NoError(t, err) + require.Len(t, out, len(values)) +} + +func TestContainerMemoryBudgetEmptyAndCumulative(t *testing.T) { + data, err := New(WithCompatible(false)).Serialize([]any{}) + require.NoError(t, err) + var empty []any + err = New(WithCompatible(false), WithMaxContainerMemoryBytes(sliceObjectBytes-1)).Deserialize(data, &empty) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + writer := New(WithCompatible(false)) + require.NoError(t, writer.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + data, err = writer.Serialize(&budgetSiblings{A: []string{}, B: []string{}}) + require.NoError(t, err) + reader := New(WithCompatible(false), WithMaxContainerMemoryBytes(sliceObjectBytes)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + var out budgetSiblings + err = reader.Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") +} + +func TestContainerMemoryBudgetMapAndOverflow(t *testing.T) { + data, err := New().Serialize(map[string]string{"k": "v"}) + require.NoError(t, err) + var out map[string]string + oneEntryBudget := mapObjectBytes + + 2*referenceSlotBytes + + mapEntryOverheadBytes + referenceSlotBytes + + containerSizeOf[string]() + containerSizeOf[string]() + err = New(WithMaxContainerMemoryBytes(oneEntryBudget-1)).Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + ctx := NewReadContext(false) + ctx.initContainerMemoryBudget(0, true) + require.False(t, ctx.ReserveMapMemory(MaxInt, MaxInt64, 1)) + require.Contains(t, ctx.CheckError().Error(), "overflows") +} + +func TestContainerMemoryBudgetSlicesAndInlineValues(t *testing.T) { + data, err := New().Serialize([]string{"a"}) + require.NoError(t, err) + var stringsOut []string + err = New(WithMaxContainerMemoryBytes(sliceObjectBytes+containerSizeOf[string]()-1)).Deserialize(data, &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + writer := New() + require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + data, err = writer.Serialize([]budgetItem{{A: 1}}) + require.NoError(t, err) + reader := New(WithMaxContainerMemoryBytes(sliceObjectBytes + containerSizeOf[budgetItem]() - 1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var items []budgetItem + err = reader.Deserialize(data, &items) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") +} + +func TestContainerMemoryBudgetSkipsDenseOwners(t *testing.T) { + f := New(WithMaxContainerMemoryBytes(1)) + + stringData, err := New().Serialize(strings.Repeat("x", 128)) + require.NoError(t, err) + var s string + require.NoError(t, f.Deserialize(stringData, &s)) + require.Len(t, s, 128) + + bytesData, err := New().Serialize([]byte{1, 2, 3, 4}) + require.NoError(t, err) + var b []byte + require.NoError(t, f.Deserialize(bytesData, &b)) + require.Equal(t, []byte{1, 2, 3, 4}, b) + + intsData, err := New().Serialize([]int32{1, 2, 3, 4}) + require.NoError(t, err) + var ints []int32 + require.NoError(t, f.Deserialize(intsData, &ints)) + require.Equal(t, []int32{1, 2, 3, 4}, ints) +} + +func TestContainerMemoryBudgetPreservesByteChecks(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(LIST)) + buf.WriteLength(1024) + buf.WriteInt8(int8(CollectionIsSameType)) + buf.WriteUint8(uint8(STRING)) + + var stringsOut []string + err := New(WithMaxContainerMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") + + buf = NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(INT32_ARRAY)) + buf.WriteLength(4096) + + var ints []int32 + err = New(WithMaxContainerMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &ints) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") +} diff --git a/go/fory/field_serializer.go b/go/fory/field_serializer.go index d91b6ec5e1..d3ef3a787b 100644 --- a/go/fory/field_serializer.go +++ b/go/fory/field_serializer.go @@ -42,7 +42,7 @@ func serializerNeedsGenericDispatch(serializer Serializer) bool { switch serializer.(type) { case *sliceSerializer, primitiveListSerializer, - sliceDynSerializer, + *sliceDynSerializer, setSerializer, mapSerializer, stringSliceSerializer, @@ -68,10 +68,13 @@ func newDeclaredSliceSerializer(type_ reflect.Type, elemSerializer Serializer, r if elem.Kind() == reflect.Ptr && elem.Elem().Kind() == reflect.Interface { return nil, fmt.Errorf("slice serializer does not support pointer to interface element type: %v", type_) } + elemBytes := int64(elem.Size()) return &sliceSerializer{ type_: type_, elemSerializer: elemSerializer, referencable: referencable, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } diff --git a/go/fory/fory.go b/go/fory/fory.go index 412fc46449..7bfb9867ef 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -69,6 +69,7 @@ type Config struct { MaxDepth int IsXlang bool Compatible bool // Schema evolution compatibility mode + MaxContainerMemoryBytes int64 MaxTypeFields int MaxTypeMetaBytes int MaxSchemaVersionsPerType int @@ -82,6 +83,7 @@ func defaultConfig() Config { MaxDepth: 20, IsXlang: true, MaxTypeFields: 512, + MaxContainerMemoryBytes: -1, MaxTypeMetaBytes: 4096, MaxSchemaVersionsPerType: 10, MaxAverageSchemaVersionsPerType: 3, @@ -110,6 +112,17 @@ func WithMaxDepth(depth int) Option { } } +// WithMaxContainerMemoryBytes sets the maximum estimated container-owned memory accepted during one root deserialization. +// Use -1 for the automatic input-shaped limit. +func WithMaxContainerMemoryBytes(size int64) Option { + if size != -1 && size <= 0 { + panic("MaxContainerMemoryBytes must be positive or -1 for auto") + } + return func(f *Fory) { + f.config.MaxContainerMemoryBytes = size + } +} + // WithXlang sets cross-language serialization mode func WithXlang(enabled bool) Option { return func(f *Fory) { @@ -218,6 +231,7 @@ func New(opts ...Option) *Fory { f.writeCtx.xlang = f.config.IsXlang f.readCtx = NewReadContext(f.config.TrackRef) + f.readCtx.maxContainerMemoryBytes = f.config.MaxContainerMemoryBytes f.readCtx.typeResolver = f.typeResolver f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible @@ -556,6 +570,10 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) + f.readCtx.initContainerMemoryBudget(len(data), false) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } readHeader(f.readCtx) if f.readCtx.HasError() { @@ -1016,6 +1034,10 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) + f.readCtx.initContainerMemoryBudget(len(data), false) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } // ReadData and validate header readHeader(f.readCtx) diff --git a/go/fory/map.go b/go/fory/map.go index 8b1d82cc95..fdb8ebbc53 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -303,6 +303,9 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { iface := reflect.TypeOf((*any)(nil)).Elem() mapType = reflect.MapOf(iface, iface) } + if !ctx.reserveMapTypeMemory(size, mapType.Key(), mapType.Elem()) { + return + } if size == 0 { if value.IsNil() { value.Set(reflect.MakeMap(mapType)) diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index d520e97bd5..5a4e925ade 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -25,6 +25,27 @@ import ( // Optimized map serializers for common types // ============================================================================ +var ( + stringStringMapElemBytes = mapElementMemory(stringElementBytes, stringElementBytes) + stringStringMapMaxLength = maxMapLength(stringStringMapElemBytes) + stringInt64MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int64]()) + stringInt64MapMaxLength = maxMapLength(stringInt64MapElemBytes) + stringInt32MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int32]()) + stringInt32MapMaxLength = maxMapLength(stringInt32MapElemBytes) + stringIntMapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int]()) + stringIntMapMaxLength = maxMapLength(stringIntMapElemBytes) + stringFloat64MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[float64]()) + stringFloat64MapMaxLength = maxMapLength(stringFloat64MapElemBytes) + stringBoolMapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[bool]()) + stringBoolMapMaxLength = maxMapLength(stringBoolMapElemBytes) + int32Int32MapElemBytes = mapElementMemory(containerSizeOf[int32](), containerSizeOf[int32]()) + int32Int32MapMaxLength = maxMapLength(int32Int32MapElemBytes) + int64Int64MapElemBytes = mapElementMemory(containerSizeOf[int64](), containerSizeOf[int64]()) + int64Int64MapMaxLength = maxMapLength(int64Int64MapElemBytes) + intIntMapElemBytes = mapElementMemory(containerSizeOf[int](), containerSizeOf[int]()) + intIntMapMaxLength = maxMapLength(intIntMapElemBytes) +) + // writeMapStringString writes map[string]string using chunk protocol // When hasGenerics=true, element types are known so we set DECL_TYPE flags and skip type info func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool) { @@ -68,10 +89,16 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } } -func readTypedMapSize(ctx *ReadContext) (int, bool) { +func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, bool) { size := ctx.ReadCollectionLength() - if size == 0 || ctx.HasError() { - return size, false + if ctx.HasError() { + return 0, false + } + if !ctx.reserveMapMemory(size, elemBytes, maxLength) { + return 0, false + } + if size == 0 { + return size, true } if !ctx.Buffer().CheckReadable(size, ctx.Err()) { return 0, false @@ -83,12 +110,11 @@ func readTypedMapSize(ctx *ReadContext) (int, bool) { func readMapStringString(ctx *ReadContext) map[string]string { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]string) + size, ok := readTypedMapSize(ctx, stringStringMapElemBytes, stringStringMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]string, size) + result := make(map[string]string, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -171,12 +197,11 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) func readMapStringInt64(ctx *ReadContext) map[string]int64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int64) + size, ok := readTypedMapSize(ctx, stringInt64MapElemBytes, stringInt64MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int64, size) + result := make(map[string]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -256,12 +281,11 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) func readMapStringInt32(ctx *ReadContext) map[string]int32 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int32) + size, ok := readTypedMapSize(ctx, stringInt32MapElemBytes, stringInt32MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int32, size) + result := make(map[string]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -341,12 +365,11 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { func readMapStringInt(ctx *ReadContext) map[string]int { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int) + size, ok := readTypedMapSize(ctx, stringIntMapElemBytes, stringIntMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int, size) + result := make(map[string]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -426,12 +449,11 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo func readMapStringFloat64(ctx *ReadContext) map[string]float64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]float64) + size, ok := readTypedMapSize(ctx, stringFloat64MapElemBytes, stringFloat64MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]float64, size) + result := make(map[string]float64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -511,12 +533,11 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { func readMapStringBool(ctx *ReadContext) map[string]bool { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]bool) + size, ok := readTypedMapSize(ctx, stringBoolMapElemBytes, stringBoolMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]bool, size) + result := make(map[string]bool, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -600,12 +621,11 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { func readMapInt32Int32(ctx *ReadContext) map[int32]int32 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int32]int32) + size, ok := readTypedMapSize(ctx, int32Int32MapElemBytes, int32Int32MapMaxLength) if !ok { - return result + return nil } - result = make(map[int32]int32, size) + result := make(map[int32]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -685,12 +705,11 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { func readMapInt64Int64(ctx *ReadContext) map[int64]int64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int64]int64) + size, ok := readTypedMapSize(ctx, int64Int64MapElemBytes, int64Int64MapMaxLength) if !ok { - return result + return nil } - result = make(map[int64]int64, size) + result := make(map[int64]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -770,12 +789,11 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { func readMapIntInt(ctx *ReadContext) map[int]int { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int]int) + size, ok := readTypedMapSize(ctx, intIntMapElemBytes, intIntMapMaxLength) if !ok { - return result + return nil } - result = make(map[int]int, size) + result := make(map[int]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -831,12 +849,12 @@ func (s stringStringMapSerializer) Write(ctx *WriteContext, refMode RefMode, wri } func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringString(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringStringMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -866,12 +884,12 @@ func (s stringInt64MapSerializer) Write(ctx *WriteContext, refMode RefMode, writ } func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringInt64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringInt64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -901,12 +919,12 @@ func (s stringIntMapSerializer) Write(ctx *WriteContext, refMode RefMode, writeT } func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringInt(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringIntMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -936,12 +954,12 @@ func (s stringFloat64MapSerializer) Write(ctx *WriteContext, refMode RefMode, wr } func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringFloat64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringFloat64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -971,12 +989,12 @@ func (s stringBoolMapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringBool(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringBoolMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1006,12 +1024,12 @@ func (s int32Int32MapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s int32Int32MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapInt32Int32(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s int32Int32MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1041,12 +1059,12 @@ func (s int64Int64MapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s int64Int64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapInt64Int64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s int64Int64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1076,12 +1094,12 @@ func (s intIntMapSerializer) Write(ctx *WriteContext, refMode RefMode, writeType } func (s intIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapIntInt(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s intIntMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { diff --git a/go/fory/reader.go b/go/fory/reader.go index 3985bb4e2b..b3d6301d65 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -29,21 +29,60 @@ import ( // ReadContext holds all state needed during deserialization. type ReadContext struct { - buffer *ByteBuffer - refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode - rootHeader byte - compatible bool // Schema evolution compatibility mode - typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking in native-mode paths - outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization - outOfBandIndex int // Current index into out-of-band buffers - depth int // Current nesting depth for cycle detection - maxDepth int // Maximum allowed nesting depth - err Error // Accumulated error state for deferred checking - lastTypePtr uintptr - lastTypeInfo *TypeInfo + buffer *ByteBuffer + refReader *RefReader + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + rootHeader byte + compatible bool // Schema evolution compatibility mode + typeResolver *TypeResolver // For complex type deserialization + refResolver *RefResolver // For reference tracking in native-mode paths + outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization + outOfBandIndex int // Current index into out-of-band buffers + depth int // Current nesting depth for cycle detection + maxDepth int // Maximum allowed nesting depth + err Error // Accumulated error state for deferred checking + lastTypePtr uintptr + lastTypeInfo *TypeInfo + maxContainerMemoryBytes int64 + containerMemoryLimitBytes int64 + remainingContainerMemoryBytes int64 +} + +const ( + knownRootBudgetMultiplier = int64(8) + knownRootBudgetSlackBytes = int64(64 * 1024) + streamRootBudgetBytes = int64(128 * 1024 * 1024) + sliceObjectBytes = int64(unsafe.Sizeof([]byte(nil))) + mapObjectBytes = int64(48) + mapEntryOverheadBytes = int64(16) +) + +var referenceSlotBytes = int64(unsafe.Sizeof(uintptr(0))) +var stringElementBytes = containerSizeOf[string]() +var stringSliceMaxLength = maxSliceLength(stringElementBytes) + +func containerSizeOf[T any]() int64 { + var v T + return int64(unsafe.Sizeof(v)) +} + +func maxSliceLength(elemBytes int64) int64 { + if elemBytes == 0 { + return MaxInt64 + } + return (MaxInt64 - sliceObjectBytes) / elemBytes +} + +func mapElementMemory(keyBytes int64, valueBytes int64) int64 { + return keyBytes + valueBytes + mapEntryOverheadBytes + referenceSlotBytes + 2*referenceSlotBytes +} + +func maxMapLength(elemBytes int64) int64 { + if elemBytes == 0 { + return MaxInt64 + } + return (MaxInt64 - mapObjectBytes) / elemBytes } // IsXlang returns whether cross-language serialization mode is enabled @@ -54,10 +93,11 @@ func (c *ReadContext) IsXlang() bool { // NewReadContext creates a new read context func NewReadContext(trackRef bool) *ReadContext { return &ReadContext{ - buffer: NewByteBuffer(nil), - refReader: NewRefReader(trackRef), - trackRef: trackRef, - maxDepth: 128, // Default maximum nesting depth + buffer: NewByteBuffer(nil), + refReader: NewRefReader(trackRef), + trackRef: trackRef, + maxDepth: 128, // Default maximum nesting depth + maxContainerMemoryBytes: -1, } } @@ -67,6 +107,8 @@ func (c *ReadContext) Reset() { c.outOfBandBuffers = nil c.outOfBandIndex = 0 c.err = Error{} // Clear error state + // Container budget state is overwritten by each root read before deserialization. + // Avoid extra reset stores on the successful root hot path. if c.refResolver != nil { c.refResolver.resetRead() } @@ -75,6 +117,157 @@ func (c *ReadContext) Reset() { } } +func (c *ReadContext) initContainerMemoryBudget(rootInputBytes int, unknownLengthInput bool) { + limit := c.maxContainerMemoryBytes + if limit <= 0 { + if unknownLengthInput { + limit = streamRootBudgetBytes + } else { + if rootInputBytes < 0 { + c.setContainerMemoryError("root input size must be non-negative: %d", rootInputBytes) + return + } + if int64(rootInputBytes) > (MaxInt64-knownRootBudgetSlackBytes)/knownRootBudgetMultiplier { + c.setContainerMemoryError("root input size %d overflows automatic container memory budget", rootInputBytes) + return + } + limit = int64(rootInputBytes)*knownRootBudgetMultiplier + knownRootBudgetSlackBytes + } + } + c.containerMemoryLimitBytes = limit + c.remainingContainerMemoryBytes = limit +} + +// ReserveSliceMemory reserves estimated memory for a Go slice backing array before allocation. +func (c *ReadContext) ReserveSliceMemory(length int, elemBytes int64) bool { + if elemBytes < 0 { + c.setContainerMemoryError("negative container element size: %d", elemBytes) + return false + } + return c.reserveSliceMemory(length, elemBytes, maxSliceLength(elemBytes)) +} + +func (c *ReadContext) reserveSliceMemory(length int, elemBytes int64, maxLength int64) bool { + if length < 0 { + c.setContainerMemoryError("negative container length: %d", length) + return false + } + if int64(length) > maxLength { + c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return false + } + bytes := sliceObjectBytes + int64(length)*elemBytes + remaining := c.remainingContainerMemoryBytes + if bytes > remaining { + c.setContainerMemoryExceeded(bytes, remaining) + return false + } + c.remainingContainerMemoryBytes = remaining - bytes + return true +} + +func (c *ReadContext) reserveSliceTypeMemory(length int, elemType reflect.Type) bool { + elemBytes := referenceSlotBytes + if elemType != nil { + elemBytes = int64(elemType.Size()) + } + return c.ReserveSliceMemory(length, elemBytes) +} + +// ReserveMapMemory reserves estimated memory for a Go map before allocation or size hinting. +func (c *ReadContext) ReserveMapMemory(length int, keyBytes int64, valueBytes int64) bool { + if keyBytes < 0 || valueBytes < 0 { + c.setContainerMemoryError("negative map element size: key=%d value=%d", keyBytes, valueBytes) + return false + } + perEntry := keyBytes + valueBytes + if perEntry < keyBytes || perEntry > MaxInt64-mapEntryOverheadBytes-referenceSlotBytes { + c.setContainerMemoryError("map element size overflows: key=%d value=%d", keyBytes, valueBytes) + return false + } + perEntry += mapEntryOverheadBytes + referenceSlotBytes + if perEntry > MaxInt64-2*referenceSlotBytes { + c.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + return false + } + elemBytes := perEntry + 2*referenceSlotBytes + return c.reserveMapMemory(length, elemBytes, maxMapLength(elemBytes)) +} + +func (c *ReadContext) reserveMapTypeMemory(length int, keyType reflect.Type, valueType reflect.Type) bool { + keyBytes := referenceSlotBytes + valueBytes := referenceSlotBytes + if keyType != nil { + keyBytes = int64(keyType.Size()) + } + if valueType != nil { + valueBytes = int64(valueType.Size()) + } + return c.ReserveMapMemory(length, keyBytes, valueBytes) +} + +func (c *ReadContext) reserveMapMemory(length int, elemBytes int64, maxLength int64) bool { + if length < 0 { + c.setContainerMemoryError("negative container length: %d", length) + return false + } + if int64(length) > maxLength { + c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return false + } + bytes := mapObjectBytes + int64(length)*elemBytes + remaining := c.remainingContainerMemoryBytes + if bytes > remaining { + c.setContainerMemoryExceeded(bytes, remaining) + return false + } + c.remainingContainerMemoryBytes = remaining - bytes + return true +} + +func (c *ReadContext) reserveCountedMemory(length int, fixedBytes int64, elemBytes int64) bool { + if length < 0 { + c.setContainerMemoryError("negative container length: %d", length) + return false + } + if fixedBytes < 0 || elemBytes < 0 { + c.setContainerMemoryError("negative container memory estimate: fixed=%d elem=%d", fixedBytes, elemBytes) + return false + } + if elemBytes != 0 && int64(length) > (MaxInt64-fixedBytes)/elemBytes { + c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return false + } + return c.ReserveContainerMemory(fixedBytes + int64(length)*elemBytes) +} + +// ReserveContainerMemory reserves raw estimated container-owned bytes. +func (c *ReadContext) ReserveContainerMemory(bytes int64) bool { + if bytes < 0 { + c.setContainerMemoryError("estimated container memory must be non-negative, got %d bytes", bytes) + return false + } + remaining := c.remainingContainerMemoryBytes + if bytes > remaining { + c.setContainerMemoryExceeded(bytes, remaining) + return false + } + c.remainingContainerMemoryBytes = remaining - bytes + return true +} + +//go:noinline +func (c *ReadContext) setContainerMemoryError(format string, args ...any) { + c.SetError(DeserializationErrorf(format, args...)) +} + +//go:noinline +func (c *ReadContext) setContainerMemoryExceeded(bytes int64, remaining int64) { + c.SetError(DeserializationErrorf( + "estimated container memory request %d bytes exceeds maxContainerMemoryBytes remaining budget %d bytes out of effective limit %d bytes", + bytes, remaining, c.containerMemoryLimitBytes)) +} + // SetData sets new input data (for buffer reuse) // Reuses existing buffer to avoid allocation func (c *ReadContext) SetData(data []byte) { @@ -536,7 +729,42 @@ func (c *ReadContext) ReadStringSlice(refMode RefMode, readType bool) []string { if readType { _ = c.buffer.ReadUint8(err) } - return ReadStringSlice(c.buffer, err) + return c.readStringSliceData() +} + +func (c *ReadContext) readStringSliceData() []string { + buf := c.buffer + err := c.Err() + length := buf.ReadLength(err) + if c.HasError() { + return nil + } + if !c.reserveSliceMemory(length, containerSizeOf[string](), stringSliceMaxLength) { + return nil + } + if length == 0 { + return make([]string, 0) + } + collectFlag := buf.ReadInt8(err) + if (collectFlag&CollectionIsSameType) != 0 && (collectFlag&CollectionIsDeclElementType) == 0 { + _ = buf.ReadUint8(err) + } + if c.HasError() || !buf.CheckReadable(length, err) { + return nil + } + result := make([]string, length) + trackRefs := (collectFlag & CollectionTrackingRef) != 0 + hasNull := (collectFlag & CollectionHasNull) != 0 + for i := 0; i < length; i++ { + if trackRefs || hasNull { + rf := buf.ReadInt8(err) + if rf == NullFlag { + continue + } + } + result[i] = readString(buf, err) + } + return result } // ReadStringStringMap reads map[string]string with optional ref/type info diff --git a/go/fory/set.go b/go/fory/set.go index 1a42739547..652f8ddca9 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -318,6 +318,9 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } if length == 0 { + if !ctx.reserveMapTypeMemory(length, type_.Key(), type_.Elem()) { + return + } // Initialize empty set if length is 0 value.Set(reflect.MakeMap(type_)) return @@ -356,6 +359,9 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if !buf.CheckReadable(length, err) { return } + if !ctx.reserveMapTypeMemory(length, type_.Key(), type_.Elem()) { + return + } // Initialize set if nil if value.IsNil() { diff --git a/go/fory/slice.go b/go/fory/slice.go index 6d941b3bf6..56d4d4845f 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -124,6 +124,8 @@ type sliceSerializer struct { type_ reflect.Type elemSerializer Serializer referencable bool + elemBytes int64 + maxLength int64 } // newSliceSerializer creates a sliceSerializer for slices with concrete element types. @@ -144,10 +146,13 @@ func newSliceSerializer(type_ reflect.Type, elemSerializer Serializer, xlang boo reflect.Uint8, reflect.Float32, reflect.Float64: return nil, fmt.Errorf("sliceSerializer does not support primitive element type %v: use dedicated primitive slice serializer", type_) } + elemBytes := int64(elem.Size()) return &sliceSerializer{ type_: type_, elemSerializer: elemSerializer, referencable: isRefType(elem, xlang), + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } @@ -308,6 +313,9 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } isArrayType := value.Type().Kind() == reflect.Array + if !isArrayType && !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + return + } if length == 0 { if !isArrayType { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 907fcddd4f..d341e9455b 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -31,35 +31,43 @@ type sliceDynSerializer struct { elemType reflect.Type isInterfaceElem bool isPointerElem bool + elemBytes int64 + maxLength int64 } // newSliceDynSerializer creates a new sliceDynSerializer. // This serializer is ONLY for slices with interface or pointer to interface element types. // For other slice types, use sliceSerializer instead. -func newSliceDynSerializer(elemType reflect.Type) (sliceDynSerializer, error) { +func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { // Nil element type is allowed for fully dynamic slices (e.g., []any) if elemType == nil { - return sliceDynSerializer{ + elemBytes := containerSizeOf[any]() + return &sliceDynSerializer{ isInterfaceElem: true, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } // Validate element type is interface or pointer to interface isInterface := elemType.Kind() == reflect.Interface isPointerToInterface := elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Interface if !isInterface && !isPointerToInterface { - return sliceDynSerializer{}, fmt.Errorf( + return nil, fmt.Errorf( "sliceDynSerializer only supports interface or pointer to interface element types, got %v; use sliceSerializer for other types", elemType) } - return sliceDynSerializer{ + elemBytes := int64(elemType.Size()) + return &sliceDynSerializer{ elemType: elemType, isInterfaceElem: isInterface, isPointerElem: isPointerToInterface, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } // mustNewSliceDynSerializer is like newSliceDynSerializer but panics on error. // Used for initialization code where the element type is known to be valid. -func mustNewSliceDynSerializer(elemType reflect.Type) sliceDynSerializer { +func mustNewSliceDynSerializer(elemType reflect.Type) *sliceDynSerializer { s, err := newSliceDynSerializer(elemType) if err != nil { panic(err) @@ -67,7 +75,7 @@ func mustNewSliceDynSerializer(elemType reflect.Type) sliceDynSerializer { return s } -func (s sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { +func (s *sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { done := writeSliceRefAndType(ctx, refMode, writeType, value, LIST) if done || ctx.HasError() { return @@ -75,7 +83,7 @@ func (s sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType s.WriteData(ctx, value) } -func (s sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { +func (s *sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { buf := ctx.Buffer() // Get slice length and handle empty slice case length := value.Len() @@ -103,7 +111,7 @@ func (s sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // - Type consistency flags // - Element type information (if homogeneous) // Returns pointer to TypeInfo to avoid copy overhead. -func (s sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, value reflect.Value) (byte, *TypeInfo) { +func (s *sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, value reflect.Value) (byte, *TypeInfo) { collectFlag := CollectionDefaultFlag var elemTypeInfo *TypeInfo hasNull := false @@ -161,7 +169,7 @@ func (s sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, valu } // writeSameType efficiently serializes a slice where all elements share the same type -func (s sliceDynSerializer) writeSameType( +func (s *sliceDynSerializer) writeSameType( ctx *WriteContext, buf *ByteBuffer, value reflect.Value, typeInfo *TypeInfo, flag byte) { if typeInfo == nil { return @@ -194,7 +202,7 @@ func (s sliceDynSerializer) writeSameType( } // writeDifferentTypes handles serialization of slices with mixed element types -func (s sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuffer, value reflect.Value, flag byte) { +func (s *sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuffer, value reflect.Value, flag byte) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 @@ -246,7 +254,7 @@ func (s sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuff } } -func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { +func (s *sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { done, typeId := readSliceRefAndType(ctx, refMode, readType, value) if done || ctx.HasError() { return @@ -258,11 +266,11 @@ func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo s.ReadData(ctx, value) } -func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { +func (s *sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { s.readData(ctx, value, -1) } -func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { +func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadCollectionLength() @@ -274,6 +282,10 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe ctx.SetError(DeserializationErrorf("array length %d does not match serialized length %d", expectedLength, length)) return } + allocatedByCaller := expectedLength >= 0 + if !allocatedByCaller && !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + return + } if length == 0 { value.Set(reflect.MakeSlice(sliceType, 0, 0)) return @@ -305,7 +317,9 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe if !buf.CheckReadable(length, ctxErr) { return } - value.Set(reflect.MakeSlice(sliceType, length, length)) + if !allocatedByCaller { + value.Set(reflect.MakeSlice(sliceType, length, length)) + } ctx.RefResolver().Reference(value) s.readSameType(ctx, buf, value, elemType, elemSerializer, collectFlag, length) return @@ -313,18 +327,20 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe if !buf.CheckReadable(length, ctxErr) { return } - value.Set(reflect.MakeSlice(sliceType, length, length)) + if !allocatedByCaller { + value.Set(reflect.MakeSlice(sliceType, length, length)) + } ctx.RefResolver().Reference(value) s.readDifferentTypes(ctx, buf, value, collectFlag, length) } -func (s sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { +func (s *sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { // typeInfo is already read, don't read it again s.Read(ctx, refMode, false, false, value) } // readSameType handles deserialization of slices where all elements share the same type -func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { +func (s *sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 ctxErr := ctx.Err() @@ -402,7 +418,7 @@ func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, valu } // readDifferentTypes handles deserialization of slices with mixed element types -func (s sliceDynSerializer) readDifferentTypes( +func (s *sliceDynSerializer) readDifferentTypes( ctx *ReadContext, buf *ByteBuffer, value reflect.Value, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 @@ -464,7 +480,7 @@ func (s sliceDynSerializer) readDifferentTypes( // 1. Slice element type is pointer-to-interface and the deserialized type is not a pointer, OR // 2. Slice element type is interface and the deserialized type doesn't directly implement it // but the pointer type does (common case where interface has pointer receivers) -func (s sliceDynSerializer) wrapSerializerIfNeeded(elemType reflect.Type, serializer Serializer) (reflect.Type, Serializer) { +func (s *sliceDynSerializer) wrapSerializerIfNeeded(elemType reflect.Type, serializer Serializer) (reflect.Type, Serializer) { if elemType.Kind() == reflect.Ptr { return elemType, serializer } diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 9b92691ac8..88e5d50b08 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -652,6 +652,9 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } ptr := (*[]string)(value.Addr().UnsafePointer()) + if !ctx.reserveSliceMemory(length, stringElementBytes, stringSliceMaxLength) { + return + } if length == 0 { *ptr = make([]string, 0) return diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index 0335b2a08e..e033fb4409 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -25,6 +25,18 @@ import ( type primitiveListSerializer struct { type_ reflect.Type elemTypeID TypeId + elemBytes int64 + maxLength int64 +} + +func newPrimitiveList(type_ reflect.Type, elemTypeID TypeId, elemType reflect.Type) primitiveListSerializer { + elemBytes := int64(elemType.Size()) + return primitiveListSerializer{ + type_: type_, + elemTypeID: elemTypeID, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), + } } type compatiblePrimitiveListToArraySerializer struct { @@ -39,43 +51,43 @@ func newPrimitiveListSerializer(type_ reflect.Type, elemTypeID TypeId) (Serializ elemType := type_.Elem() switch elemType.Kind() { case reflect.Bool: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == BOOL + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == BOOL case reflect.Int8: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT8 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT8 case reflect.Uint8: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT8 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT8 case reflect.Int16: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT16 case reflect.Uint16: if elemType == float16Type { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT16 } if elemType == bfloat16Type { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == BFLOAT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == BFLOAT16 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT16 case reflect.Int32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT32 || elemTypeID == VARINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT32 || elemTypeID == VARINT32 case reflect.Uint32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT32 || elemTypeID == VAR_UINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT32 || elemTypeID == VAR_UINT32 case reflect.Int64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 case reflect.Uint64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 case reflect.Int: if reflect.TypeOf(int(0)).Size() == 8 { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT32 || elemTypeID == VARINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT32 || elemTypeID == VARINT32 case reflect.Uint: if reflect.TypeOf(uint(0)).Size() == 8 { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT32 || elemTypeID == VAR_UINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT32 || elemTypeID == VAR_UINT32 case reflect.Float32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT32 case reflect.Float64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT64 default: return nil, false } @@ -167,6 +179,9 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) if ctx.HasError() { return } + if !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + return + } if length == 0 { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) return @@ -228,6 +243,9 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if length == 0 { if value.Kind() == reflect.Slice { + if !ctx.reserveSliceMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + return + } value.Set(reflect.MakeSlice(value.Type(), 0, 0)) } else if value.Len() != 0 { ctx.SetError(DeserializationErrorf("array-compatible list length %d does not match array length %d", length, value.Len())) @@ -266,6 +284,9 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } if value.Kind() == reflect.Slice { + if !ctx.reserveSliceMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + return + } temp := reflect.New(value.Type()).Elem() s.listReader.readValues(buf, err, temp, length, false) if ctx.HasError() { diff --git a/go/fory/stream.go b/go/fory/stream.go index bb86689598..45111695e5 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -96,6 +96,13 @@ func (is *InputStream) Shrink() { func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer + f.readCtx.initContainerMemoryBudget(0, true) + if f.readCtx.HasError() { + err := f.readCtx.TakeError() + f.readCtx.buffer = origBuffer + f.resetReadState() + return err + } defer func() { f.readCtx.buffer = origBuffer f.resetReadState() @@ -123,6 +130,10 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { defer f.resetReadState() // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) + f.readCtx.initContainerMemoryBudget(0, true) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } readHeader(f.readCtx) if f.readCtx.HasError() { diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index 742135a8ba..ca639979a5 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,12 +1,13 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-06-12T06:41:26+08:00 +// generated at: 2026-06-26T15:00:42+08:00 package fory import ( "github.com/apache/fory/go/fory" "reflect" + "unsafe" ) func init() { @@ -189,6 +190,9 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -217,6 +221,9 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -662,6 +669,12 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(int))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -709,6 +722,9 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(int))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -755,6 +771,12 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -802,6 +824,9 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -848,6 +873,12 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -895,6 +926,9 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -1250,6 +1284,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { @@ -1289,6 +1329,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { @@ -1327,6 +1370,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { @@ -1366,6 +1415,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { @@ -1404,6 +1456,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { @@ -1443,6 +1501,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { @@ -1481,6 +1542,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { @@ -1528,6 +1595,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index f0966b0596..ba4d3e5dc9 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -415,7 +415,7 @@ func (r *TypeResolver) initialize() { {stringPtrType, STRING, ptrToStringSerializer{}}, // Register interface types first so typeIDToTypeInfo maps to generic types // that can hold any element type when deserializing into any - {interfaceSliceType, LIST, sliceDynSerializer{}}, + {interfaceSliceType, LIST, mustNewSliceDynSerializer(interfaceType)}, {interfaceMapType, MAP, mapSerializer{type_: interfaceMapType, keyReferencable: true, valueReferencable: true}}, // stringSliceType uses dedicated stringSliceSerializer for optimized serialization // This ensures CollectionIsDeclElementType is set for Java compatibility @@ -1779,7 +1779,7 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s } // For dynamic types, use dynamic slice serializer if isDynamicType(elem) { - return sliceDynSerializer{}, nil + return newSliceDynSerializer(elem) } else { elemSerializer, err := r.getSerializerByType(type_.Elem(), false) if err != nil { diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index 5ad083d69e..1b90fcfcc8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -425,12 +425,17 @@ public T deserialize(byte[] bytes, Class type) { @Override public T deserialize(MemoryBuffer buffer, Class type) { + return deserialize(buffer, type, false); + } + + private T deserialize(MemoryBuffer buffer, Class type, boolean unknownLengthInput) { ensureRegistrationFinished(); + int rootInputBytes = buffer.remaining(); byte bitmap = buffer.readByte(); if (bitmap != headerBitmap) { checkHeaderBitmapWithoutOutOfBand(bitmap); } - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, rootInputBytes, unknownLengthInput); try { try { jitContext.lock(); @@ -451,7 +456,7 @@ public T deserialize(MemoryBuffer buffer, Class type) { @Override public T deserialize(ForyInputStream inputStream, Class type) { try { - return deserialize(inputStream.getBuffer(), type); + return deserialize(inputStream.getBuffer(), type, true); } finally { inputStream.shrinkBuffer(); } @@ -459,7 +464,7 @@ public T deserialize(ForyInputStream inputStream, Class type) { @Override public T deserialize(ForyReadableChannel channel, Class type) { - return deserialize(channel.getBuffer(), type); + return deserialize(channel.getBuffer(), type, true); } @Override @@ -487,7 +492,13 @@ public Object deserialize(MemoryBuffer buffer) { */ @Override public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { + return deserialize(buffer, outOfBandBuffers, false); + } + + private Object deserialize( + MemoryBuffer buffer, Iterable outOfBandBuffers, boolean unknownLengthInput) { ensureRegistrationFinished(); + int rootInputBytes = buffer.remaining(); byte bitmap = buffer.readByte(); boolean peerOutOfBandEnabled = false; if (bitmap != headerBitmap) { @@ -505,7 +516,11 @@ public Object deserialize(MemoryBuffer buffer, Iterable outOfBandB + "produced with bufferCallback null."); } readContext.prepare( - buffer, peerOutOfBandEnabled ? outOfBandBuffers : null, peerOutOfBandEnabled); + buffer, + peerOutOfBandEnabled ? outOfBandBuffers : null, + peerOutOfBandEnabled, + rootInputBytes, + unknownLengthInput); try { try { jitContext.lock(); @@ -532,7 +547,7 @@ public Object deserialize(ForyInputStream inputStream) { public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { try { MemoryBuffer buf = inputStream.getBuffer(); - return deserialize(buf, outOfBandBuffers); + return deserialize(buf, outOfBandBuffers, true); } finally { inputStream.shrinkBuffer(); } @@ -546,7 +561,7 @@ public Object deserialize(ForyReadableChannel channel) { @Override public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { MemoryBuffer buf = channel.getBuffer(); - return deserialize(buf, outOfBandBuffers); + return deserialize(buf, outOfBandBuffers, true); } @SuppressWarnings("unchecked") diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 2b8db3ec66..b96e4aa83d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -68,6 +68,7 @@ public class Config implements Serializable { private final int maxTypeMetaBytes; private final int maxSchemaVersionsPerType; private final int maxAverageSchemaVersionsPerType; + private final long maxContainerMemoryBytes; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -114,6 +115,7 @@ public Config(ForyBuilder builder) { maxTypeMetaBytes = builder.maxTypeMetaBytes; maxSchemaVersionsPerType = builder.maxSchemaVersionsPerType; maxAverageSchemaVersionsPerType = builder.maxAverageSchemaVersionsPerType; + maxContainerMemoryBytes = builder.maxContainerMemoryBytes; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -320,6 +322,11 @@ public int maxAverageSchemaVersionsPerType() { return maxAverageSchemaVersionsPerType; } + /** Returns the root-operation estimated container memory limit in bytes, or -1 for auto. */ + public long maxContainerMemoryBytes() { + return maxContainerMemoryBytes; + } + /** Returns loadFactor of MacRef's writtenObjects. */ public float mapRefLoadFactor() { return mapRefLoadFactor; @@ -368,6 +375,7 @@ public boolean equals(Object o) { && maxTypeMetaBytes == config.maxTypeMetaBytes && maxSchemaVersionsPerType == config.maxSchemaVersionsPerType && maxAverageSchemaVersionsPerType == config.maxAverageSchemaVersionsPerType + && maxContainerMemoryBytes == config.maxContainerMemoryBytes && Objects.equals(defaultJDKStreamSerializerType, config.defaultJDKStreamSerializerType) && longEncoding == config.longEncoding && forVirtualThread == config.forVirtualThread; @@ -403,6 +411,7 @@ public int hashCode() { maxTypeMetaBytes, maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType, + maxContainerMemoryBytes, metaShareEnabled, scopedMetaShareEnabled, metaCompressor, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 48d9dcb433..93d4943940 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -103,6 +103,7 @@ public final class ForyBuilder { int maxTypeMetaBytes = 4096; int maxSchemaVersionsPerType = 10; int maxAverageSchemaVersionsPerType = 3; + long maxContainerMemoryBytes = -1; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -571,6 +572,22 @@ public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersi return this; } + /** + * Sets the maximum estimated container-owned memory accepted during one root deserialization. + * + *

The default is {@code -1}, which derives an automatic per-root budget from the input shape. + * Positive values are explicit byte limits. Other values are invalid. + */ + public ForyBuilder withMaxContainerMemoryBytes(long maxContainerMemoryBytes) { + Preconditions.checkArgument( + maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, + "maxContainerMemoryBytes must be positive or -1 for auto but got %s", + maxContainerMemoryBytes); + this.maxContainerMemoryBytes = maxContainerMemoryBytes; + recordAction(b -> b.withMaxContainerMemoryBytes(maxContainerMemoryBytes)); + return this; + } + /** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */ public ForyBuilder withMapRefLoadFactor(float loadFactor) { Preconditions.checkArgument( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 6dca2e503a..46ca156525 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -51,6 +51,15 @@ */ @SuppressWarnings({"rawtypes", "unchecked"}) public final class ReadContext { + private static final long KNOWN_ROOT_BUDGET_MULTIPLIER = 8L; + private static final long KNOWN_ROOT_BUDGET_SLACK_BYTES = 64L * 1024; + private static final long STREAM_ROOT_BUDGET_BYTES = 128L * 1024 * 1024; + private static final long COLLECTION_OBJECT_BYTES = 24L; + private static final long MAP_OBJECT_BYTES = 48L; + private static final long ARRAY_HEADER_BYTES = 16L; + private static final long MAP_ENTRY_BYTES = 32L; + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private final Config config; private final Generics generics; private final TypeResolver typeResolver; @@ -61,6 +70,7 @@ public final class ReadContext { private final boolean compressInt; private final Int64Encoding longEncoding; private final int maxDepth; + private final long maxContainerMemoryBytes; private final boolean scopedMetaShareEnabled; private final boolean forVirtualThread; private final IdentityHashMap contextObjects = new IdentityHashMap<>(); @@ -69,6 +79,8 @@ public final class ReadContext { private MetaReadContext metaReadContext; private boolean peerOutOfBandEnabled; private int depth; + private long containerMemoryLimitBytes; + private long remainingContainerMemoryBytes; /** * Creates read-side runtime state for one {@code Fory} instance. @@ -92,6 +104,7 @@ public ReadContext( compressInt = config.compressInt(); longEncoding = config.longEncoding(); maxDepth = config.maxDepth(); + maxContainerMemoryBytes = config.maxContainerMemoryBytes(); forVirtualThread = config.forVirtualThread(); scopedMetaShareEnabled = config.isScopedMetaShareEnabled(); if (scopedMetaShareEnabled) { @@ -104,10 +117,32 @@ public ReadContext( * flag for one operation. */ public void prepare( - MemoryBuffer buffer, Iterable outOfBandBuffers, boolean peerOutOfBandEnabled) { + MemoryBuffer buffer, + Iterable outOfBandBuffers, + boolean peerOutOfBandEnabled, + int rootInputBytes, + boolean unknownLengthInput) { this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); + initContainerMemoryBudget(rootInputBytes, unknownLengthInput); + } + + private void initContainerMemoryBudget(int rootInputBytes, boolean unknownLengthInput) { + long limit = maxContainerMemoryBytes; + if (limit <= 0) { + if (unknownLengthInput) { + limit = STREAM_ROOT_BUDGET_BYTES; + } else { + if (rootInputBytes < 0) { + throw new IllegalArgumentException( + "Root input size must be non-negative: " + rootInputBytes); + } + limit = rootInputBytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES; + } + } + containerMemoryLimitBytes = limit; + remainingContainerMemoryBytes = limit; } /** @@ -303,6 +338,8 @@ public void reset() { outOfBandBuffers = null; peerOutOfBandEnabled = false; depth = 0; + containerMemoryLimitBytes = 0; + remainingContainerMemoryBytes = 0; } /** Returns the immutable runtime configuration for this context. */ @@ -310,6 +347,52 @@ public Config getConfig() { return config; } + public void reserveCollectionMemory(int numElements) { + reserveContainerMemory(COLLECTION_OBJECT_BYTES + (long) numElements * REFERENCE_BYTES); + } + + public void reserveCollectionCapacity(int numElements, int capacity) { + reserveContainerMemory((long) (capacity - numElements) * REFERENCE_BYTES); + } + + public void reserveMapMemory(int numElements) { + long entries = (long) numElements; + long tableBytes = entries * 2 * REFERENCE_BYTES; + long entryBytes = entries * (MAP_ENTRY_BYTES + 3L * REFERENCE_BYTES); + reserveContainerMemory(MAP_OBJECT_BYTES + tableBytes + entryBytes); + } + + public void reserveObjectArrayMemory(int numElements) { + reserveContainerMemory(ARRAY_HEADER_BYTES + (long) numElements * REFERENCE_BYTES); + } + + public void reserveContainerMemory(long bytes) { + if (bytes < 0) { + throwNegativeContainerMemory(bytes); + } + long remaining = remainingContainerMemoryBytes; + if (bytes > remaining) { + throwContainerMemoryExceeded(bytes, remaining); + } + remainingContainerMemoryBytes = remaining - bytes; + } + + private void throwNegativeContainerMemory(long bytes) { + throw new InsecureException( + "Estimated container memory must be non-negative, but got " + bytes + " bytes."); + } + + private void throwContainerMemoryExceeded(long bytes, long remaining) { + throw new InsecureException( + "Estimated container memory request " + + bytes + + " bytes exceeds maxContainerMemoryBytes remaining budget " + + remaining + + " bytes out of effective limit " + + containerMemoryLimitBytes + + " bytes. If the data is trusted, increase ForyBuilder#withMaxContainerMemoryBytes."); + } + /** Returns the generics stack shared by the owning runtime. */ public Generics getGenerics() { return generics; diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java index 94dd961889..2f19f0981a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java @@ -77,6 +77,7 @@ public final class MemoryBuffer { private static final int LONG_ARRAY_OFFSET; private static final int FLOAT_ARRAY_OFFSET; private static final int DOUBLE_ARRAY_OFFSET; + private static final int OBJECT_ARRAY_INDEX_SCALE; // GraalVM native-image recognizes arrayBaseOffset only when the call stores directly into the // target static field. Keep these assignments in this shape so native images recompute heap array @@ -91,6 +92,7 @@ public final class MemoryBuffer { LONG_ARRAY_OFFSET = 0; FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; + OBJECT_ARRAY_INDEX_SCALE = 4; } else { BOOLEAN_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(boolean[].class); BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); @@ -100,6 +102,7 @@ public final class MemoryBuffer { LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); FLOAT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + OBJECT_ARRAY_INDEX_SCALE = UNSAFE.arrayIndexScale(Object[].class); } } @@ -4183,6 +4186,10 @@ public static MemoryBuffer fromDirectByteBuffer( return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); } + public static int objectArrayIndexScale() { + return OBJECT_ARRAY_INDEX_SCALE > 0 ? OBJECT_ARRAY_INDEX_SCALE : 4; + } + /** * Create a heap buffer of specified initial size. The buffer will grow automatically if not * enough. diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 9fe08fdfb5..52237e7082 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -51,6 +51,19 @@ private static void throwInvalidObjectArraySize(int size) { throw new DeserializationException("Object array size must be non-negative: " + size); } + private static int readObjectArraySize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = buffer.readVarUInt32Small7(); + // Keep this as direct primitive branches. Object-array reads allocate immediately; using + // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. + if (numElements < 0) { + throwInvalidObjectArraySize(numElements); + } + readContext.reserveObjectArrayMemory(numElements); + buffer.checkReadableBytes(numElements); + return numElements; + } + /** * Returns the object-array serializer for {@code cls}. * @@ -128,14 +141,7 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -213,14 +219,7 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -654,14 +653,7 @@ public void write(WriteContext writeContext, Object[] value) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 35eeca550a..b5853d433b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -343,18 +343,18 @@ private static Object readNotNull( if (array == null) { return null; } - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } if (readMode == READ_LIST_TO_LIST) { return readListBodyAsListTarget(readContext, arrayTypeId, elementTypeId, targetType); } if (readMode == READ_ARRAY_TO_LIST) { Object array = readDenseArrayBody(readContext, arrayTypeId); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } if (readMode == READ_ARRAY_TO_ARRAY) { Object array = readDenseArrayBody(readContext, arrayTypeId); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } throw new IllegalStateException("Unexpected compatible read mode " + readMode); } @@ -621,7 +621,7 @@ private static Object readListBodyAsListTarget( validateElementCount(numElements); if (numElements == 0) { Object array = readListPrimitiveElements(buffer, 0, arrayTypeId, elementTypeId, false); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } int flags = buffer.readByte(); boolean hasNull = (flags & CollectionFlags.HAS_NULL) == CollectionFlags.HAS_NULL; @@ -654,11 +654,11 @@ private static Object readListBodyAsListTarget( throw new DeserializationException( "Cannot read null peer list element into local list field"); } - return readNullableListBoxedElements(buffer, numElements, arrayTypeId, elementTypeId); + return readNullableListBoxedElements(readContext, numElements, arrayTypeId, elementTypeId); } Object array = readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, false); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } private static Object readDenseArrayBody(ReadContext readContext, int arrayTypeId) { @@ -976,8 +976,11 @@ private static void readNonNullListElement(MemoryBuffer buffer) { } private static List readNullableListBoxedElements( - MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { - buffer.checkReadableBytes(minReadablePrimitiveListBytes(numElements, elementTypeId, true)); + ReadContext readContext, int numElements, int arrayTypeId, int elementTypeId) { + MemoryBuffer buffer = readContext.getBuffer(); + int bodyBytes = minReadablePrimitiveListBytes(numElements, elementTypeId, true); + readContext.reserveCollectionMemory(numElements); + buffer.checkReadableBytes(bodyBytes); ArrayList values = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { byte headFlag = buffer.readByte(); @@ -1043,7 +1046,8 @@ private static Object readBoxedListElement( } } - private static Object materializeTarget(Object array, int arrayTypeId, Class targetType) { + private static Object materializeTarget( + ReadContext readContext, Object array, int arrayTypeId, Class targetType) { if (targetType.isArray()) { return array; } @@ -1058,7 +1062,7 @@ private static Object materializeTarget(Object array, int arrayTypeId, Class return primitiveList; } if (targetType.isAssignableFrom(ArrayList.class)) { - return materializeBoxedList(array, arrayTypeId); + return materializeBoxedList(readContext, array, arrayTypeId); } throw new DeserializationException("Unsupported compatible list/array target " + targetType); } @@ -1172,8 +1176,10 @@ private static boolean canMaterializePrimitiveListTarget(Class targetType, in } } - private static List materializeBoxedList(Object array, int arrayTypeId) { + private static List materializeBoxedList( + ReadContext readContext, Object array, int arrayTypeId) { int size = java.lang.reflect.Array.getLength(array); + readContext.reserveCollectionMemory(size); ArrayList list = new ArrayList<>(size); switch (arrayTypeId) { case Types.BOOL_ARRAY: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java index f7840349ef..f58d56d08b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java @@ -249,7 +249,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -295,7 +295,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -403,7 +403,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index 3915b5d888..b6151f45a9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -461,7 +461,7 @@ public T read(ReadContext readContext) { */ public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readCollectionSize(buffer); + numElements = readCollectionSize(readContext); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -560,9 +560,11 @@ protected void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readCollectionSize(MemoryBuffer buffer) { + protected final int readCollectionSize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); + readContext.reserveCollectionMemory(numElements); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index a81c38298f..2456377485 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -127,7 +127,7 @@ public ArrayListSerializer(TypeResolver typeResolver) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -189,7 +189,7 @@ public List read(ReadContext readContext) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -205,7 +205,7 @@ public HashSetSerializer(TypeResolver typeResolver) { @Override public HashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); HashSet hashSet = new HashSet(numElements); readContext.reference(hashSet); @@ -221,7 +221,7 @@ public LinkedHashSetSerializer(TypeResolver typeResolver) { @Override public LinkedHashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); LinkedHashSet hashSet = new LinkedHashSet(numElements); readContext.reference(hashSet); @@ -270,7 +270,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); T collection; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); @@ -335,7 +335,7 @@ public void write(WriteContext writeContext, List value) { @Override public List read(ReadContext readContext) { if (config.isXlang()) { - int numElements = readCollectionSize(readContext.getBuffer()); + int numElements = readCollectionSize(readContext); if (numElements != 0) { throw new DeserializationException( "Empty list body must have zero elements but got " + numElements); @@ -356,7 +356,7 @@ public CopyOnWriteArrayListSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -390,7 +390,7 @@ public CopyOnWriteArraySetSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -542,7 +542,7 @@ public CollectionSnapshot onCollectionWrite( @Override public ConcurrentSkipListSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); if (config.isXlang()) { ConcurrentSkipListSet skipListSet = new ConcurrentSkipListSet(); @@ -726,7 +726,7 @@ public VectorSerializer(TypeResolver typeResolver, Class cls) { @Override public Vector newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); Vector vector = new Vector<>(numElements); readContext.reference(vector); @@ -743,7 +743,7 @@ public ArrayDequeSerializer(TypeResolver typeResolver, Class cls) { @Override public ArrayDeque newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayDeque deque = new ArrayDeque(numElements); readContext.reference(deque); @@ -786,9 +786,9 @@ public void write(WriteContext writeContext, EnumSet object) { public EnumSet read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); Class elemClass = typeResolver.readTypeInfo(readContext).getType(); + int length = readCollectionSize(readContext); EnumSet object = EnumSet.noneOf(elemClass); Serializer elemSerializer = typeResolver.getSerializer(elemClass); - int length = readCollectionSize(buffer); for (int i = 0; i < length; i++) { object.add(elemSerializer.read(readContext)); } @@ -863,7 +863,7 @@ public Collection newCollection(CopyContext copyContext, Collection collection) public PriorityQueue newCollection(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); PriorityQueue queue = new PriorityQueue(comparator); @@ -923,10 +923,11 @@ public CollectionSnapshot onCollectionWrite( @Override public ArrayBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); + readContext.reserveCollectionCapacity(numElements, capacity); buffer.checkReadableBytes(capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); @@ -990,10 +991,12 @@ public CollectionSnapshot onCollectionWrite( @Override public LinkedBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); + // LinkedBlockingQueue capacity is a logical bound, not preallocated backing storage. The + // current node storage is already charged by readCollectionSize(numElements). LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity); readContext.reference(queue); return queue; @@ -1130,7 +1133,7 @@ public XlangListDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public List newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); @@ -1146,7 +1149,7 @@ public XlangSetDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public Set newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); HashSet set = new HashSet(numElements); readContext.reference(set); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index c28aa04561..a5aba71aaa 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -94,7 +94,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -127,7 +127,7 @@ public RegularImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer(numElements); } @@ -161,7 +161,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -203,7 +203,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedCollectionContainer(comparator, numElements); @@ -236,7 +236,7 @@ public GuavaMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); return new MapContainer(numElements); } @@ -264,7 +264,7 @@ public T onMapRead(Map map) { @Override public T read(ReadContext readContext) { - int size = readMapSize(readContext.getBuffer()); + int size = readMapSize(readContext); Map map = new HashMap(); readElements(readContext, size, map); return xnewInstance(map); @@ -574,7 +574,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedMapContainer<>(comparator, numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java index cd69f2b6cf..7a9f9f017d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java @@ -125,7 +125,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -186,7 +186,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -247,7 +247,7 @@ public ImmutableMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new JDKImmutableMapContainer(numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 334bd8a35c..d4c98e2ad5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -957,7 +957,7 @@ public void onMapWriteFinish(Map map) {} */ public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readMapSize(buffer); + numElements = readMapSize(readContext); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -1026,12 +1026,14 @@ public void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readMapSize(MemoryBuffer buffer) { + protected final int readMapSize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); checkMapSize(numElements); if (numElements > Integer.MAX_VALUE / 2) { throwInvalidMapBodySize(numElements); } + readContext.reserveMapMemory(numElements); buffer.checkReadableBytes(numElements << 1); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java index 9b3825495a..91f9ba2d06 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java @@ -86,7 +86,7 @@ public HashMapSerializer(TypeResolver typeResolver) { @Override public HashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); HashMap hashMap = new HashMap(numElements); readContext.reference(hashMap); @@ -107,7 +107,7 @@ public LinkedHashMapSerializer(TypeResolver typeResolver) { @Override public LinkedHashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); LinkedHashMap hashMap = new LinkedHashMap(numElements); readContext.reference(hashMap); @@ -146,7 +146,7 @@ public LazyMapSerializer(TypeResolver typeResolver) { @Override public LazyMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); LazyMap map = new LazyMap(numElements); readContext.reference(map); @@ -200,7 +200,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - setNumElements(readMapSize(buffer)); + setNumElements(readMapSize(readContext)); T map; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); if (type == TreeMap.class) { @@ -322,7 +322,7 @@ public ConcurrentHashMapSerializer(TypeResolver typeResolver, Class keyType = typeResolver.readTypeInfo(readContext).getType(); EnumMap map = new EnumMap(keyType); readContext.reference(map); @@ -619,7 +619,7 @@ public Object onMapCopy(Map map) { public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); HashMap map = new HashMap<>(numElements); readContext.reference(map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java index f11f4b79a3..c011c91277 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java @@ -158,7 +158,7 @@ public List read(ReadContext readContext) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); diff --git a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java index 7fbf021ff5..27cbd3643c 100644 --- a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java @@ -81,6 +81,7 @@ public final class MemoryBuffer { private static final int LONG_ARRAY_OFFSET = 0; private static final int FLOAT_ARRAY_OFFSET = 0; private static final int DOUBLE_ARRAY_OFFSET = 0; + private static final int OBJECT_ARRAY_INDEX_SCALE = 4; private static final VarHandle BYTE_ARRAY_CHAR = MethodHandles.byteArrayViewVarHandle(char[].class, NATIVE_ORDER); private static final VarHandle BYTE_ARRAY_SHORT = @@ -3922,6 +3923,10 @@ public static MemoryBuffer fromDirectByteBuffer( return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); } + public static int objectArrayIndexScale() { + return OBJECT_ARRAY_INDEX_SCALE; + } + /** * Create a heap buffer of specified initial size. The buffer will grow automatically if not * enough. diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java index 63e3ffcdc1..94c0e893b8 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java @@ -344,7 +344,7 @@ public static void withWriteContext( public static T withReadContext( Fory fory, MemoryBuffer buffer, Function action) { ReadContext context = (ReadContext) ReflectionUtils.getObjectFieldValue(fory, "readContext"); - context.prepare(buffer, null, false); + context.prepare(buffer, null, false, buffer.remaining(), false); try { return action.apply(context); } finally { diff --git a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java index 71f73582c2..3b379c7b3a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java @@ -53,7 +53,7 @@ public void testForyStructInput(boolean compressNumber) throws IOException { buffer.writeFloat32(4.1f); buffer.writeFloat64(4.2); new StringSerializer(fory.getConfig()).writeString(buffer, "abc"); - fory.getReadContext().prepare(buffer, null, false); + fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); try (MemoryBufferObjectInput input = new MemoryBufferObjectInput(fory.getConfig(), fory.getReadContext())) { assertEquals(input.readByte(), 1); diff --git a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java index a04a0e9f3c..f5b69b8b13 100644 --- a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java @@ -46,7 +46,7 @@ public void testForyStructOutput() throws IOException { output.writeChars("abc"); output.writeUTF("abc"); } - fory.getReadContext().prepare(buffer, null, false); + fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); try (MemoryBufferObjectInput input = new MemoryBufferObjectInput(fory.getConfig(), fory.getReadContext())) { assertEquals(input.readByte(), 1); diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java index 0b94da34cb..157b3950f5 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java @@ -360,7 +360,7 @@ public void testRemoteTypeDefChecksTypeChecker() { ReadContext readContext = reader.getReadContext(); readContext.setMetaReadContext(new MetaReadContext()); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); buffer.writeVarUInt32(0); typeDef.writeTypeDef(buffer); buffer.readerIndex(0); @@ -473,7 +473,7 @@ public void testExactLocalEnumTypeDefBypassesLimit() { ReadContext readContext = fory.getReadContext(); readContext.setMetaReadContext(new MetaReadContext()); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); buffer.writeVarUInt32(0); exact.writeTypeDef(buffer); buffer.readerIndex(0); @@ -792,7 +792,7 @@ public void testWriteClassName() { } finally { fory.getWriteContext().reset(); } - fory.getReadContext().prepare(buffer, null, false); + fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); try { Assert.assertSame(classResolver.readTypeInfo(fory.getReadContext()).getType(), getClass()); Assert.assertSame(classResolver.readTypeInfo(fory.getReadContext()).getType(), getClass()); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java index b0fb46f0f4..752eb77686 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java @@ -373,11 +373,15 @@ private static Object readPrimitiveArrayBody( MemoryBuffer control = MemoryBuffer.newHeapBuffer(1); control.writeBoolean(false); readContext.prepare( - control, Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), true); + control, + Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), + true, + control.remaining(), + false); } else { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); } return fory.getSerializer(arrayType).read(readContext); } @@ -387,8 +391,8 @@ private static Object readTruncatedPrimitiveArrayBody( ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); - readContext.prepare( - MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())), null, false); + MemoryBuffer truncated = MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); + readContext.prepare(truncated, null, false, truncated.remaining(), false); return fory.getSerializer(arrayType).read(readContext); } @@ -396,7 +400,7 @@ private static Object readPrimitiveArrayRawBody(Fory fory, Class arrayType) { ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(arrayType).read(readContext); } @@ -404,7 +408,7 @@ private static Object readObjectArrayBody(Fory fory, Class arrayType, int num ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(numElements); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(arrayType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java index cd3d31dcac..492bee83f8 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java @@ -33,6 +33,7 @@ import org.apache.fory.ForyTestBase; import org.apache.fory.TestUtils; import org.apache.fory.config.Language; +import org.apache.fory.context.ReadContext; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.serializer.collection.UnmodifiableSerializersTest; @@ -138,14 +139,21 @@ public void testWriteCompatibleBasic() throws Exception { public void testNullableListBodyBounds() throws Exception { Method method = CompatibleCollectionArrayReader.class.getDeclaredMethod( - "readNullableListBoxedElements", MemoryBuffer.class, int.class, int.class, int.class); + "readNullableListBoxedElements", ReadContext.class, int.class, int.class, int.class); method.setAccessible(true); MemoryBuffer buffer = MemoryUtils.buffer(0); - InvocationTargetException exception = - Assert.expectThrows( - InvocationTargetException.class, - () -> method.invoke(null, buffer, 1024, Types.INT32_ARRAY, Types.INT32)); - Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + Fory fory = builder().build(); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false, buffer.remaining(), false); + try { + InvocationTargetException exception = + Assert.expectThrows( + InvocationTargetException.class, + () -> method.invoke(null, readContext, 1024, Types.INT32_ARRAY, Types.INT32)); + Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + } finally { + readContext.reset(); + } } @Test diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java new file mode 100644 index 0000000000..09b73c25d8 --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.apache.fory.Fory; +import org.apache.fory.ForyTestBase; +import org.apache.fory.collection.Int32List; +import org.apache.fory.context.ReadContext; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.InsecureException; +import org.apache.fory.io.ForyInputStream; +import org.apache.fory.memory.MemoryBuffer; +import org.testng.annotations.Test; + +public class ContainerMemoryBudgetTest extends ForyTestBase { + private static final long KNOWN_ROOT_MULTIPLIER = 8L; + private static final long KNOWN_ROOT_SLACK_BYTES = 64L * 1024; + private static final long STREAM_ROOT_BYTES = 128L * 1024 * 1024; + private static final long COLLECTION_OBJECT_BYTES = 24L; + private static final long MAP_OBJECT_BYTES = 48L; + private static final long ARRAY_HEADER_BYTES = 16L; + private static final long MAP_ENTRY_BYTES = 32L; + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + + @Test + public void testConfigValidation() { + assertEquals(newFory(-1).getConfig().maxContainerMemoryBytes(), -1); + assertEquals(newFory(123).getConfig().maxContainerMemoryBytes(), 123); + assertThrows(IllegalArgumentException.class, () -> builder().withMaxContainerMemoryBytes(0)); + assertThrows(IllegalArgumentException.class, () -> builder().withMaxContainerMemoryBytes(-2)); + } + + @Test + public void testKnownAutoBudget() { + Fory fory = newFory(-1); + ReadContext readContext = prepareContext(fory, 17, false); + try { + long budget = knownAutoBytes(17); + readContext.reserveContainerMemory(budget); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + } finally { + readContext.reset(); + } + } + + @Test + public void testStreamAutoBudget() { + Fory fory = newFory(-1); + ReadContext readContext = prepareContext(fory, 17, true); + try { + readContext.reserveContainerMemory(STREAM_ROOT_BYTES); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + } finally { + readContext.reset(); + } + + StreamPayload payload = findStreamPayload(); + assertThrows(InsecureException.class, () -> newFory(-1).deserialize(payload.bytes)); + Object copy = + newFory(-1).deserialize(new ForyInputStream(new ByteArrayInputStream(payload.bytes), 1)); + assertEquals(copy, payload.value); + } + + @Test + public void testExplicitBudgetWins() { + Fory fory = newFory(7); + ReadContext readContext = prepareContext(fory, 1024 * 1024, false); + try { + readContext.reserveContainerMemory(7); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + } finally { + readContext.reset(); + } + } + + @Test + public void testNestedEmptyFixedCost() { + List value = emptyLists(1); + byte[] bytes = newFory(-1).serialize(value); + + assertThrows(InsecureException.class, () -> newFory(collectionBytes(1)).deserialize(bytes)); + assertEquals(newFory(collectionBytes(1) + collectionBytes(0)).deserialize(bytes), value); + } + + @Test + public void testSiblingBudgetIsCumulative() { + List value = nullLists(2, 64); + byte[] bytes = newFory(-1).serialize(value); + long firstChildOnly = collectionBytes(2) + collectionBytes(64); + + assertThrows(InsecureException.class, () -> newFory(firstChildOnly).deserialize(bytes)); + assertEquals(newFory(firstChildOnly + collectionBytes(64)).deserialize(bytes), value); + } + + @Test + public void testMapBudgetAndOverflow() { + Fory fory = newFory(mapBytes(1) - 1); + ReadContext readContext = prepareContext(fory, 8, false); + try { + assertThrows(InsecureException.class, () -> readContext.reserveMapMemory(1)); + } finally { + readContext.reset(); + } + + Fory exactFory = newFory(mapBytes(1)); + ReadContext exactContext = prepareContext(exactFory, 8, false); + try { + exactContext.reserveMapMemory(1); + assertThrows(InsecureException.class, () -> exactContext.reserveContainerMemory(1)); + } finally { + exactContext.reset(); + } + + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(Integer.MAX_VALUE); + buffer = trimBuffer(buffer); + Fory reader = newFory(STREAM_ROOT_BYTES); + ReadContext mapContext = reader.getReadContext(); + mapContext.prepare(buffer, null, false, buffer.remaining(), false); + try { + assertThrows( + DeserializationException.class, + () -> reader.getSerializer(HashMap.class).read(mapContext)); + } finally { + mapContext.reset(); + } + } + + @Test + public void testObjectArrayBudget() { + Fory lowFory = newFory(objectArrayBytes(0) - 1); + ReadContext lowContext = lowFory.getReadContext(); + MemoryBuffer lowBuffer = objectArraySizeBuffer(0); + lowContext.prepare(lowBuffer, null, false, lowBuffer.remaining(), false); + try { + assertThrows( + InsecureException.class, () -> lowFory.getSerializer(Object[].class).read(lowContext)); + } finally { + lowContext.reset(); + } + + Fory exactFory = newFory(objectArrayBytes(0)); + ReadContext exactContext = exactFory.getReadContext(); + MemoryBuffer exactBuffer = objectArraySizeBuffer(0); + exactContext.prepare(exactBuffer, null, false, exactBuffer.remaining(), false); + try { + Object[] array = (Object[]) exactFory.getSerializer(Object[].class).read(exactContext); + assertEquals(array.length, 0); + } finally { + exactContext.reset(); + } + + Fory slotFory = newFory(objectArrayBytes(2) - 1); + ReadContext slotContext = slotFory.getReadContext(); + MemoryBuffer slotBuffer = objectArraySizeBuffer(2); + slotContext.prepare(slotBuffer, null, false, slotBuffer.remaining(), false); + try { + assertThrows( + InsecureException.class, () -> slotFory.getSerializer(Object[].class).read(slotContext)); + } finally { + slotContext.reset(); + } + } + + @Test + public void testScalarOwnersSkipBudget() { + Fory fory = newFory(1); + assertEquals(fory.deserialize(fory.serialize("container budget")), "container budget"); + + byte[] bytes = new byte[] {1, 2, 3}; + assertTrue(Arrays.equals((byte[]) fory.deserialize(fory.serialize(bytes)), bytes)); + + int[] ints = new int[] {4, 5, 6}; + assertTrue(Arrays.equals((int[]) fory.deserialize(fory.serialize(ints)), ints)); + + Int32List denseList = new Int32List(new int[] {7, 8, 9}); + assertEquals(fory.deserialize(fory.serialize(denseList)), denseList); + } + + @Test + public void testTruncatedCollectionStillFails() { + Fory fory = newFory(collectionBytes(3)); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(3); + buffer.writeByte(0); + buffer.writeByte(0); + buffer = trimBuffer(buffer); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false, buffer.remaining(), false); + try { + assertThrows( + IndexOutOfBoundsException.class, + () -> fory.getSerializer(ArrayList.class).read(readContext)); + } finally { + readContext.reset(); + } + } + + private static Fory newFory(long maxContainerMemoryBytes) { + return builder().withMaxContainerMemoryBytes(maxContainerMemoryBytes).build(); + } + + private static ReadContext prepareContext( + Fory fory, int rootInputBytes, boolean unknownLengthInput) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(0); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false, rootInputBytes, unknownLengthInput); + return readContext; + } + + private static long collectionBytes(int numElements) { + return COLLECTION_OBJECT_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long mapBytes(int numElements) { + long entries = numElements; + return MAP_OBJECT_BYTES + + entries * 2 * REFERENCE_BYTES + + entries * (MAP_ENTRY_BYTES + 3L * REFERENCE_BYTES); + } + + private static long objectArrayBytes(int numElements) { + return ARRAY_HEADER_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long knownAutoBytes(int inputBytes) { + return inputBytes * KNOWN_ROOT_MULTIPLIER + KNOWN_ROOT_SLACK_BYTES; + } + + private static List emptyLists(int numElements) { + List root = new ArrayList<>(numElements); + for (int i = 0; i < numElements; i++) { + root.add(new ArrayList<>()); + } + return root; + } + + private static List nullLists(int siblings, int childElements) { + List root = new ArrayList<>(siblings); + for (int i = 0; i < siblings; i++) { + List child = new ArrayList<>(childElements); + for (int j = 0; j < childElements; j++) { + child.add(null); + } + root.add(child); + } + return root; + } + + private static List emptyMaps(int numElements) { + List root = new ArrayList<>(numElements); + for (int i = 0; i < numElements; i++) { + root.add(new HashMap<>()); + } + return root; + } + + private static MemoryBuffer objectArraySizeBuffer(int numElements) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(numElements); + return trimBuffer(buffer); + } + + private static MemoryBuffer trimBuffer(MemoryBuffer buffer) { + return MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); + } + + private static StreamPayload findStreamPayload() { + Fory writer = newFory(-1); + int numElements = 128; + while (numElements <= 1 << 20) { + List value = emptyMaps(numElements); + byte[] bytes = writer.serialize(value); + long estimatedMemory = collectionBytes(numElements) + (long) numElements * mapBytes(0); + if (estimatedMemory > knownAutoBytes(bytes.length) && estimatedMemory < STREAM_ROOT_BYTES) { + return new StreamPayload(value, bytes); + } + numElements <<= 1; + } + throw new AssertionError("Unable to build compact stream-budget payload"); + } + + private static final class StreamPayload { + final List value; + final byte[] bytes; + + StreamPayload(List value, byte[] bytes) { + this.value = value; + this.bytes = bytes; + } + } +} diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java index 2b0505e077..2ec74b12a0 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java @@ -179,7 +179,7 @@ public void testThrowableReadsMainWireOrderWithCyclicCause() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false); + readContext.prepare(payload, null, false, payload.remaining(), false); readContext.preserveRefId(); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(RuntimeException.class); @@ -251,7 +251,7 @@ public void testThrowableRejectsMismatchedClassLayerCount() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false); + readContext.prepare(payload, null, false, payload.remaining(), false); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(CustomException.class); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java index bc945d60ff..4c7d650331 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java @@ -307,7 +307,7 @@ private static Object readPrimitiveListBody(Fory fory, Class listType, int he MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(headerSize); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(listType).read(readContext); } @@ -315,7 +315,7 @@ private static Object readPrimitiveListRawBody(Fory fory, Class listType) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(listType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java index afc536be28..cfc4cfd60f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java @@ -143,7 +143,7 @@ public void testChildCollectionRejectsMismatchedClassLayerCount() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false); + readContext.prepare(payload, null, false, payload.remaining(), false); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(ChildArrayList.class); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java index ebff4537ec..cabc02f65a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java @@ -1218,7 +1218,7 @@ public void testBitSetRejectsNegativeBinary() { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); Assert.expectThrows( DeserializationException.class, () -> fory.getSerializer(BitSet.class).read(readContext)); } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index bccde267e8..f95c38d72f 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -531,6 +531,13 @@ export class WriteContext { export class ReadContext { private static readonly MIN_REMOTE_TYPE_META_LIMIT = 8192; + private static readonly KNOWN_ROOT_BUDGET_MULTIPLIER = 8; + private static readonly KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024; + private static readonly COLLECTION_OBJECT_BYTES = 24; + private static readonly MAP_OBJECT_BYTES = 48; + private static readonly ARRAY_HEADER_BYTES = 16; + private static readonly MAP_ENTRY_BYTES = 32; + private static readonly REFERENCE_BYTES = 4; readonly reader: BinaryReader; readonly refReader: RefReader; @@ -548,6 +555,9 @@ export class ReadContext { private _depth = 0; private _maxDepth: number; + private readonly maxContainerMemoryBytes: number; + private effectiveContainerMemoryBytes = 0; + private remainingContainerMemoryBytes = 0; private remoteSchemaVersionsByType: Map | undefined = undefined; @@ -559,6 +569,7 @@ export class ReadContext { this.refReader = new RefReader(this.reader); this.metaStringReader = new MetaStringReader(); this._maxDepth = config.maxDepth ?? 50; + this.maxContainerMemoryBytes = config.maxContainerMemoryBytes; } reset(bytes: Uint8Array) { @@ -567,6 +578,71 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; + this.effectiveContainerMemoryBytes = this.maxContainerMemoryBytes > 0 + ? this.maxContainerMemoryBytes + : bytes.byteLength * ReadContext.KNOWN_ROOT_BUDGET_MULTIPLIER + + ReadContext.KNOWN_ROOT_BUDGET_SLACK_BYTES; + this.remainingContainerMemoryBytes = this.effectiveContainerMemoryBytes; + } + + reserveCollectionMemory(numElements: number) { + const bytes + = ReadContext.COLLECTION_OBJECT_BYTES + + numElements * ReadContext.REFERENCE_BYTES; + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + reserveMapMemory(numElements: number) { + const bytes = ReadContext.MAP_OBJECT_BYTES + + numElements + * ( + ReadContext.REFERENCE_BYTES * 2 + + ReadContext.MAP_ENTRY_BYTES + + ReadContext.REFERENCE_BYTES * 3 + ); + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + reserveTypedArrayMemory(numElements: number, elementBytes: number) { + const bytes = ReadContext.ARRAY_HEADER_BYTES + numElements * elementBytes; + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + reserveContainerMemory(bytes: number) { + if (!Number.isSafeInteger(bytes) || bytes < 0) { + this.throwContainerMemoryOverflow(bytes); + } + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + private throwContainerMemoryOverflow(bytes: number): never { + throw new Error( + `maxContainerMemoryBytes overflow: requested ${bytes} estimated container bytes`, + ); + } + + private throwContainerBudgetExceeded(bytes: number): never { + throw new Error( + `maxContainerMemoryBytes exceeded: requested ${bytes} estimated container bytes, ` + + `${this.remainingContainerMemoryBytes} remaining, effective limit ` + + `${this.effectiveContainerMemoryBytes}`, + ); } isCompatible() { diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index 05b424b1cc..da22e247ba 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -38,6 +38,7 @@ const DEFAULT_MAX_TYPE_FIELDS = 512 as const; const DEFAULT_MAX_TYPE_META_BYTES = 4096 as const; const DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE = 10 as const; const DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE = 3 as const; +const DEFAULT_MAX_CONTAINER_MEMORY_BYTES = -1 as const; export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; @@ -105,10 +106,21 @@ export default class Fory { `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, ); } + const maxContainerMemoryBytes + = config?.maxContainerMemoryBytes ?? DEFAULT_MAX_CONTAINER_MEMORY_BYTES; + if ( + !Number.isSafeInteger(maxContainerMemoryBytes) + || (maxContainerMemoryBytes !== -1 && maxContainerMemoryBytes <= 0) + ) { + throw new Error( + `maxContainerMemoryBytes must be -1 or a positive safe integer but got ${maxContainerMemoryBytes}`, + ); + } return { ref: Boolean(config?.ref), useSliceString: Boolean(config?.useSliceString), maxDepth: config?.maxDepth, + maxContainerMemoryBytes, maxTypeFields, maxTypeMetaBytes, maxSchemaVersionsPerType, diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index 74344fc00f..620543198b 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -100,6 +100,30 @@ function compatibleArrayCollectionExpr( } } +function compatibleArrayElementBytes(elementTypeId: number): number { + switch (elementTypeId) { + case TypeId.BOOL: + case TypeId.INT8: + case TypeId.UINT8: + return 1; + case TypeId.INT16: + case TypeId.UINT16: + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + return 2; + case TypeId.INT32: + case TypeId.UINT32: + case TypeId.FLOAT32: + return 4; + case TypeId.INT64: + case TypeId.UINT64: + case TypeId.FLOAT64: + return 8; + default: + return 4; + } +} + function compatibleArrayPutAccessor( elementTypeId: number, result: string, @@ -245,10 +269,12 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); + this.readContext.reserveCollectionMemory(len); if (len === 0) { return createCollection(len); } const flags = this.readContext.reader.readUint8(); + this.readContext.reader.checkReadableBytes(len); const result = createCollection(len); // IMPORTANT: collection readers must obey the ref/null bits written on the // wire, not local TypeScript metadata that may imply a different ref @@ -456,6 +482,9 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const newCollection = compatibleListToArray ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); + const reserveMemory = compatibleListToArray + ? `${readContextName}.reserveTypedArrayMemory(${len}, ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` + : `${readContextName}.reserveCollectionMemory(${len});`; const putAccessor = (item: string, index: string) => compatibleListToArray ? compatibleArrayPutAccessor( @@ -495,6 +524,7 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera : `${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName});`; return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; + ${reserveMemory} let ${flags} = 0; if (${len} > 0) { ${flags} = ${this.builder.reader.readUint8()}; diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index c447691d40..dd24c45c3f 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -298,6 +298,7 @@ class MapAnySerializer { read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); + this.readContext.reserveMapMemory(count); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -527,6 +528,7 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { return ` let ${count} = ${this.builder.reader.readVarUint32Small7()}; + ${readContextName}.reserveMapMemory(${count}); const ${result} = new Map(); if (${refState}) { ${this.builder.referenceResolver.reference(result)} diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index bfb4a942fb..f08ca07e8d 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -292,6 +292,7 @@ export interface Config { ref: boolean; useSliceString: boolean; maxDepth?: number; + maxContainerMemoryBytes: number; maxTypeFields: number; maxTypeMetaBytes: number; maxSchemaVersionsPerType: number; diff --git a/javascript/test/containerMemoryBudget.test.ts b/javascript/test/containerMemoryBudget.test.ts new file mode 100644 index 0000000000..77907ea3e3 --- /dev/null +++ b/javascript/test/containerMemoryBudget.test.ts @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import Fory, { Type } from '../packages/core/index'; +import { describe, expect, test } from '@jest/globals'; + +const KNOWN_SLACK_BYTES = 64 * 1024; + +function serializeAny(value: unknown) { + return new Fory({ compatible: false, ref: true }).serialize(value); +} + +function deserializeAny(bytes: Uint8Array, maxContainerMemoryBytes: number) { + return new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes, + }).deserialize(bytes); +} + +describe('container memory budget', () => { + test('uses known length auto budget', () => { + const inputBytes = 17; + const fory = new Fory({ compatible: false }); + const budget = inputBytes * 8 + KNOWN_SLACK_BYTES; + + fory.readContext.reset(new Uint8Array(inputBytes)); + expect(() => fory.readContext.reserveContainerMemory(budget)).not.toThrow(); + expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( + /maxContainerMemoryBytes/, + ); + }); + + test('validates explicit config', () => { + expect(() => new Fory({ maxContainerMemoryBytes: 0 })).toThrow( + /maxContainerMemoryBytes/, + ); + expect(() => new Fory({ maxContainerMemoryBytes: -2 })).toThrow( + /maxContainerMemoryBytes/, + ); + + const fory = new Fory({ maxContainerMemoryBytes: 24 }); + fory.readContext.reset(new Uint8Array(1)); + expect(() => fory.readContext.reserveCollectionMemory(0)).not.toThrow(); + expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( + /maxContainerMemoryBytes/, + ); + }); + + test('charges nested empty containers', () => { + const typeInfo = Type.struct('budget.nested.empty', { + values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ values: [[]] }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 52, + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 51, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); + }); + + test('charges sibling containers cumulatively', () => { + const typeInfo = Type.struct('budget.sibling.empty', { + values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: [[], [], []], + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 108, + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 107, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + values: [[], [], []], + }); + }); + + test('charges map entries', () => { + const bytes = serializeAny(new Map([[1, 2]])); + + expect(() => deserializeAny(bytes, 99)).toThrow(/maxContainerMemoryBytes/); + expect(deserializeAny(bytes, 100)).toEqual(new Map([[1, 2]])); + }); + + test('charges generated containers', () => { + const typeInfo = Type.struct('budget.generated', { + list: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), + set: Type.set(Type.string()).setId(2), + map: Type.map(Type.string(), Type.int32({ encoding: 'fixed' })).setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + list: [1], + set: new Set(['a']), + map: new Map([['k', 1]]), + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 156, + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 155, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + list: [1], + set: new Set(['a']), + map: new Map([['k', 1]]), + }); + }); + + test('charges compatible typed arrays', () => { + const writerType = Type.struct(9010, { + values: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), + }); + const readerType = Type.struct(9010, { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: true }); + const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); + const passingReader = new Fory({ + compatible: true, + maxContainerMemoryBytes: 28, + }).register(readerType); + const failingReader = new Fory({ + compatible: true, + maxContainerMemoryBytes: 27, + }).register(readerType); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([ + 1, + 2, + 3, + ]); + }); + + test('skips scalar dense owners', () => { + const typeInfo = Type.struct('budget.skipped', { + text: Type.string().setId(1), + binary: Type.binary().setId(2), + values: Type.int32Array().setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + text: 'hello', + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 1, + }).register(typeInfo); + + expect(reader.deserialize(bytes)).toEqual({ + text: 'hello', + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + }); + + test('keeps byte checks', () => { + const typeInfo = Type.struct('budget.bytecheck', { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 1024 * 1024, + }).register(typeInfo); + + expect(() => reader.deserialize(bytes.slice(0, bytes.length - 1))).toThrow(); + }); +}); diff --git a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt index e3b36d4785..be2980313f 100644 --- a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt +++ b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt @@ -21,7 +21,6 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.context.ReadContext import org.apache.fory.context.WriteContext -import org.apache.fory.exception.DeserializationException import org.apache.fory.resolver.TypeResolver import org.apache.fory.serializer.collection.CollectionLikeSerializer @@ -57,15 +56,8 @@ public class KotlinArrayDequeSerializer( } override fun newCollection(readContext: ReadContext): Collection { - val buffer = readContext.buffer - val numElements = buffer.readVarUInt32Small7() - if (numElements < 0) { - throw DeserializationException("Collection size must be non-negative: $numElements") - } + val numElements = readCollectionSize(readContext) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } return ArrayDequeBuilder(ArrayDeque(numElements)) } } diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt index 5684821ebe..1d17f39b91 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt @@ -20,8 +20,10 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.kotlin.ForyKotlin import org.testng.Assert.assertEquals +import org.testng.Assert.fail import org.testng.annotations.Test class CollectionSerializerTest { @@ -33,6 +35,22 @@ class CollectionSerializerTest { assertEquals(arrayDeque, fory.deserialize(fory.serialize(arrayDeque))) } + @Test + fun testArrayDequeContainerMemoryBudget() { + val writer: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() + val reader: Fory = + ForyKotlin.builder() + .withXlang(false) + .requireClassRegistration(true) + .withMaxContainerMemoryBytes(23) + .build() + + try { + reader.deserialize(writer.serialize(ArrayDeque())) + fail("Expected container memory budget failure") + } catch (ignored: InsecureException) {} + } + @Test fun testSerializeArrayList() { val fory: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index e4819ba424..0d6f4d0591 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -124,6 +124,7 @@ class Fory: "strict", "buffer", "max_depth", + "max_container_memory_bytes", "field_nullable", "policy", ) @@ -139,6 +140,7 @@ def __init__( max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, + max_container_memory_bytes: int = -1, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -183,6 +185,9 @@ def __init__( max_average_schema_versions_per_type: Average remote metadata versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. `-1` means auto; positive values are explicit byte limits. + policy: Custom deserialization policy for security checks. When provided, it controls which types can be deserialized, overriding the default policy. **Strongly recommended** when strict=False to maintain security controls. @@ -213,6 +218,13 @@ def __init__( raise ValueError("max_schema_versions_per_type must be a positive integer") if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") + if ( + not isinstance(max_container_memory_bytes, int) + or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) + or max_container_memory_bytes > (1 << 63) - 1 + ): + raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") + self.max_container_memory_bytes = max_container_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -225,6 +237,7 @@ def __init__( max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, + max_container_memory_bytes=max_container_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, @@ -559,6 +572,7 @@ def _deserialize( buffers=buffers, unsupported_objects=unsupported_objects, peer_out_of_band_enabled=peer_out_of_band_enabled, + root_input_bytes=buffer.size() - reader_index, ) return read_context.read_ref() diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 0183b26231..6dd5c5c4dc 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -466,10 +466,24 @@ cdef class ListSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i + cdef int64_t container_bytes if len_ == 0: + container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes list_ = PyList_New(0) return list_ + if len_ < 0: + read_context.reserve_collection_memory_c(len_) + else: + container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -583,10 +597,24 @@ cdef class TupleSerializer(CollectionSerializer): cdef bint has_null cdef int8_t head_flag cdef int64_t i + cdef int64_t container_bytes if len_ == 0: + container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes tuple_ = PyTuple_New(0) return tuple_ + if len_ < 0: + read_context.reserve_collection_memory_c(len_) + else: + container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -684,7 +712,7 @@ cdef class StringArraySerializer(ListSerializer): @cython.final cdef class SetSerializer(CollectionSerializer): cpdef read(self, ReadContext read_context): - cdef set instance = set() + cdef set instance cdef int32_t len_ cdef int8_t collect_flag cdef TypeInfo typeinfo @@ -701,11 +729,29 @@ cdef class SetSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i + cdef int64_t container_bytes - read_context.reference(instance) len_ = buffer.read_var_uint32() if len_ == 0: + container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes + instance = set() + read_context.reference(instance) return instance + if len_ < 0: + read_context.reserve_collection_memory_c(len_) + else: + container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes + read_context.check_readable_bytes(len_) + instance = set() + read_context.reference(instance) collect_flag = buffer.read_int8() if (collect_flag & COLL_IS_SAME_TYPE) != 0: @@ -1048,9 +1094,23 @@ cdef class MapSerializer(Serializer): cdef int32_t ref_id cdef dict map_ cdef int8_t chunk_header = 0 + cdef int64_t container_bytes if size == 0: + container_bytes = read_context.remaining_container_memory_bytes - _MAP_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_MAP_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes + map_ = {} + elif size < 0: + read_context.reserve_map_memory_c(size) map_ = {} else: + container_bytes = _MAP_OBJECT_BYTES + size * (_MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES) + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes read_context.check_readable_bytes(size) chunk_header = read_context.read_uint8() map_ = _PyDict_NewPresized(size) diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index d78673a6dc..c2e2e2a058 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -176,6 +176,9 @@ def _write_different_types(self, write_context, value, collect_flag=0): def read(self, read_context): length = read_context.read_var_uint32() + read_context.reserve_collection_memory(length) + if length != 0: + read_context.check_readable_bytes(length) collection_ = self.new_instance(read_context, self.type_) if length == 0: return collection_ @@ -455,6 +458,9 @@ def write(self, write_context, obj): def read(self, read_context): size = read_context.read_var_uint32() + read_context.reserve_map_memory(size) + if size != 0: + read_context.check_readable_bytes(size) map_ = {} ref_reader = read_context.ref_reader read_context.reference(map_) diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 702f09769c..a27084d466 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -30,6 +30,14 @@ STRING_TYPE_ID = TypeId.STRING SMALL_STRING_THRESHOLD = 16 cdef int32_t MAX_CACHED_META_STRINGS = 8192 cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 +cdef int64_t _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 +cdef int64_t _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 +cdef int64_t _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +cdef int64_t _COLLECTION_OBJECT_BYTES = 56 +cdef int64_t _MAP_OBJECT_BYTES = 64 +cdef int64_t _MAP_ENTRY_BYTES = 32 +cdef int64_t _REFERENCE_BYTES = sizeof(PyObject*) +cdef int64_t _MAX_CONTAINER_MEMORY_BYTES = 9223372036854775807 cdef inline uint64_t _mix64(uint64_t x): @@ -746,6 +754,9 @@ cdef class ReadContext: cdef readonly bint field_nullable cdef readonly object policy cdef readonly int32_t max_depth + cdef public int64_t max_container_memory_bytes + cdef public int64_t container_memory_limit_bytes + cdef public int64_t remaining_container_memory_bytes cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader cdef readonly MetaShareReadContext meta_share_context @@ -766,6 +777,9 @@ cdef class ReadContext: self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth + self.max_container_memory_bytes = config.max_container_memory_bytes + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.ref_reader = RefReader(self.track_ref) self.meta_string_reader = MetaStringReader(self.type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -783,12 +797,26 @@ cdef class ReadContext: buffers=None, unsupported_objects=None, bint peer_out_of_band_enabled=False, + int64_t root_input_bytes=-1, ): + cdef int64_t limit + if self.max_container_memory_bytes > 0: + limit = self.max_container_memory_bytes + elif buffer.has_input_stream(): + limit = _STREAM_ROOT_BUDGET_BYTES + else: + if root_input_bytes < 0: + root_input_bytes = buffer.size() - buffer.get_reader_index() + if root_input_bytes > (_MAX_CONTAINER_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: + raise ValueError("max_container_memory_bytes auto budget overflow") + limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES self.buffer = buffer self.c_buffer = buffer.c_buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled + self.container_memory_limit_bytes = limit + self.remaining_container_memory_bytes = limit self.depth = 0 cpdef inline reset(self): @@ -803,8 +831,61 @@ cdef class ReadContext: self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.depth = 0 + cdef inline void reserve_container_memory_c(self, int64_t num_bytes): + cdef int64_t used + if num_bytes < 0: + raise ValueError("Estimated container memory is negative") + if num_bytes > _MAX_CONTAINER_MEMORY_BYTES: + raise ValueError("Estimated container memory overflow") + if num_bytes > self.remaining_container_memory_bytes: + used = self.container_memory_limit_bytes - self.remaining_container_memory_bytes + raise ValueError( + f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " + "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_container_memory_bytes -= num_bytes + + cdef inline void reserve_container_memory_fast(self, int64_t num_bytes): + cdef int64_t used + if num_bytes > self.remaining_container_memory_bytes: + used = self.container_memory_limit_bytes - self.remaining_container_memory_bytes + raise ValueError( + f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " + "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_container_memory_bytes -= num_bytes + + cpdef inline reserve_container_memory(self, int64_t num_bytes): + self.reserve_container_memory_c(num_bytes) + + cdef inline void reserve_collection_memory_c(self, int64_t num_elements): + if num_elements < 0: + raise ValueError("Container element count is negative") + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _COLLECTION_OBJECT_BYTES) // _REFERENCE_BYTES: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory_c(_COLLECTION_OBJECT_BYTES + num_elements * _REFERENCE_BYTES) + + cpdef inline reserve_collection_memory(self, int64_t num_elements): + self.reserve_collection_memory_c(num_elements) + + cdef inline void reserve_map_memory_c(self, int64_t num_elements): + cdef int64_t bytes_per_entry + if num_elements < 0: + raise ValueError("Map entry count is negative") + bytes_per_entry = _MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _MAP_OBJECT_BYTES) // bytes_per_entry: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory_c(_MAP_OBJECT_BYTES + num_elements * bytes_per_entry) + + cpdef inline reserve_map_memory(self, int64_t num_elements): + self.reserve_map_memory_c(num_elements) + cpdef inline add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 3abfb46e3d..a923731c4b 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -17,6 +17,8 @@ from __future__ import annotations +import struct + from pyfory.serialization import Config from pyfory.lib import mmh3 from pyfory.meta.metastring import Encoding @@ -37,6 +39,14 @@ FLOAT64_TYPE_ID = TypeId.FLOAT64 BOOL_TYPE_ID = TypeId.BOOL STRING_TYPE_ID = TypeId.STRING +_KNOWN_ROOT_BUDGET_MULTIPLIER = 8 +_KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 +_STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +_COLLECTION_OBJECT_BYTES = 56 +_MAP_OBJECT_BYTES = 64 +_MAP_ENTRY_BYTES = 32 +_REFERENCE_BYTES = struct.calcsize("P") +_MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 def _mix64(x: int) -> int: @@ -470,6 +480,9 @@ class ReadContext: "field_nullable", "policy", "max_depth", + "max_container_memory_bytes", + "container_memory_limit_bytes", + "remaining_container_memory_bytes", "ref_reader", "meta_string_reader", "meta_share_context", @@ -490,6 +503,9 @@ def __init__(self, config: Config, type_resolver): self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth + self.max_container_memory_bytes = config.max_container_memory_bytes + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -520,11 +536,26 @@ def prepare( buffers=None, unsupported_objects=None, peer_out_of_band_enabled=False, + root_input_bytes=None, ): + if self.max_container_memory_bytes > 0: + limit = self.max_container_memory_bytes + elif buffer.has_input_stream(): + limit = _STREAM_ROOT_BUDGET_BYTES + else: + if root_input_bytes is None: + root_input_bytes = buffer.size() - buffer.get_reader_index() + if root_input_bytes < 0: + raise ValueError("root input byte count is negative") + if root_input_bytes > (_MAX_CONTAINER_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: + raise ValueError("max_container_memory_bytes auto budget overflow") + limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES self.buffer = buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled + self.container_memory_limit_bytes = limit + self.remaining_container_memory_bytes = limit self.depth = 0 def reset(self): @@ -538,8 +569,40 @@ def reset(self): self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.depth = 0 + def reserve_container_memory(self, num_bytes): + if num_bytes < 0: + raise ValueError("Estimated container memory is negative") + if num_bytes > _MAX_CONTAINER_MEMORY_BYTES: + raise ValueError("Estimated container memory overflow") + remaining = self.remaining_container_memory_bytes + if num_bytes > remaining: + used = self.container_memory_limit_bytes - remaining + raise ValueError( + f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " + "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_container_memory_bytes = remaining - num_bytes + + def reserve_collection_memory(self, num_elements): + if num_elements < 0: + raise ValueError("Container element count is negative") + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _COLLECTION_OBJECT_BYTES) // _REFERENCE_BYTES: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory(_COLLECTION_OBJECT_BYTES + num_elements * _REFERENCE_BYTES) + + def reserve_map_memory(self, num_elements): + if num_elements < 0: + raise ValueError("Map entry count is negative") + bytes_per_entry = _MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _MAP_OBJECT_BYTES) // bytes_per_entry: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory(_MAP_OBJECT_BYTES + num_elements * bytes_per_entry) + def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 899adcaf3c..2e4fede422 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -113,6 +113,8 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. -1 means auto; positive values are explicit byte limits. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. @@ -129,6 +131,7 @@ cdef class Config: cdef public int32_t max_type_meta_bytes cdef public int32_t max_schema_versions_per_type cdef public int32_t max_average_schema_versions_per_type + cdef public int64_t max_container_memory_bytes cdef public bint field_nullable cdef public object policy cdef public object meta_compressor @@ -147,6 +150,7 @@ cdef class Config: max_type_meta_bytes, max_schema_versions_per_type, max_average_schema_versions_per_type, + max_container_memory_bytes, field_nullable, policy, meta_compressor, @@ -166,6 +170,8 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. -1 means auto; positive values are explicit byte limits. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. @@ -185,10 +191,17 @@ cdef class Config: raise ValueError("max_schema_versions_per_type must be a positive integer") if max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") + if ( + not isinstance(max_container_memory_bytes, int) + or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) + or max_container_memory_bytes > 9223372036854775807 + ): + raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") self.max_type_fields = max_type_fields self.max_type_meta_bytes = max_type_meta_bytes self.max_schema_versions_per_type = max_schema_versions_per_type self.max_average_schema_versions_per_type = max_average_schema_versions_per_type + self.max_container_memory_bytes = max_container_memory_bytes self.field_nullable = field_nullable self.policy = policy self.meta_compressor = meta_compressor @@ -829,6 +842,7 @@ cdef class Fory: cdef public bint compatible cdef public bint field_nullable cdef public int32_t max_depth + cdef public int64_t max_container_memory_bytes cdef public object policy cdef public Config config cdef public TypeResolver type_resolver @@ -847,6 +861,7 @@ cdef class Fory: max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, + max_container_memory_bytes=-1, policy=None, field_nullable=False, meta_compressor=None, @@ -865,6 +880,8 @@ cdef class Fory: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. -1 means auto; positive values are explicit byte limits. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. @@ -882,6 +899,13 @@ cdef class Fory: self.compatible = compatible self.field_nullable = field_nullable self.max_depth = max_depth + if ( + not isinstance(max_container_memory_bytes, int) + or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) + or max_container_memory_bytes > 9223372036854775807 + ): + raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") + self.max_container_memory_bytes = max_container_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -894,6 +918,7 @@ cdef class Fory: max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, + max_container_memory_bytes=max_container_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, @@ -1051,6 +1076,8 @@ cdef class Fory: cdef int32_t reader_index cdef uint8_t bitmap cdef bint peer_out_of_band_enabled + cdef int64_t root_input_bytes + cdef int64_t container_memory_limit if isinstance(buffer, bytes): buffer = Buffer(buffer) read_buffer = buffer @@ -1066,6 +1093,13 @@ cdef class Fory: raise ValueError("Out-of-band buffers are required by the root header") if not peer_out_of_band_enabled and buffers is not None: raise ValueError("Out-of-band buffers were provided for an in-band root payload") + if self.max_container_memory_bytes > 0: + container_memory_limit = self.max_container_memory_bytes + elif read_buffer.has_input_stream(): + container_memory_limit = _STREAM_ROOT_BUDGET_BYTES + else: + root_input_bytes = read_buffer.size() - reader_index + container_memory_limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES # Keep the root context setup inline. Top-level deserialize is a hot path, # so it should not pay an extra method call just to bind the active buffer. read_context.buffer = read_buffer @@ -1075,6 +1109,8 @@ cdef class Fory: iter(unsupported_objects) if unsupported_objects is not None else None ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled + read_context.container_memory_limit_bytes = container_memory_limit + read_context.remaining_container_memory_bytes = container_memory_limit read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index d3e43de30f..8ed4aa2255 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -933,6 +933,7 @@ def read(self, read_context): if dtype.kind == "O": length = read_context.read_varint32() _check_non_negative_size(length, "ndarray object") + read_context.reserve_collection_memory(length) read_context.check_readable_bytes(length) items = [read_context.read_ref() for _ in range(length)] return np.array(items, dtype=object) diff --git a/python/pyfory/tests/test_container_memory_budget.py b/python/pyfory/tests/test_container_memory_budget.py new file mode 100644 index 0000000000..09069d412b --- /dev/null +++ b/python/pyfory/tests/test_container_memory_budget.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array +import struct + +import pytest + +import pyfory +from pyfory.serialization import Buffer +from pyfory.serializer import ListSerializer + +try: + import numpy as np +except ImportError: + np = None + + +KNOWN_ROOT_BUDGET_MULTIPLIER = 8 +KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 +STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +COLLECTION_OBJECT_BYTES = 56 +MAP_OBJECT_BYTES = 64 +MAP_ENTRY_BYTES = 32 +REFERENCE_BYTES = struct.calcsize("P") +MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 + + +class OneByteStream: + def __init__(self, data: bytes): + self._data = data + self._offset = 0 + + def read(self, size=-1): + if self._offset >= len(self._data): + return b"" + if size < 0: + size = len(self._data) - self._offset + if size == 0: + return b"" + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + return self._data[start : start + read_size] + + def readinto(self, buffer): + if self._offset >= len(self._data): + return 0 + view = memoryview(buffer).cast("B") + if len(view) == 0: + return 0 + read_size = min(1, len(view), len(self._data) - self._offset) + start = self._offset + self._offset += read_size + view[:read_size] = self._data[start : start + read_size] + return read_size + + def recv_into(self, buffer, size=-1): + if self._offset >= len(self._data): + return 0 + view = memoryview(buffer).cast("B") + if size < 0 or size > len(view): + size = len(view) + if size == 0: + return 0 + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + view[:read_size] = self._data[start : start + read_size] + return read_size + + +def collection_memory(num_elements): + return COLLECTION_OBJECT_BYTES + num_elements * REFERENCE_BYTES + + +def map_memory(num_entries): + return MAP_OBJECT_BYTES + num_entries * (MAP_ENTRY_BYTES + 5 * REFERENCE_BYTES) + + +def new_fory(limit=-1, *, xlang=True): + return pyfory.Fory( + xlang=xlang, + ref=True, + strict=False, + compatible=xlang, + max_container_memory_bytes=limit, + ) + + +def expect_budget(value, budget, *, xlang=True): + writer = new_fory(xlang=xlang) + data = writer.serialize(value) + with pytest.raises(ValueError, match="Estimated container memory budget exceeded"): + new_fory(budget - 1, xlang=xlang).deserialize(data) + return new_fory(budget, xlang=xlang).deserialize(data) + + +def varuint_payload(value): + buffer = Buffer.allocate(16) + buffer.write_var_uint32(value) + return buffer.to_bytes(0, buffer.get_writer_index()) + + +def test_known_length_auto_budget(): + fory = new_fory(xlang=False) + root_input_bytes = 17 + try: + fory.read_context.prepare(Buffer(b"x" * root_input_bytes), root_input_bytes=root_input_bytes) + expected = root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES + assert fory.read_context.container_memory_limit_bytes == expected + fory.read_context.reserve_container_memory(expected) + with pytest.raises(ValueError, match="Estimated container memory budget exceeded"): + fory.read_context.reserve_container_memory(1) + finally: + fory.reset_read() + + +def test_stream_auto_budget(): + fory = new_fory(xlang=False) + try: + buffer = Buffer.from_stream(OneByteStream(b"streamed")) + fory.read_context.prepare(buffer, root_input_bytes=1) + assert fory.read_context.container_memory_limit_bytes == STREAM_ROOT_BUDGET_BYTES + finally: + fory.reset_read() + + +def test_explicit_config_overrides_auto(): + value = [1] + budget = collection_memory(1) + assert expect_budget(value, budget) == value + + +def test_nested_empty_containers_charge_fixed_cost(): + value = [[]] + budget = collection_memory(1) + collection_memory(0) + assert expect_budget(value, budget) == value + + +def test_sibling_nested_containers_are_cumulative(): + value = [[], [], []] + budget = collection_memory(3) + 3 * collection_memory(0) + assert expect_budget(value, budget) == value + + +def test_map_entry_budget_and_overflow(): + value = {"a": 1} + assert expect_budget(value, map_memory(1)) == value + + fory = new_fory(xlang=False) + try: + fory.read_context.prepare(Buffer(b""), root_input_bytes=0) + max_map_entries = (MAX_CONTAINER_MEMORY_BYTES - MAP_OBJECT_BYTES) // (MAP_ENTRY_BYTES + 5 * REFERENCE_BYTES) + with pytest.raises(ValueError, match="Estimated container memory overflow"): + fory.read_context.reserve_map_memory(max_map_entries + 1) + finally: + fory.reset_read() + + +def test_object_reference_array_budget(): + value = (1, 2, 3) + assert expect_budget(value, collection_memory(3), xlang=False) == value + + +def test_object_ndarray_budget(): + if np is None: + pytest.skip("numpy is not installed") + value = np.array([1, 2, 3], dtype=object) + restored = expect_budget(value, collection_memory(3), xlang=False) + np.testing.assert_array_equal(restored, value) + + +def test_string_binary_and_dense_arrays_skip_budget(): + values = [ + "x" * 256, + b"x" * 256, + array.array("i", range(32)), + ] + if np is not None: + values.append(np.array(list(range(32)), dtype=np.int32)) + for value in values: + fory = new_fory(1, xlang=False) + restored = fory.deserialize(fory.serialize(value)) + if np is not None and isinstance(value, np.ndarray): + np.testing.assert_array_equal(restored, value) + else: + assert restored == value + + +def test_declared_large_list_still_needs_bytes(): + fory = new_fory(10_000_000, xlang=False) + serializer = ListSerializer(fory.type_resolver, list) + try: + fory.read_context.prepare(Buffer(varuint_payload(1000)), root_input_bytes=1) + with pytest.raises(Exception) as exc_info: + serializer.read(fory.read_context) + assert "Estimated container memory" not in str(exc_info.value) + finally: + fory.reset_read() + + +@pytest.mark.parametrize("limit", [0, -2, 1 << 63]) +def test_invalid_config(limit): + with pytest.raises(ValueError, match="max_container_memory_bytes"): + new_fory(limit) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 167d67e72b..f5d003cd0f 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -40,6 +40,9 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, + /// Maximum estimated container-owned memory accepted during one root deserialization. + /// `-1` selects the automatic input-shaped limit. + pub max_container_memory_bytes: i64, /// Maximum accepted field count in one received struct TypeMeta. pub max_type_fields: u32, /// Maximum accepted body size in one received TypeMeta. @@ -61,6 +64,7 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, + max_container_memory_bytes: -1, max_type_fields: 512, max_type_meta_bytes: 4096, max_schema_versions_per_type: 10, @@ -123,6 +127,12 @@ impl Config { self.track_ref } + /// Get maximum estimated container-owned memory per root deserialization. + #[inline(always)] + pub fn max_container_memory_bytes(&self) -> i64 { + self.max_container_memory_bytes + } + /// Get maximum accepted field count in one received struct TypeMeta. #[inline(always)] pub fn max_type_fields(&self) -> usize { diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 260f94ea4c..f36e150d2f 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -31,6 +31,13 @@ use crate::type_id as types; use crate::TypeId; use std::rc::Rc; +const KNOWN_ROOT_BUDGET_MULTIPLIER: usize = 8; +const KNOWN_ROOT_BUDGET_SLACK_BYTES: usize = 64 * 1024; +const VEC_OBJECT_BYTES: usize = mem::size_of::>(); +const MAP_ENTRY_OVERHEAD_BYTES: usize = 16; +const REFERENCE_SLOT_BYTES: usize = mem::size_of::(); +const MAX_CONTAINER_LEN: usize = u32::MAX as usize; + /// Thread-local context cache with fast path for single Fory instance. /// Uses (cached_id, context) for O(1) access when using same Fory instance repeatedly. /// Falls back to HashMap for multiple Fory instances per thread. @@ -359,6 +366,9 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, + max_container_memory_bytes: i64, + container_memory_limit_bytes: usize, + remaining_container_memory_bytes: usize, // Context-specific fields pub reader: Reader<'a>, @@ -388,6 +398,9 @@ impl<'a> ReadContext<'a> { max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, + max_container_memory_bytes: config.max_container_memory_bytes, + container_memory_limit_bytes: 0, + remaining_container_memory_bytes: 0, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), meta_string_resolver: MetaStringReaderResolver::default(), @@ -443,6 +456,112 @@ impl<'a> ReadContext<'a> { self.reader = reader; } + #[inline(always)] + pub(crate) fn init_container_memory_budget( + &mut self, + root_input_bytes: usize, + ) -> Result<(), Error> { + let limit = if self.max_container_memory_bytes > 0 { + usize::try_from(self.max_container_memory_bytes).map_err(|_| { + container_memory_error("max_container_memory_bytes does not fit usize") + })? + } else { + if root_input_bytes + > (usize::MAX - KNOWN_ROOT_BUDGET_SLACK_BYTES) / KNOWN_ROOT_BUDGET_MULTIPLIER + { + return Err(container_memory_error( + "root input size overflows automatic container memory budget", + )); + } + root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES + }; + self.container_memory_limit_bytes = limit; + self.remaining_container_memory_bytes = limit; + Ok(()) + } + + #[inline(always)] + pub(crate) fn reserve_vec_memory(&mut self, len: u32) -> Result { + let len = len as usize; + self.reserve_counted_memory(len, VEC_OBJECT_BYTES, mem::size_of::())?; + Ok(len) + } + + #[inline(always)] + pub(crate) fn reserve_collection_memory(&mut self, len: u32) -> Result { + let len = len as usize; + let elem_size = mem::size_of::(); + if elem_size > usize::MAX - REFERENCE_SLOT_BYTES { + return Err(container_memory_overflow(len, elem_size)); + } + let elem_bytes = elem_size + REFERENCE_SLOT_BYTES; + self.reserve_counted_memory(len, mem::size_of::(), elem_bytes)?; + Ok(len) + } + + #[inline(always)] + pub(crate) fn reserve_map_memory(&mut self, len: u32) -> Result { + let len = len as usize; + let key_size = mem::size_of::(); + let value_size = mem::size_of::(); + let overhead = MAP_ENTRY_OVERHEAD_BYTES + REFERENCE_SLOT_BYTES * 3; + if key_size > usize::MAX - value_size || key_size + value_size > usize::MAX - overhead { + return Err(container_memory_overflow(len, key_size)); + } + let elem_bytes = key_size + value_size + overhead; + self.reserve_counted_memory(len, mem::size_of::(), elem_bytes)?; + Ok(len) + } + + #[inline(always)] + pub(crate) fn reserve_container_bytes(&mut self, bytes: usize) -> Result<(), Error> { + let remaining = self.remaining_container_memory_bytes; + if bytes > remaining { + return Err(container_memory_exceeded( + bytes, + remaining, + self.container_memory_limit_bytes, + )); + } + self.remaining_container_memory_bytes = remaining - bytes; + Ok(()) + } + + #[inline(always)] + fn reserve_counted_memory( + &mut self, + len: usize, + fixed_bytes: usize, + elem_bytes: usize, + ) -> Result<(), Error> { + if len == 0 { + return self.reserve_container_bytes(fixed_bytes); + } + if elem_bytes <= (usize::MAX - fixed_bytes) / MAX_CONTAINER_LEN { + return self.reserve_container_bytes(len * elem_bytes + fixed_bytes); + } + self.reserve_counted_memory_checked(len, fixed_bytes, elem_bytes) + } + + #[cold] + #[inline(never)] + fn reserve_counted_memory_checked( + &mut self, + len: usize, + fixed_bytes: usize, + elem_bytes: usize, + ) -> Result<(), Error> { + let elem_total = match len.checked_mul(elem_bytes) { + Some(bytes) => bytes, + None => return Err(container_memory_overflow(len, elem_bytes)), + }; + let bytes = match elem_total.checked_add(fixed_bytes) { + Some(bytes) => bytes, + None => return Err(container_memory_overflow(len, elem_bytes)), + }; + self.reserve_container_bytes(bytes) + } + #[inline(always)] pub fn detach_reader(&mut self) -> Reader<'_> { mem::take(&mut self.reader) @@ -552,3 +671,27 @@ impl<'a> ReadContext<'a> { self.current_depth = 0; } } + +#[cold] +#[inline(never)] +fn container_memory_error(message: &'static str) -> Error { + Error::invalid_data(message) +} + +#[cold] +#[inline(never)] +fn container_memory_overflow(len: usize, elem_bytes: usize) -> Error { + Error::invalid_data(format!( + "container memory estimate overflows: length={} elementBytes={}", + len, elem_bytes + )) +} + +#[cold] +#[inline(never)] +fn container_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { + Error::invalid_data(format!( + "estimated container memory request {} bytes exceeds max_container_memory_bytes remaining budget {} bytes out of effective limit {} bytes", + bytes, remaining, limit + )) +} diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 4b6c98419a..8eb9d3794f 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -261,6 +261,18 @@ impl ForyBuilder { self } + /// Sets the maximum estimated container-owned memory accepted during one root deserialization. + /// + /// Use `-1` for the automatic input-shaped limit. Positive values are explicit byte limits. + pub fn max_container_memory_bytes(mut self, max_bytes: i64) -> Self { + assert!( + max_bytes == -1 || max_bytes > 0, + "max_container_memory_bytes must be positive or -1 for auto" + ); + self.config.max_container_memory_bytes = max_bytes; + self + } + /// Sets the maximum depth for nested dynamic object serialization. /// /// # Arguments @@ -988,7 +1000,13 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = self.deserialize_with_context(context); + let result = match context.init_container_memory_budget(bf.len()) { + Ok(()) => self.deserialize_with_context(context), + Err(err) => { + context.reset(); + Err(err) + } + }; context.detach_reader(); result }) @@ -1050,8 +1068,15 @@ impl Fory { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(reader.bf) }; let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); + let root_input_bytes = reader.bf.len().saturating_sub(reader.cursor); context.attach_reader(new_reader); - let result = self.deserialize_with_context(context); + let result = match context.init_container_memory_budget(root_input_bytes) { + Ok(()) => self.deserialize_with_context(context), + Err(err) => { + context.reset(); + Err(err) + } + }; let end = context.detach_reader().get_cursor(); reader.set_cursor(end); result diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 34059103f5..675fc133e9 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -1700,6 +1700,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; + context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -1728,6 +1729,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; + context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -2270,6 +2272,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; + context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(HashMap::new()); } @@ -2289,6 +2292,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; + let capacity = context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(HashMap::new()); } @@ -2299,7 +2303,8 @@ where { return read_map_dynamic::(context, len, remote_field_type); } - let mut map = HashMap::with_capacity(check_map_len(context, len)?); + context.reader.check_bound(capacity)?; + let mut map = HashMap::with_capacity(capacity); let mut len_counter = 0; while len_counter < len { let header = context.reader.read_u8()?; diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index ee16166bb4..b2dd1950f9 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -239,6 +239,7 @@ where C: FromIterator, { let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_collection_memory::(len)?; if len == 0 { return Ok(C::from_iter(std::iter::empty())); } @@ -257,7 +258,7 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - let _ = check_collection_len(context, len)?; + context.reader.check_bound(len_usize)?; if !has_null { (0..len) .map(|_| T::fory_read_data(context)) @@ -281,6 +282,7 @@ where T: Serializer + ForyDefault, { let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -297,7 +299,8 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + context.reader.check_bound(len_usize)?; + let mut vec = Vec::with_capacity(len_usize); if !has_null { for _ in 0..len { vec.push(T::fory_read_data(context)?); @@ -343,7 +346,8 @@ where } else { T::fory_get_type_info(context.get_type_resolver())? }; - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + let len_usize = check_collection_len(context, len)?; + let mut vec = Vec::with_capacity(len_usize); if elem_ref_mode == RefMode::None { for _ in 0..len { vec.push(T::fory_read_with_type_info( @@ -363,7 +367,8 @@ where } Ok(vec) } else { - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + let len_usize = check_collection_len(context, len)?; + let mut vec = Vec::with_capacity(len_usize); for _ in 0..len { vec.push(T::fory_read(context, elem_ref_mode, true)?); } @@ -724,6 +729,7 @@ where { let element_type = generic_field_type(remote_field_type, 0, "list")?; let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -748,8 +754,8 @@ where "array-compatible list must declare element type", )); } - context.reader.check_bound(len as usize)?; - let mut vec = Vec::with_capacity(len as usize); + context.reader.check_bound(len_usize)?; + let mut vec = Vec::with_capacity(len_usize); for _ in 0..len { vec.push(T::read_list_array_element(context, element_type.type_id)?); } diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 3d0dc094e7..158e020edc 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -35,12 +35,6 @@ const TRACKING_VALUE_REF: u8 = 0b1000; pub const VALUE_NULL: u8 = 0b10000; pub const DECL_VALUE_TYPE: u8 = 0b100000; -fn check_map_len(context: &ReadContext, len: u32) -> Result { - let len = len as usize; - context.reader.check_bound(len)?; - Ok(len) -} - fn write_chunk_size(context: &mut WriteContext, header_offset: usize, size: u8) { context.writer.set_bytes(header_offset + 1, &[size]); } @@ -559,10 +553,11 @@ impl Result { let len = context.reader.read_var_u32()?; + let capacity = context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(HashMap::new()); } - let capacity = check_map_len(context, len)?; + context.reader.check_bound(capacity)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() @@ -711,10 +706,11 @@ impl Result { let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(BTreeMap::new()); } - let _ = check_map_len(context, len)?; + context.reader.check_bound(len_usize)?; let mut map = BTreeMap::::new(); if K::fory_is_polymorphic() || K::fory_is_shared_ref() diff --git a/rust/tests/tests/mod.rs b/rust/tests/tests/mod.rs index c66d727c8a..74f62c87d3 100644 --- a/rust/tests/tests/mod.rs +++ b/rust/tests/tests/mod.rs @@ -18,6 +18,7 @@ mod compatible; mod test_any; mod test_collection; +mod test_container_memory_budget; mod test_field_meta; mod test_max_dyn_depth; mod test_tuple; diff --git a/rust/tests/tests/test_container_memory_budget.rs b/rust/tests/tests/test_container_memory_budget.rs new file mode 100644 index 0000000000..29f70d10bf --- /dev/null +++ b/rust/tests/tests/test_container_memory_budget.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fory_core::{Error, Fory, Reader}; +use fory_derive::ForyStruct; +use std::collections::HashMap; +use std::panic; + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetSiblings { + first: Vec, + second: Vec, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetItem { + left: u64, + right: u64, +} + +#[derive(ForyStruct, Debug)] +struct ListWireInts { + values: Vec>, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct DenseWireInts { + values: Vec, +} + +fn fory_with_budget(max_container_memory_bytes: i64) -> Fory { + let mut fory = Fory::builder() + .xlang(false) + .compatible(false) + .max_container_memory_bytes(max_container_memory_bytes) + .build(); + fory.register_by_name::("BudgetSiblings") + .unwrap(); + fory.register_by_name::("BudgetItem").unwrap(); + fory +} + +fn compatible_fory(max_container_memory_bytes: i64) -> Fory +where + T: fory_core::Serializer + fory_core::StructSerializer + fory_core::ForyDefault, +{ + let mut fory = Fory::builder() + .xlang(false) + .compatible(true) + .max_container_memory_bytes(max_container_memory_bytes) + .build(); + fory.register::(88_001).unwrap(); + fory +} + +fn compact_empty_lists(count: usize) -> Vec> { + (0..count).map(|_| Vec::new()).collect() +} + +fn assert_budget_error(err: Error, effective_limit: usize) { + let message = err.to_string(); + assert!( + message.contains("estimated container memory request"), + "{message}" + ); + assert!( + message.contains(&format!("effective limit {effective_limit}")), + "{message}" + ); +} + +#[test] +fn config_validation() { + assert!(panic::catch_unwind(|| Fory::builder().max_container_memory_bytes(0)).is_err()); + assert!(panic::catch_unwind(|| Fory::builder().max_container_memory_bytes(-2)).is_err()); + let _ = Fory::builder().max_container_memory_bytes(-1).build(); + let _ = Fory::builder().max_container_memory_bytes(1).build(); +} + +#[test] +fn known_auto_budget() { + let value = compact_empty_lists(3000); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let auto_limit = bytes.len() * 8 + 64 * 1024; + + let err = writer.deserialize::>>(&bytes).unwrap_err(); + assert_budget_error(err, auto_limit); + + let explicit = fory_with_budget(auto_limit as i64); + let err = explicit + .deserialize::>>(&bytes) + .unwrap_err(); + assert_budget_error(err, auto_limit); +} + +#[test] +fn reader_known_auto_budget() { + let value = compact_empty_lists(3000); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let auto_limit = bytes.len() * 8 + 64 * 1024; + + let mut reader = Reader::new(&bytes); + let err = writer + .deserialize_from::>>(&mut reader) + .unwrap_err(); + assert_budget_error(err, auto_limit); +} + +#[test] +fn explicit_override() { + let value = compact_empty_lists(3000); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + assert!(writer.deserialize::>>(&bytes).is_err()); + + let vec_bytes = std::mem::size_of::>(); + let estimate = std::mem::size_of::>>() + value.len() * vec_bytes * 2; + let explicit = fory_with_budget(estimate as i64); + let decoded: Vec> = explicit.deserialize(&bytes).unwrap(); + assert_eq!(decoded, value); +} + +#[test] +fn empty_container_cost() { + let value: Vec = Vec::new(); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let fixed = std::mem::size_of::>() as i64; + + let limited = fory_with_budget(fixed - 1); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn sibling_cumulative_budget() { + let value = BudgetSiblings { + first: Vec::new(), + second: Vec::new(), + }; + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let one_vec = std::mem::size_of::>() as i64; + + let limited = fory_with_budget(one_vec); + assert!(limited.deserialize::(&bytes).is_err()); +} + +#[test] +fn map_budget() { + let value: HashMap = HashMap::new(); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let fixed = std::mem::size_of::>() as i64; + + let limited = fory_with_budget(fixed - 1); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn inline_value_vec_budget() { + let value = (0..16) + .map(|i| BudgetItem { + left: i, + right: i + 1, + }) + .collect::>(); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let under_inline = + std::mem::size_of::>() + value.len() * std::mem::size_of::(); + + let limited = fory_with_budget(under_inline as i64); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn compatible_list_array_budget() { + let value = ListWireInts { + values: (0..64).map(Some).collect(), + }; + let writer = compatible_fory::(-1); + let bytes = writer.serialize(&value).unwrap(); + + let limited = compatible_fory::(std::mem::size_of::>() as i64); + assert!(limited.deserialize::(&bytes).is_err()); + + let enough = compatible_fory::(i64::MAX); + let decoded = enough.deserialize::(&bytes).unwrap(); + assert_eq!( + decoded, + DenseWireInts { + values: (0..64).collect() + } + ); +} + +#[test] +fn dense_paths_skipped() { + let fory = fory_with_budget(1); + + let string_bytes = fory_with_budget(-1) + .serialize(&"hello".to_string()) + .unwrap(); + let decoded: String = fory.deserialize(&string_bytes).unwrap(); + assert_eq!(decoded, "hello"); + + let binary = vec![1_u8, 2, 3, 4]; + let binary_bytes = fory_with_budget(-1).serialize(&binary).unwrap(); + let decoded: Vec = fory.deserialize(&binary_bytes).unwrap(); + assert_eq!(decoded, binary); + + let ints = vec![1_i32, 2, 3, 4]; + let int_bytes = fory_with_budget(-1).serialize(&ints).unwrap(); + let decoded: Vec = fory.deserialize(&int_bytes).unwrap(); + assert_eq!(decoded, ints); +} + +#[test] +fn byte_check_preserved() { + let writer = fory_with_budget(-1); + let mut bytes = writer.serialize(&Vec::::new()).unwrap(); + let last = bytes.len() - 1; + bytes[last] = 64; + + let reader = fory_with_budget(i64::MAX); + let err = reader.deserialize::>(&bytes).unwrap_err(); + assert!(matches!(err, Error::BufferOutOfBound(..)), "{err}"); +} diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala index 066e24c629..8b688bcb37 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala @@ -53,15 +53,10 @@ abstract class AbstractScalaCollectionSerializer[A, T <: Iterable[A]]( value: T): util.Collection[_] override def newCollection(readContext: ReadContext): util.Collection[_] = { - val buffer = readContext.getBuffer - val numElements = buffer.readVarUInt32() - checkCollectionSize(numElements) + val numElements = readCollectionSize(readContext) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[A, T]] val builder = factory.newBuilder - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } builder.sizeHint(numElements) new JavaCollectionBuilder[A, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala index 9c21954b7d..3891361615 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala @@ -50,15 +50,10 @@ abstract class AbstractScalaMapSerializer[K, V, T](typeResolver: TypeResolver, c def onMapWrite(writeContext: WriteContext, value: T): util.Map[_, _] override def newMap(readContext: ReadContext): util.Map[_, _] = { - val buffer = readContext.getBuffer - val numElements = buffer.readVarUInt32() - checkMapSize(numElements) + val numElements = readMapSize(readContext) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[(K, V), T]] val builder = factory.newBuilder - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } builder.sizeHint(numElements) new MapBuilder[K, V, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index 9eeab286d2..9439f3493e 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -43,12 +43,8 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I } override def newCollection(readContext: ReadContext): util.Collection[_] = { - val buffer = readContext.getBuffer - val numElements = readCollectionSize(buffer) + val numElements = readCollectionSize(readContext) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } val builder = newBuilder(numElements) if (ScalaXlangCollectionShape.hasOptionElement(readContext)) { new XlangOptionCollectionBuilder[A, T](builder) @@ -368,12 +364,8 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K } override def newMap(readContext: ReadContext): util.Map[_, _] = { - val buffer = readContext.getBuffer - val numElements = readMapSize(buffer) + val numElements = readMapSize(readContext) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } val builder = ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) val optionKey = ScalaXlangCollectionShape.hasOptionKey(readContext) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index d1b7e67952..fc386639dd 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.scala.ForyScala import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -89,6 +90,37 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { } } } + + "fory scala container memory budget" should { + def runtime(maxContainerMemoryBytes: Long = -1): Fory = { + val builder = ForyScala.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .withSerializerFactory(new ScalaSerializerFactory()) + if (maxContainerMemoryBytes > 0) { + builder.withMaxContainerMemoryBytes(maxContainerMemoryBytes) + } + builder.build() + } + + "charge scala collection fixed cost" in { + val writer = runtime() + val reader = runtime(maxContainerMemoryBytes = 23) + intercept[InsecureException] { + reader.deserialize(writer.serialize(List.empty[String])) + } + } + + "charge scala map fixed cost" in { + val writer = runtime() + val reader = runtime(maxContainerMemoryBytes = 47) + intercept[InsecureException] { + reader.deserialize(writer.serialize(Map("k" -> "v"))) + } + } + } } case class CollectionStruct1(list: List[String]) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index 6cc4f880a1..b637bc4f6c 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.scala.ForyScala import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -120,5 +121,24 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { copiedCyclic should not be theSameInstanceAs(cyclic) copiedCyclic(0) shouldBe theSameInstanceAs(copiedCyclic) } + + "enforce container memory budget" in { + val writer = fory + val reader = ForyScala.builder() + .withXlang(true) + .withRefTracking(true) + .withRefCopy(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .withMaxContainerMemoryBytes(23) + .build() + + intercept[InsecureException] { + reader.deserialize(writer.serialize(List.empty[String])) + } + intercept[InsecureException] { + reader.deserialize(writer.serialize(Map("k" -> 1))) + } + } } } diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index fdd99b38a4..ca31cce2da 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -570,7 +570,11 @@ public func readListOfAny( refMode: refMode, readTypeInfo: readTypeInfo ) - return wrapped?.map { $0.anyValueForCollection() } + guard let wrapped else { + return nil + } + try context.reserveReferenceArrayMemory(count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } } public func writeMapStringToAny( @@ -604,6 +608,7 @@ public func readMapStringToAny( guard let wrapped else { return nil } + try context.reserveReferenceMapMemory(count: wrapped.count) var map: [String: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -643,6 +648,7 @@ public func readMapInt32ToAny( guard let wrapped else { return nil } + try context.reserveReferenceMapMemory(count: wrapped.count) var map: [Int32: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -682,6 +688,7 @@ public func readMapAnyHashableToAny( guard let wrapped else { return nil } + try context.reserveReferenceMapMemory(count: wrapped.count) var map: [AnyHashable: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -693,8 +700,10 @@ public func readMapAnyHashableToAny( func readDynamicAnyMapValue(context: ReadContext) throws -> Any { let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] if map.isEmpty { + try context.reserveReferenceMapMemory(count: 0) return [String: Any]() } + try context.reserveReferenceMapMemory(count: map.count) var stringMap: [String: Any] = [:] stringMap.reserveCapacity(map.count) for pair in map { @@ -708,6 +717,7 @@ func readDynamicAnyMapValue(context: ReadContext) throws -> Any { return stringMap } + try context.reserveReferenceMapMemory(count: map.count) var int32Map: [Int32: Any] = [:] int32Map.reserveCapacity(map.count) for pair in map { diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 1be59fb6b4..a7b943a4b6 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -234,18 +234,35 @@ func writePrimitiveArray(_ value: [Element], context: Write } } -func readPrimitiveArray(_ context: ReadContext) throws -> [Element] { +@inline(__always) +private func preparePrimitiveArray( + _ context: ReadContext, + chargeContainerMemory: Bool, + type: Element.Type, + count: Int, + label: String +) throws { + try context.ensureCollectionLength(count, label: label) + if chargeContainerMemory { + try context.reserveArrayMemory(type, count: count) + } +} + +func readPrimitiveArray( + _ context: ReadContext, + chargeContainerMemory: Bool = false +) throws -> [Element] { let byteSize = Int(try context.buffer.readVarUInt32()) try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") if Element.self == UInt8.self { - try context.ensureCollectionLength(byteSize, label: "uint8_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "uint8_array") let bytes = try context.buffer.readBytes(count: byteSize) return uncheckedArrayCast(bytes, to: Element.self) } if Element.self == Bool.self { - try context.ensureCollectionLength(byteSize, label: "bool_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "bool_array") let out = try readArrayUninitialized(count: byteSize) { destination in for index in 0..(_ context: ReadContext) throws -> [ } if Element.self == Int8.self { - try context.ensureCollectionLength(byteSize, label: "int8_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "int8_array") var out = Array(repeating: Int8(0), count: byteSize) try out.withUnsafeMutableBytes { rawBytes in try context.buffer.readBytes(into: rawBytes) @@ -266,7 +283,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("int16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "int16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int16_array") if hostIsLittleEndian { var out = Array(repeating: Int16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -285,7 +302,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("int32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "int32_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int32_array") if hostIsLittleEndian { var out = Array(repeating: Int32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -304,7 +321,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("uint32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "uint32_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint32_array") if hostIsLittleEndian { var out = Array(repeating: UInt32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -323,7 +340,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("int64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "int64_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int64_array") if hostIsLittleEndian { var out = Array(repeating: Int64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -342,7 +359,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("uint64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "uint64_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint64_array") if hostIsLittleEndian { var out = Array(repeating: UInt64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -361,7 +378,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("uint16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "uint16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint16_array") if hostIsLittleEndian { var out = Array(repeating: UInt16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -380,7 +397,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Float16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("float16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "float16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ if Element.self == BFloat16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("bfloat16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "bfloat16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "bfloat16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ if Element.self == Float.self { if byteSize % 4 != 0 { throw ForyError.invalidData("float32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "float32_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float32_array") if hostIsLittleEndian { var out = Array(repeating: Float(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -422,7 +439,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if byteSize % 8 != 0 { throw ForyError.invalidData("float64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "float64_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float64_array") if hostIsLittleEndian { var out = Array(repeating: Double(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -532,6 +549,7 @@ extension Array: Serializer where Element: Serializer { let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try context.reserveArrayMemory(Element.self, count: length) return [] } @@ -541,6 +559,7 @@ extension Array: Serializer where Element: Serializer { let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { + try context.reserveArrayMemory(Element.self, count: length) try context.ensureRemainingBytes(length, label: "array") if trackRef { return try readArrayUninitialized(count: length) { destination in @@ -579,6 +598,7 @@ extension Array: Serializer where Element: Serializer { } let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) + try context.reserveArrayMemory(Element.self, count: length) try context.ensureRemainingBytes(length, label: "array") return try context.withTypeInfo(elementTypeInfo, for: Element.self) { if trackRef { @@ -637,7 +657,9 @@ extension Set: Serializer where Element: Serializer & Hashable { } public static func foryReadData(_ context: ReadContext) throws -> Set { - Set(try [Element].foryReadData(context)) + let values = try [Element].foryReadData(context) + try context.reserveSetMemory(Element.self, count: values.count) + return Set(values) } } @@ -864,11 +886,13 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { + try context.reserveMapMemory(key: Key.self, value: Value.self, count: totalLength) return [:] } - var map: [Key: Value] = [:] + try context.reserveMapMemory(key: Key.self, value: Value.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") + var map: [Key: Value] = [:] map.reserveCapacity(totalLength) let keyDynamicType = Key.staticTypeId == .unknown let valueDynamicType = Value.staticTypeId == .unknown diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index b9e8825967..87141e1798 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -840,7 +840,9 @@ public enum SetFieldCodec: FieldCodec where ElementCod } public static func readPayload(_ context: ReadContext) throws -> Value { - Set(try readCollectionPayload(context, elementCodec: ElementCodec.self)) + let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) + try context.reserveFieldSetMemory(ElementCodec.self, count: values.count) + return Set(values) } } @@ -959,11 +961,13 @@ where KeyCodec.Value: Hashable { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { + try context.reserveFieldMapMemory(key: KeyCodec.self, value: ValueCodec.self, count: totalLength) return [:] } - var map: Value = [:] + try context.reserveFieldMapMemory(key: KeyCodec.self, value: ValueCodec.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") + var map: Value = [:] map.reserveCapacity(totalLength) var readCount = 0 while readCount < totalLength { @@ -1323,8 +1327,11 @@ private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { } } -private func readIntArrayPayload(_ context: ReadContext) throws -> [Int] { +private func readIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [Int] { let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") + if chargeContainerMemory { + try context.reserveArrayMemory(Int.self, count: count) + } var values: [Int] = [] values.reserveCapacity(count) for _ in 0.. [Int] { return values } -private func readUIntArrayPayload(_ context: ReadContext) throws -> [UInt] { +private func readUIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [UInt] { let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") + if chargeContainerMemory { + try context.reserveArrayMemory(UInt.self, count: count) + } var values: [UInt] = [] values.reserveCapacity(count) for _ in 0..( elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Bool], to: ElementCodec.Value.self) } if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int8], to: ElementCodec.Value.self) } if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int16], to: ElementCodec.Value.self) } if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int32], to: ElementCodec.Value.self) } if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int64], to: ElementCodec.Value.self) } if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readIntArrayPayload(context, chargeContainerMemory: true), to: ElementCodec.Value.self) } if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt8], to: ElementCodec.Value.self) } if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt16], to: ElementCodec.Value.self) } if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt32], to: ElementCodec.Value.self) } if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt64], to: ElementCodec.Value.self) } if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readUIntArrayPayload(context, chargeContainerMemory: true), to: ElementCodec.Value.self) } if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Float16], to: ElementCodec.Value.self) } if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [BFloat16], to: ElementCodec.Value.self) } if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Float], to: ElementCodec.Value.self) } if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Double], to: ElementCodec.Value.self) } throw ForyError.invalidData("unsupported compatible array-to-list field element codec \(ElementCodec.self)") } @@ -1590,6 +1600,7 @@ private func readCollectionPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) return [] } @@ -1604,6 +1615,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) @@ -1688,6 +1700,7 @@ private func readListPayloadAsArrayPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) return [] } @@ -1715,6 +1728,7 @@ private func readListPayloadAsArrayPayload( } try context.ensureRemainingBytes(length, label: "array") var result: [ElementCodec.Value] = [] + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) result.reserveCapacity(length) return try ElementCodec.withTypeInfo(elementTypeInfo, context) { for _ in 0.. 0, + "maxContainerMemoryBytes must be positive or -1 for auto") precondition(maxTypeFields > 0, "maxTypeFields must be positive") precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") @@ -49,6 +54,7 @@ public struct Config { self.compatible = effectiveCompatible self.checkClassVersion = effectiveCheckClassVersion self.maxDepth = maxDepth + self.maxContainerMemoryBytes = maxContainerMemoryBytes self.maxTypeFields = maxTypeFields self.maxTypeMetaBytes = maxTypeMetaBytes self.maxSchemaVersionsPerType = maxSchemaVersionsPerType @@ -72,6 +78,7 @@ public final class Fory { compatible: Bool? = nil, checkClassVersion: Bool? = nil, maxDepth: Int = 5, + maxContainerMemoryBytes: Int64 = -1, maxTypeFields: Int = 512, maxTypeMetaBytes: Int = 4096, maxSchemaVersionsPerType: Int = 10, @@ -83,6 +90,7 @@ public final class Fory { compatible: compatible, checkClassVersion: checkClassVersion, maxDepth: maxDepth, + maxContainerMemoryBytes: maxContainerMemoryBytes, maxTypeFields: maxTypeFields, maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, @@ -465,8 +473,9 @@ public final class Fory { func withReusableReadContext( data: Data, _ body: (ReadContext) throws -> R - ) rethrows -> R { + ) throws -> R { readContext.buffer.replace(with: data) + try readContext.initContainerMemoryBudgetKnown(rootBytes: data.count) defer { readContext.reset() } @@ -528,6 +537,7 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) + try readContext.initContainerMemoryBudgetKnown(rootBytes: readContext.buffer.remaining) defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 94242afafa..fd5fb63678 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -20,6 +20,15 @@ import Foundation private let typeMetaSizeMask = 0xFF public final class ReadContext { + static let knownContainerBudgetSlackBytes = 64 * 1024 + static let unknownContainerBudgetBytes = 128 * 1024 * 1024 + static let containerFixedBytes = 32 + static let arrayHeaderBytes = 24 + static let referenceBytes = 4 + static let collectionEntryOverheadBytes = 16 + static let mapEntryOverheadBytes = 24 + private static let maxKnownContainerRootBytes = (Int.max - knownContainerBudgetSlackBytes) / 8 + public let buffer: ByteBuffer let typeResolver: TypeResolver public let trackRef: Bool @@ -35,6 +44,8 @@ public final class ReadContext { private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] private var lastTypeInfo = TypeInfo.uncached private let config: Config + private let maxContainerMemoryBytes: Int + private var remainingContainerMemoryBytes = Int.max init( buffer: ByteBuffer, @@ -48,9 +59,166 @@ public final class ReadContext { self.checkClassVersion = config.checkClassVersion self.maxDepth = config.maxDepth self.config = config + self.maxContainerMemoryBytes = Int(config.maxContainerMemoryBytes) self.refReader = RefReader() } + @inline(__always) + func initContainerMemoryBudgetKnown(rootBytes: Int) throws { + var limit = maxContainerMemoryBytes + if limit < 0 { + if rootBytes > Self.maxKnownContainerRootBytes { + try throwContainerMemoryOverflow() + } + limit = rootBytes * 8 + Self.knownContainerBudgetSlackBytes + } + remainingContainerMemoryBytes = limit + } + + @inline(__always) + func reserveArrayMemory(_ type: Element.Type, count: Int) throws { + try reserveArrayMemory(count: count, elementBytes: containerElementBytes(type)) + } + + @inline(__always) + func reserveFieldArrayMemory( + _ codec: ElementCodec.Type, + count: Int + ) throws { + try reserveArrayMemory(count: count, elementBytes: fieldElementBytes(codec)) + } + + @inline(__always) + func reserveReferenceArrayMemory(count: Int) throws { + try reserveArrayMemory(count: count, elementBytes: Self.referenceBytes) + } + + @inline(__always) + func reserveSetMemory(_ type: Element.Type, count: Int) throws { + try reserveSetMemory(count: count, elementBytes: containerElementBytes(type)) + } + + @inline(__always) + func reserveFieldSetMemory( + _ codec: ElementCodec.Type, + count: Int + ) throws { + try reserveSetMemory(count: count, elementBytes: fieldElementBytes(codec)) + } + + @inline(__always) + func reserveMapMemory( + key _: Key.Type, + value _: Value.Type, + count: Int + ) throws { + try reserveMapMemory( + count: count, + keyBytes: containerElementBytes(Key.self), + valueBytes: containerElementBytes(Value.self) + ) + } + + @inline(__always) + func reserveFieldMapMemory( + key _: KeyCodec.Type, + value _: ValueCodec.Type, + count: Int + ) throws { + try reserveMapMemory( + count: count, + keyBytes: fieldElementBytes(KeyCodec.self), + valueBytes: fieldElementBytes(ValueCodec.self) + ) + } + + @inline(__always) + func reserveReferenceMapMemory(count: Int) throws { + try reserveMapMemory(count: count, keyBytes: Self.referenceBytes, valueBytes: Self.referenceBytes) + } + + @inline(__always) + private func reserveArrayMemory(count: Int, elementBytes: Int) throws { + if count == 0 { + try reserveContainerMemory(Self.containerFixedBytes) + return + } + try reserveCountedContainerMemory( + count: count, + fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes, + elementBytes: elementBytes + ) + } + + @inline(__always) + private func reserveSetMemory(count: Int, elementBytes: Int) throws { + if count == 0 { + try reserveContainerMemory(Self.containerFixedBytes) + return + } + try reserveCountedContainerMemory( + count: count, + fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes, + elementBytes: elementBytes + Self.collectionEntryOverheadBytes + Self.referenceBytes * 2 + ) + } + + @inline(__always) + private func reserveMapMemory(count: Int, keyBytes: Int, valueBytes: Int) throws { + if count == 0 { + try reserveContainerMemory(Self.containerFixedBytes) + return + } + try reserveCountedContainerMemory( + count: count, + fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes * 2, + elementBytes: keyBytes + valueBytes + Self.mapEntryOverheadBytes + Self.referenceBytes + ) + } + + @inline(__always) + private func reserveContainerMemory(_ bytes: Int) throws { + if bytes > remainingContainerMemoryBytes { + try throwContainerMemoryExceeded(bytes: bytes) + } + remainingContainerMemoryBytes -= bytes + } + + @inline(__always) + private func reserveCountedContainerMemory( + count: Int, + fixedBytes: Int, + elementBytes: Int + ) throws { + if count > (Int.max - fixedBytes) / elementBytes { + try throwContainerMemoryOverflow() + } + try reserveContainerMemory(count * elementBytes + fixedBytes) + } + + @inline(__always) + private func containerElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? Self.referenceBytes : max(1, MemoryLayout.stride) + } + + @inline(__always) + private func fieldElementBytes(_ codec: ElementCodec.Type) -> Int { + codec.isRefType ? Self.referenceBytes : max(1, MemoryLayout.stride) + } + + @inline(never) + private func throwContainerMemoryOverflow() throws -> Never { + throw ForyError.invalidData("container memory estimate overflows") + } + + @inline(never) + private func throwContainerMemoryExceeded(bytes: Int) throws -> Never { + let message = + "estimated container memory request \(bytes) bytes exceeds maxContainerMemoryBytes " + + "remaining budget \(remainingContainerMemoryBytes) bytes" + throw ForyError.invalidData(message) + } + @inline(__always) func enterDynamicAnyDepth() throws { if maxDepth < 0 { diff --git a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift new file mode 100644 index 0000000000..3650f55f48 --- /dev/null +++ b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation +import Testing +@testable import Fory + +@ForyStruct +private final class BudgetNode { + var id: Int32 = 0 + + required init() {} + + init(id: Int32) { + self.id = id + } +} + +@ForyStruct +private struct BudgetSiblings { + var left: [BudgetNode] = [] + var right: [BudgetNode] = [] +} + +@ForyStruct +private struct BudgetDenseHolder: Equatable { + var text: String = "" + var data: Data = Data() + @ArrayField(element: .int32()) + var dense: [Int32] = [] +} + +private func makeBudgetFory(maxContainerMemoryBytes: Int64 = -1) -> Fory { + let fory = Fory(config: .init( + trackRef: false, + compatible: false, + maxContainerMemoryBytes: maxContainerMemoryBytes + )) + fory.register(BudgetNode.self, id: 9801) + fory.register(BudgetSiblings.self, id: 9802) + fory.register(BudgetDenseHolder.self, id: 9803) + return fory +} + +private func elementBytes(_ type: Element.Type) -> Int { + type.isRefType ? ReadContext.referenceBytes : max(1, MemoryLayout.stride) +} + +private func arrayBudget(_ type: Element.Type, count: Int) -> Int { + if count == 0 { + return ReadContext.containerFixedBytes + } + return ReadContext.containerFixedBytes + ReadContext.arrayHeaderBytes + + count * elementBytes(type) +} + +private func mapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + if count == 0 { + return ReadContext.containerFixedBytes + } + return ReadContext.containerFixedBytes + ReadContext.arrayHeaderBytes * 2 + + count * ( + elementBytes(key) + elementBytes(value) + + ReadContext.mapEntryOverheadBytes + ReadContext.referenceBytes + ) +} + +private func expectInvalidData(_ body: () throws -> Void) { + do { + try body() + Issue.record("expected invalid data") + } catch ForyError.invalidData { + } catch { + Issue.record("expected invalid data, got \(error)") + } +} + +@Test +func knownLengthAutoBudgetRejectsNestedEmptyArrays() throws { + let count = 6_000 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let autoLimit = bytes.count * 8 + ReadContext.knownContainerBudgetSlackBytes + let required = arrayBudget([String].self, count: count) + + count * arrayBudget(String.self, count: 0) + #expect(required > autoLimit) + + expectInvalidData { + let _: [[String]] = try makeBudgetFory().deserialize(bytes) + } + + let decoded: [[String]] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) + #expect(decoded.count == count) +} + +@Test +func byteBufferRootUsesKnownLengthAutoBudget() throws { + let count = 6_000 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let buffer = ByteBuffer(data: bytes) + + expectInvalidData { + let _: [[String]] = try makeBudgetFory().deserialize(from: buffer) + } +} + +@Test +func explicitConfigOverridesAutoBudget() throws { + let values = (0..<16).map(Int32.init) + let bytes = try makeBudgetFory().serialize(values) + let required = arrayBudget(Int32.self, count: values.count) + + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)).deserialize(bytes) + } + let decoded: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) + #expect(decoded == values) +} + +@Test +func siblingContainersShareOneBudget() throws { + let value = BudgetSiblings( + left: (0..<16).map { BudgetNode(id: Int32($0)) }, + right: (16..<32).map { BudgetNode(id: Int32($0)) } + ) + let bytes = try makeBudgetFory().serialize(value) + let oneList = arrayBudget(BudgetNode.self, count: 16) + + expectInvalidData { + let _: BudgetSiblings = try makeBudgetFory(maxContainerMemoryBytes: Int64(oneList)).deserialize(bytes) + } + let decoded: BudgetSiblings = try makeBudgetFory(maxContainerMemoryBytes: Int64(oneList * 2)).deserialize(bytes) + #expect(decoded.left.count == 16) + #expect(decoded.right.count == 16) +} + +@Test +func mapBudgetIsCharged() throws { + let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] + let bytes = try makeBudgetFory().serialize(value) + let required = mapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)).deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) + #expect(decoded == value) +} + +@Test +func referenceAndInlineValueArraysAreCharged() throws { + let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } + let nodeBytes = try makeBudgetFory().serialize(nodes) + let nodeBudget = arrayBudget(BudgetNode.self, count: nodes.count) + expectInvalidData { + let _: [BudgetNode] = try makeBudgetFory(maxContainerMemoryBytes: Int64(nodeBudget - 1)).deserialize(nodeBytes) + } + let decodedNodes: [BudgetNode] = try makeBudgetFory(maxContainerMemoryBytes: Int64(nodeBudget)).deserialize(nodeBytes) + #expect(decodedNodes.count == nodes.count) + + let ints: [Int32] = [1, 2, 3, 4] + let intBytes = try makeBudgetFory().serialize(ints) + let intBudget = arrayBudget(Int32.self, count: ints.count) + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(intBudget - 1)).deserialize(intBytes) + } + #expect(try makeBudgetFory(maxContainerMemoryBytes: Int64(intBudget)).deserialize(intBytes) as [Int32] == ints) +} + +@Test +func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { + let value = BudgetDenseHolder( + text: "budget", + data: Data([1, 2, 3]), + dense: [1, 2, 3] + ) + let bytes = try makeBudgetFory().serialize(value) + + let decoded: BudgetDenseHolder = try makeBudgetFory(maxContainerMemoryBytes: 1).deserialize(bytes) + #expect(decoded == value) +} + +@Test +func dynamicAnyEmptyMapChargesFixedCost() throws { + let value = [:] as [AnyHashable: Any] + let bytes = try makeBudgetFory().serialize(value as Any) + let required = ReadContext.containerFixedBytes * 3 + + expectInvalidData { + let _: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect((decoded as? [String: Any])?.isEmpty == true) +} + +@Test +func byteAvailabilityCheckStillRejectsLargeLength() throws { + let buffer = ByteBuffer() + buffer.writeVarUInt32(64) + buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: buffer, + typeResolver: TypeResolver(config: config), + config: config + ) + + expectInvalidData { + let _: [String] = try [String].foryReadData(context) + } +} diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index cde6b8a5a1..c81b70b30f 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -382,6 +382,7 @@ func namedInitializerBuildsConfig() { #expect(defaultConfig.config.compatible == true) #expect(defaultConfig.config.checkClassVersion == false) #expect(defaultConfig.config.maxDepth == 5) + #expect(defaultConfig.config.maxContainerMemoryBytes == -1) #expect(defaultConfig.config.maxTypeFields == 512) #expect(defaultConfig.config.maxTypeMetaBytes == 4096) #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) @@ -391,6 +392,7 @@ func namedInitializerBuildsConfig() { ref: true, compatible: true, maxDepth: 7, + maxContainerMemoryBytes: 65_536, maxTypeFields: 31, maxTypeMetaBytes: 1234, maxSchemaVersionsPerType: 12, @@ -400,6 +402,7 @@ func namedInitializerBuildsConfig() { #expect(explicitConfig.config.compatible == true) #expect(explicitConfig.config.checkClassVersion == false) #expect(explicitConfig.config.maxDepth == 7) + #expect(explicitConfig.config.maxContainerMemoryBytes == 65_536) #expect(explicitConfig.config.maxTypeFields == 31) #expect(explicitConfig.config.maxTypeMetaBytes == 1234) #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) @@ -410,6 +413,7 @@ func namedInitializerBuildsConfig() { trackRef: false, compatible: true, maxDepth: 9, + maxContainerMemoryBytes: 131_072, maxTypeFields: 41, maxTypeMetaBytes: 2048, maxSchemaVersionsPerType: 14, @@ -419,6 +423,7 @@ func namedInitializerBuildsConfig() { #expect(configInit.config.compatible == true) #expect(configInit.config.checkClassVersion == false) #expect(configInit.config.maxDepth == 9) + #expect(configInit.config.maxContainerMemoryBytes == 131_072) #expect(configInit.config.maxTypeFields == 41) #expect(configInit.config.maxTypeMetaBytes == 2048) #expect(configInit.config.maxSchemaVersionsPerType == 14) From db6513bfd0fd35f0ea3f0d2e806e92c7e09de626 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Sat, 27 Jun 2026 02:56:59 +0800 Subject: [PATCH 2/2] fix: repair container memory budget CI --- go/fory/buffer.go | 11 +++++++++++ go/fory/container_memory_budget_test.go | 18 ++++++++++++++++++ go/fory/fory.go | 9 +++++++++ .../CompatibleDifferentSchemaExample.java | 2 +- python/pyfory/buffer.pxi | 9 +++++++++ python/pyfory/context.py | 3 +-- .../scala/ForySerializerDerivationTest.scala | 2 +- 7 files changed, 50 insertions(+), 4 deletions(-) diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 89e29f938d..1a1a067b7e 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -482,6 +482,17 @@ func (b *ByteBuffer) ReaderIndex() int { return b.readerIndex } +func (b *ByteBuffer) readableBytes() int { + end := b.writerIndex + if len(b.data) > end { + end = len(b.data) + } + if b.readerIndex >= end { + return 0 + } + return end - b.readerIndex +} + func (b *ByteBuffer) SetReaderIndex(index int) { b.readerIndex = index } diff --git a/go/fory/container_memory_budget_test.go b/go/fory/container_memory_budget_test.go index 16959b3d0a..e83e8d7a03 100644 --- a/go/fory/container_memory_budget_test.go +++ b/go/fory/container_memory_budget_test.go @@ -85,6 +85,24 @@ func TestContainerMemoryBudgetKnownVsStreamRoot(t *testing.T) { require.Len(t, fromStream, len(values)) } +func TestContainerMemoryBudgetBufferRoots(t *testing.T) { + writer := New(WithCompatible(false)) + value := []string{"a", "b"} + data, err := writer.Serialize(value) + require.NoError(t, err) + + reader := New(WithCompatible(false)) + var fromCallback []string + err = reader.DeserializeWithCallbackBuffers(NewByteBuffer(data), &fromCallback, nil) + require.NoError(t, err) + require.Equal(t, value, fromCallback) + + var fromBuffer []string + err = reader.DeserializeFrom(NewByteBuffer(data), &fromBuffer) + require.NoError(t, err) + require.Equal(t, value, fromBuffer) +} + func TestContainerMemoryBudgetExplicitOverride(t *testing.T) { writer := New(WithCompatible(false)) values := make([]any, 12000) diff --git a/go/fory/fory.go b/go/fory/fory.go index 7bfb9867ef..3b6360aece 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -666,6 +666,11 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { // Temporarily swap buffer origBuffer := f.readCtx.buffer f.readCtx.buffer = buf + f.readCtx.initContainerMemoryBudget(buf.readableBytes(), false) + if f.readCtx.HasError() { + f.readCtx.buffer = origBuffer + return f.readCtx.TakeError() + } readHeader(f.readCtx) if f.readCtx.HasError() { @@ -761,6 +766,10 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers f.readCtx.buffer = nil f.readCtx.outOfBandBuffers = nil }() + f.readCtx.initContainerMemoryBudget(buffer.readableBytes(), false) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } // Set up out-of-band buffers if provided if buffers != nil { f.readCtx.outOfBandBuffers = buffers diff --git a/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java b/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java index 41ddae5cda..c70d016deb 100644 --- a/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java +++ b/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java @@ -89,7 +89,7 @@ private static Serializer readSerializerForTarget( MemoryBuffer buffer = MemoryUtils.wrap(bytes); buffer.readByte(); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); try { readContext.getRefReader().tryPreserveRefId(buffer); TypeInfo typeInfo = fory.getTypeResolver().readTypeInfo(readContext, targetClass); diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 4fa77e9e68..53ff441f25 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -335,6 +335,15 @@ cdef class Buffer: f"Address range {offset, offset + length} out of bound {0, size_}", ) + cpdef inline ensure_readable(self, int32_t length): + if length < 0: + raise_fory_error(CErrorCode.InvalidData, f"Readable byte count {length} is negative") + if length == 0: + return + if not self.c_buffer.ensure_readable(length, self._error): + if not self._error.ok(): + self._raise_if_error() + cpdef inline write_bool(self, c_bool value): self.c_buffer.write_uint8(value) diff --git a/python/pyfory/context.py b/python/pyfory/context.py index a923731c4b..8b620629da 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -527,8 +527,7 @@ def check_readable_bytes(self, length): raise ValueError(f"Readable byte count {length} is negative") if length == 0: return - reader_index = self.buffer.get_reader_index() - self.buffer.check_bound(reader_index, length) + self.buffer.ensure_readable(length) def prepare( self, diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala index 7cf598f381..22e7044d23 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -571,7 +571,7 @@ import org.apache.fory.scala.ForyScala buffer.readVarUInt32() shouldBe 0 buffer.readerIndex(0) val readContext = fory.getReadContext - readContext.prepare(buffer, null, false) + readContext.prepare(buffer, null, false, buffer.remaining(), false) try serializer.read(readContext) shouldBe value finally readContext.reset() }