Skip to content

Commit 2956635

Browse files
dfa1claude
andcommitted
feat(writer): zone-map MIN/MAX for dict-encoded columns
Rust computes zone-map stats on the logical column dtype, independent of the dict encoding, so dict columns carry MIN/MAX/NULL_COUNT like any other. The Java dict path emitted NULL_COUNT only — a parity gap. Compute per-chunk min/max on each chunk's logical values at dict-build time (reusing PrimitiveEncodingEncoder.minMaxStats / VarBinEncodingEncoder .minMaxStats, now exposed so the dict and flat paths stay identical) and carry them on DictColRef. Unify zone-map emission: both the flat and dict loops feed per-zone min/max scalar bytes + null counts through one emitZoneMap helper (replacing the dict-only NULL_COUNT writer), and the stat-column builders now take the scalar bytes directly. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 1d447ac commit 2956635

4 files changed

Lines changed: 188 additions & 99 deletions

File tree

writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java

Lines changed: 101 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -701,70 +701,71 @@ private void flushZoneMaps() throws IOException {
701701
if (chunks.isEmpty()) {
702702
continue;
703703
}
704-
int nZones = chunks.size();
705-
boolean[] allValid = new boolean[nZones];
706-
java.util.Arrays.fill(allValid, true);
707-
708-
// NULL_COUNT is computable for every column type; MIN/MAX whenever every chunk carries
709-
// stats. Fixed-width primitives store min/max as that primitive (extension columns
710-
// unwrap to their storage primitive); Utf8 stores them as full strings. Field/bit
711-
// order follows ZonedStatsSchema: MAX(3), MIN(4), NULL_COUNT(6); each stat nullable, and
712-
// MAX/MIN carry a trailing `_is_truncated` Bool (always false — we never truncate).
713-
DType colDtype = schema.fieldTypes().get(schema.fieldNames().indexOf(colName));
714-
DType minMaxDtype = zoneMinMaxDtype(colDtype);
715-
boolean hasMinMax = minMaxDtype != null
716-
&& chunks.stream().allMatch(ChunkRef::hasStats);
717-
718-
List<String> names = new java.util.ArrayList<>();
719-
List<DType> types = new java.util.ArrayList<>();
720-
List<Object> fields = new java.util.ArrayList<>();
721-
if (hasMinMax) {
722-
boolean[] notTruncated = new boolean[nZones];
723-
names.add("max");
724-
types.add(minMaxDtype);
725-
fields.add(new NullableData(zoneStatValues(minMaxDtype, chunks, true), allValid.clone()));
726-
names.add("max_is_truncated");
727-
types.add(new DType.Bool(false));
728-
fields.add(notTruncated);
729-
names.add("min");
730-
types.add(minMaxDtype);
731-
fields.add(new NullableData(zoneStatValues(minMaxDtype, chunks, false), allValid.clone()));
732-
names.add("min_is_truncated");
733-
types.add(new DType.Bool(false));
734-
fields.add(notTruncated.clone());
735-
}
736-
long[] nullCounts = new long[nZones];
737-
for (int i = 0; i < nZones; i++) {
704+
DType minMaxDtype = zoneMinMaxDtype(columnDtype(colName));
705+
boolean hasMinMax = minMaxDtype != null && chunks.stream().allMatch(ChunkRef::hasStats);
706+
long[] nullCounts = new long[chunks.size()];
707+
for (int i = 0; i < chunks.size(); i++) {
738708
nullCounts[i] = chunks.get(i).nullCount();
739709
}
740-
names.add("null_count");
741-
types.add(new DType.Primitive(PType.U64, true));
742-
fields.add(new NullableData(nullCounts, allValid.clone()));
743-
744-
DType.Struct statsDtype = new DType.Struct(List.copyOf(names), List.copyOf(types), false);
745-
int zonesSegIdx = writeSegment(statsDtype, new StructData(fields), new StructEncodingEncoder());
746-
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), hasMinMax));
710+
emitZoneMap(colName, hasMinMax ? minMaxDtype : null,
711+
chunks.stream().map(ChunkRef::statsMin).toList(),
712+
chunks.stream().map(ChunkRef::statsMax).toList(),
713+
nullCounts);
747714
}
748-
// Dict-encoded columns live in a separate path (one zone per code chunk); they carry
749-
// NULL_COUNT only for now (no dict-level MIN/MAX yet). Matches Rust, which zone-maps dict
750-
// columns (vortex.stats wrapping vortex.dict).
715+
// Dict-encoded columns (one zone per code chunk). MIN/MAX come from each chunk's logical
716+
// values (computed at dict-build time); NULL_COUNT always. Matches Rust, whose zone-map
717+
// stats are computed on the logical column dtype, independent of the dict encoding.
751718
for (Map.Entry<String, DictColRef> e : dictColRefs.entrySet()) {
752-
// A dict column always has at least one code chunk, so null counts are non-empty.
753-
long[] nullCounts = e.getValue().chunkNullCounts().stream().mapToLong(Long::longValue).toArray();
754-
writeNullCountZoneMap(e.getKey(), nullCounts);
719+
DictColRef ref = e.getValue();
720+
DType minMaxDtype = zoneMinMaxDtype(columnDtype(e.getKey()));
721+
boolean hasMinMax = minMaxDtype != null
722+
&& ref.chunkStatsMin().stream().allMatch(java.util.Objects::nonNull)
723+
&& ref.chunkStatsMax().stream().allMatch(java.util.Objects::nonNull);
724+
long[] nullCounts = ref.chunkNullCounts().stream().mapToLong(Long::longValue).toArray();
725+
emitZoneMap(e.getKey(), hasMinMax ? minMaxDtype : null,
726+
ref.chunkStatsMin(), ref.chunkStatsMax(), nullCounts);
755727
}
756728
}
757729

758-
/// Emits a NULL_COUNT-only `vortex.stats` zone-map (one zone per chunk) for `colName`.
759-
private void writeNullCountZoneMap(String colName, long[] nullCounts) throws IOException {
730+
private DType columnDtype(String colName) {
731+
return schema.fieldTypes().get(schema.fieldNames().indexOf(colName));
732+
}
733+
734+
/// Writes one `vortex.stats` zone-map for `colName`: one zone per chunk, with NULL_COUNT always
735+
/// and MAX/MIN (plus always-false `_is_truncated` flags) when `minMaxDtype` is non-null.
736+
/// `minBytes`/`maxBytes` hold each zone's serialised min/max scalar — read only when
737+
/// `minMaxDtype` is set. Field/bit order follows ZonedStatsSchema: MAX(3), MIN(4), NULL_COUNT(6).
738+
private void emitZoneMap(String colName, DType minMaxDtype,
739+
List<byte[]> minBytes, List<byte[]> maxBytes, long[] nullCounts) throws IOException {
760740
int nZones = nullCounts.length;
761741
boolean[] allValid = new boolean[nZones];
762742
java.util.Arrays.fill(allValid, true);
763-
DType.Struct statsDtype = new DType.Struct(
764-
List.of("null_count"), List.of(new DType.Primitive(PType.U64, true)), false);
765-
StructData sd = new StructData(List.of(new NullableData(nullCounts, allValid)));
766-
int zonesSegIdx = writeSegment(statsDtype, sd, new StructEncodingEncoder());
767-
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), false));
743+
744+
List<String> names = new java.util.ArrayList<>();
745+
List<DType> types = new java.util.ArrayList<>();
746+
List<Object> fields = new java.util.ArrayList<>();
747+
if (minMaxDtype != null) {
748+
boolean[] notTruncated = new boolean[nZones];
749+
names.add("max");
750+
types.add(minMaxDtype);
751+
fields.add(new NullableData(zoneStatValues(minMaxDtype, maxBytes), allValid.clone()));
752+
names.add("max_is_truncated");
753+
types.add(new DType.Bool(false));
754+
fields.add(notTruncated);
755+
names.add("min");
756+
types.add(minMaxDtype);
757+
fields.add(new NullableData(zoneStatValues(minMaxDtype, minBytes), allValid.clone()));
758+
names.add("min_is_truncated");
759+
types.add(new DType.Bool(false));
760+
fields.add(notTruncated.clone());
761+
}
762+
names.add("null_count");
763+
types.add(new DType.Primitive(PType.U64, true));
764+
fields.add(new NullableData(nullCounts, allValid.clone()));
765+
766+
DType.Struct statsDtype = new DType.Struct(List.copyOf(names), List.copyOf(types), false);
767+
int zonesSegIdx = writeSegment(statsDtype, new StructData(fields), new StructEncodingEncoder());
768+
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), minMaxDtype != null));
768769
}
769770

770771
/// Wraps a column's data layout in a `vortex.stats` (zoned) layout when a zone-map was
@@ -820,94 +821,92 @@ private static DType zoneMinMaxDtype(DType dtype) {
820821
}
821822

822823
/// Builds the per-zone min (or max) values array for the resolved min/max `dtype`, decoding each
823-
/// chunk's serialised [ScalarValue] stat into the array shape its encoder expects.
824-
private static Object zoneStatValues(DType minMaxDtype, List<ChunkRef> chunks, boolean max) throws IOException {
824+
/// zone's serialised [ScalarValue] stat into the array shape its encoder expects.
825+
private static Object zoneStatValues(DType minMaxDtype, List<byte[]> statBytes) throws IOException {
825826
return switch (minMaxDtype) {
826-
case DType.Primitive p -> statColumn(p.ptype(), chunks, max);
827-
case DType.Utf8 ignored -> statStringColumn(chunks, max);
827+
case DType.Primitive p -> statColumn(p.ptype(), statBytes);
828+
case DType.Utf8 ignored -> statStringColumn(statBytes);
828829
default -> throw new IllegalStateException("no zone stat values for " + minMaxDtype);
829830
};
830831
}
831832

832-
/// Builds the per-zone min (or max) string array, decoding each chunk's serialised string
833-
/// [ScalarValue] stat. Used for Utf8 columns whose `vortex.varbin` encoder records full
834-
/// string min/max scalars.
835-
private static String[] statStringColumn(List<ChunkRef> chunks, boolean max) throws IOException {
836-
String[] out = new String[chunks.size()];
833+
/// Builds the per-zone string array by decoding each zone's serialised string [ScalarValue]
834+
/// stat. Used for Utf8 columns whose `vortex.varbin` encoder records full string min/max scalars.
835+
private static String[] statStringColumn(List<byte[]> statBytes) throws IOException {
836+
String[] out = new String[statBytes.size()];
837837
for (int i = 0; i < out.length; i++) {
838-
ChunkRef cr = chunks.get(i);
839-
out[i] = decodeScalar(max ? cr.statsMax() : cr.statsMin()).string_value();
838+
out[i] = decodeScalar(statBytes.get(i)).string_value();
840839
}
841840
return out;
842841
}
843842

844-
/// Builds the per-zone min (or max) values array in the storage shape the primitive encoder
845-
/// expects, decoding each chunk's serialised [ScalarValue] stat.
846-
private static Object statColumn(PType ptype, List<ChunkRef> chunks, boolean max) throws IOException {
847-
int n = chunks.size();
843+
/// Builds the per-zone values array in the storage shape the primitive encoder expects, decoding
844+
/// each zone's serialised [ScalarValue] stat.
845+
private static Object statColumn(PType ptype, List<byte[]> statBytes) throws IOException {
846+
int n = statBytes.size();
848847
return switch (ptype) {
849848
case I8, U8 -> {
850849
byte[] a = new byte[n];
851850
for (int i = 0; i < n; i++) {
852-
a[i] = (byte) scalarLong(chunks.get(i), max);
851+
a[i] = (byte) scalarLong(statBytes.get(i));
853852
}
854853
yield a;
855854
}
856855
case I16, U16 -> {
857856
short[] a = new short[n];
858857
for (int i = 0; i < n; i++) {
859-
a[i] = (short) scalarLong(chunks.get(i), max);
858+
a[i] = (short) scalarLong(statBytes.get(i));
860859
}
861860
yield a;
862861
}
863862
case I32, U32 -> {
864863
int[] a = new int[n];
865864
for (int i = 0; i < n; i++) {
866-
a[i] = (int) scalarLong(chunks.get(i), max);
865+
a[i] = (int) scalarLong(statBytes.get(i));
867866
}
868867
yield a;
869868
}
870869
case I64, U64 -> {
871870
long[] a = new long[n];
872871
for (int i = 0; i < n; i++) {
873-
a[i] = scalarLong(chunks.get(i), max);
872+
a[i] = scalarLong(statBytes.get(i));
874873
}
875874
yield a;
876875
}
877876
case F32 -> {
878877
float[] a = new float[n];
879878
for (int i = 0; i < n; i++) {
880-
a[i] = (float) scalarDouble(chunks.get(i), max);
879+
a[i] = (float) scalarDouble(statBytes.get(i));
881880
}
882881
yield a;
883882
}
884883
case F64 -> {
885884
double[] a = new double[n];
886885
for (int i = 0; i < n; i++) {
887-
a[i] = scalarDouble(chunks.get(i), max);
886+
a[i] = scalarDouble(statBytes.get(i));
888887
}
889888
yield a;
890889
}
891890
case F16 -> {
892891
// F16 min/max are serialised as f32 scalars; re-pack to float16 storage.
893892
short[] a = new short[n];
894893
for (int i = 0; i < n; i++) {
895-
a[i] = Float.floatToFloat16((float) scalarDouble(chunks.get(i), max));
894+
a[i] = Float.floatToFloat16((float) scalarDouble(statBytes.get(i)));
896895
}
897896
yield a;
898897
}
899898
};
900899
}
901900

902-
private static long scalarLong(ChunkRef cr, boolean max) throws IOException {
901+
private static long scalarLong(byte[] bytes) throws IOException {
903902
// Integer columns serialise min/max as int64 (signed) or uint64 (unsigned).
904-
ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin());
903+
ScalarValue sv = decodeScalar(bytes);
905904
return sv.int64_value() != null ? sv.int64_value() : sv.uint64_value();
906905
}
907906

908-
private static double scalarDouble(ChunkRef cr, boolean max) throws IOException {
907+
private static double scalarDouble(byte[] bytes) throws IOException {
909908
// Float columns serialise min/max as f64 (F64) or f32 (F32).
910-
ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin());
909+
ScalarValue sv = decodeScalar(bytes);
911910
return sv.f64_value() != null ? sv.f64_value() : sv.f32_value();
912911
}
913912

@@ -1098,16 +1097,25 @@ private void writeGlobalDictColumn(String colName, DType.Primitive dtype, List<O
10981097
List<Integer> codesSegIdxes = new ArrayList<>();
10991098
List<Long> chunkRowCounts = new ArrayList<>();
11001099
List<Long> chunkNullCounts = new ArrayList<>();
1100+
List<byte[]> chunkStatsMin = new ArrayList<>();
1101+
List<byte[]> chunkStatsMax = new ArrayList<>();
11011102
for (Object chunk : chunks) {
11021103
int len = primitiveArrayLen(chunk, ptype);
11031104
Object codesArr = buildCodesArray(chunk, ptype, valueMap, codePType, len);
11041105
codesSegIdxes.add(writeSegment(codesDtype, codesArr));
11051106
chunkRowCounts.add((long) len);
11061107
chunkNullCounts.add(chunk instanceof NullableData nd ? countNulls(nd.validity()) : 0L);
1108+
// Per-zone min/max over the chunk's logical values (matches the flat primitive path:
1109+
// computed on nd.values(), placeholders included). Lets the dict zone-map prune like a
1110+
// plain primitive column.
1111+
Object values = chunk instanceof NullableData nd ? nd.values() : chunk;
1112+
byte[][] mm = PrimitiveEncodingEncoder.minMaxStats(ptype, values);
1113+
chunkStatsMin.add(mm != null ? mm[0] : null);
1114+
chunkStatsMax.add(mm != null ? mm[1] : null);
11071115
}
11081116

1109-
dictColRefs.put(colName,
1110-
new DictColRef(valuesSegIdx, dictSize, codesSegIdxes, chunkRowCounts, chunkNullCounts));
1117+
dictColRefs.put(colName, new DictColRef(valuesSegIdx, dictSize, codesSegIdxes,
1118+
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax));
11111119
}
11121120

11131121
private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Object> chunks)
@@ -1144,15 +1152,22 @@ private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Ob
11441152
List<Integer> codesSegIdxes = new ArrayList<>();
11451153
List<Long> chunkRowCounts = new ArrayList<>();
11461154
List<Long> chunkNullCounts = new ArrayList<>();
1155+
List<byte[]> chunkStatsMin = new ArrayList<>();
1156+
List<byte[]> chunkStatsMax = new ArrayList<>();
11471157
for (Object chunk : chunks) {
11481158
String[] strs = (String[]) chunk;
11491159
Object codesArr = buildUtf8CodesArray(strs, valueMap, codePType);
11501160
codesSegIdxes.add(writeSegment(codesDtype, codesArr));
11511161
chunkRowCounts.add((long) strs.length);
11521162
chunkNullCounts.add(0L); // global-dict Utf8 columns are dense (non-nullable)
1163+
// Per-zone string min/max over the chunk's values (matches the flat varbin path), so the
1164+
// dict zone-map prunes like a plain Utf8 column.
1165+
byte[][] mm = VarBinEncodingEncoder.minMaxStats(strs);
1166+
chunkStatsMin.add(mm != null ? mm[0] : null);
1167+
chunkStatsMax.add(mm != null ? mm[1] : null);
11531168
}
1154-
dictColRefs.put(colName,
1155-
new DictColRef(valuesSegIdx, dictSize, codesSegIdxes, chunkRowCounts, chunkNullCounts));
1169+
dictColRefs.put(colName, new DictColRef(valuesSegIdx, dictSize, codesSegIdxes,
1170+
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax));
11561171
}
11571172

11581173
private static Object buildUtf8CodesArray(String[] strs, Map<String, Integer> valueMap, PType codePType) {
@@ -1357,7 +1372,8 @@ private record ZoneMapRef(int zonesSegIdx, long nZones, long zoneLen, boolean ha
13571372
}
13581373

13591374
private record DictColRef(int valuesSegIdx, long valuesLen, List<Integer> codesSegIdxes,
1360-
List<Long> chunkRowCounts, List<Long> chunkNullCounts) {
1375+
List<Long> chunkRowCounts, List<Long> chunkNullCounts,
1376+
List<byte[]> chunkStatsMin, List<byte[]> chunkStatsMax) {
13611377
long totalRows() {
13621378
return chunkRowCounts.stream().mapToLong(Long::longValue).sum();
13631379
}

writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {
3232
MemorySegment seg = encodePrimitive(ptype, data, ctx.arena());
3333
byte[] min = null;
3434
byte[] max = null;
35-
byte[][] stats = computeStats(ptype, data);
35+
byte[][] stats = minMaxStats(ptype, data);
3636
if (stats != null) {
3737
min = stats[0];
3838
max = stats[1];
@@ -86,7 +86,15 @@ private static MemorySegment encodePrimitive(PType ptype, Object data, Arena are
8686
};
8787
}
8888

89-
private static byte[][] computeStats(PType ptype, Object data) {
89+
/// Computes the serialised min/max [io.github.dfa1.vortex.proto.ScalarValue] pair for a raw
90+
/// primitive array, in the same signed/unsigned/float shape the per-segment stats use. Returns
91+
/// `null` for an empty array. Shared so the dictionary zone-map path computes per-chunk min/max
92+
/// identically to the flat path.
93+
///
94+
/// @param ptype the primitive type of `data`
95+
/// @param data the raw primitive array (e.g. `long[]`, `int[]`, `String`-free)
96+
/// @return a two-element `{min, max}` array of encoded scalars, or `null` if `data` is empty
97+
public static byte[][] minMaxStats(PType ptype, Object data) {
9098
return switch (ptype) {
9199
case I8 -> {
92100
byte[] arr = (byte[]) data;

0 commit comments

Comments
 (0)