Skip to content

Commit 2a9dc3c

Browse files
dfa1claude
andcommitted
feat(pco): Phase 2 — Classic/None/I64 tANS decode
- PcoBin: bin record (weight, lower, offsetBits) - PcoTansDecoder: spread_state_symbols + tANS node table; decodePage writes raw U64 latents to MemorySegment - PcoEncoding.Decoder: reads chunk meta bits, builds tANS table per chunk, decodes pages; converts U64→I64 via sign-bit flip (from_latent_ordered); Phase 3+ ptypes and delta=Consecutive throw clear "not yet implemented" Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d616df5 commit 2a9dc3c

5 files changed

Lines changed: 445 additions & 22 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package io.github.dfa1.vortex.encoding;
2+
3+
/// One bin in a pco latent variable: a numerical range [lower, lower + 2^offsetBits).
4+
///
5+
/// {@code weight} is the bin's count in the tANS table (sum of weights == table size).
6+
/// {@code lower} is the raw unsigned lower bound (U64 for 64-bit latents).
7+
/// {@code offsetBits} is the log2 of the range size (0 = single value).
8+
record PcoBin(int weight, long lower, int offsetBits) {
9+
}

core/src/main/java/io/github/dfa1/vortex/encoding/PcoEncoding.java

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
package io.github.dfa1.vortex.encoding;
22

33
import com.google.protobuf.InvalidProtocolBufferException;
4+
import io.github.dfa1.vortex.core.ArrayStats;
45
import io.github.dfa1.vortex.core.DType;
6+
import io.github.dfa1.vortex.core.PType;
57
import io.github.dfa1.vortex.core.VortexException;
68
import io.github.dfa1.vortex.core.array.Array;
9+
import io.github.dfa1.vortex.core.array.LongArray;
710
import io.github.dfa1.vortex.proto.EncodingProtos;
811

12+
import java.lang.foreign.MemorySegment;
13+
import java.lang.foreign.ValueLayout;
914
import java.nio.ByteBuffer;
15+
import java.nio.ByteOrder;
1016

1117
/// Decoder for {@code vortex.pco} (pcodec numerical compression).
1218
///
@@ -19,20 +25,26 @@
1925
///
2026
/// <p>Wire format (pcodec layer, per chunk/page):
2127
/// <ul>
22-
/// <li>Chunk meta: mode nibble + extra mode bits + delta nibble + extra delta bits +
23-
/// per-latent: ans_size_log (4b), bin_count (15b), per-bin {weight-1, lower, offset_bits}</li>
24-
/// <li>Page: initial latent state (delta state_n + 4 tANS state indices) → byte align →
25-
/// per 256-batch: tANS-decoded bin indices + offset bits</li>
28+
/// <li>Chunk meta: [4b mode][extra mode bits][4b delta][extra delta bits]
29+
/// [per-latent: 4b ans_size_log, 15b n_bins, per-bin {weight-1, lower, offset_bits}]
30+
/// [0–7b alignment]</li>
31+
/// <li>Page: [4 × ans_size_log b initial states][0–7b alignment]
32+
/// [per 256-batch: ANS bits for all k, then offset bits for all k]</li>
2633
/// <li>All bit packing little-endian (LSB first)</li>
2734
/// </ul>
2835
///
29-
/// <p>Phase 1: skeleton only — parses metadata, validates header, dispatches on PType.
30-
/// Phase 2 adds Classic/None decode for I64; later phases extend to all ptypes and modes.
36+
/// <p>Supported (Phase 2): Classic mode, None delta, non-null, I64.
37+
/// Other modes/deltas/ptypes throw with a clear "not yet implemented" message.
3138
public final class PcoEncoding implements Encoding {
3239

3340
static final byte PCO_FORMAT_MAJOR = 0x04;
3441
static final byte PCO_FORMAT_MINOR = 0x01;
3542

43+
// bits needed to encode offset_bits field per latent type
44+
static final int BITS_TO_ENCODE_OFFSET_BITS_64 = 7; // log2(64) + 1
45+
static final int BITS_TO_ENCODE_OFFSET_BITS_32 = 6; // log2(32) + 1
46+
static final int BITS_TO_ENCODE_OFFSET_BITS_16 = 5; // log2(16) + 1
47+
3648
@Override
3749
public EncodingId encodingId() {
3850
return EncodingId.VORTEX_PCO;
@@ -58,12 +70,92 @@ static EncodeResult encode(DType dtype, Object data) {
5870

5971
static final class Decoder {
6072

73+
private static final ValueLayout.OfLong LE_LONG =
74+
ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
75+
6176
static Array decode(DecodeContext ctx) {
6277
EncodingProtos.PcoMetadata meta = parseMeta(ctx);
6378
validateHeader(meta);
64-
throw new VortexException(EncodingId.VORTEX_PCO,
65-
"pco decode not yet implemented — Phase 2 pending (chunks="
66-
+ meta.getChunksCount() + ")");
79+
80+
DType dtype = ctx.dtype();
81+
if (!(dtype instanceof DType.Primitive dt)) {
82+
throw new VortexException(EncodingId.VORTEX_PCO,
83+
"pco decode requires Primitive dtype, got: " + dtype);
84+
}
85+
PType ptype = dt.ptype();
86+
if (ptype != PType.I64) {
87+
throw new VortexException(EncodingId.VORTEX_PCO,
88+
"pco decode Phase 2: only I64 supported, got: " + ptype);
89+
}
90+
91+
long n = ctx.rowCount();
92+
MemorySegment out = ctx.arena().allocate(n * Long.BYTES);
93+
94+
int nChunks = meta.getChunksCount();
95+
int bufIdx = 0;
96+
long outByteOffset = 0L;
97+
98+
for (int c = 0; c < nChunks; c++) {
99+
EncodingProtos.PcoChunkInfo chunkInfo = meta.getChunks(c);
100+
MemorySegment chunkMetaBuf = ctx.buffer(bufIdx++);
101+
102+
PcoChunkMeta chunkMeta = readChunkMeta(chunkMetaBuf);
103+
PcoTansDecoder tans = PcoTansDecoder.build(chunkMeta.ansSizeLog(), chunkMeta.bins());
104+
105+
int nPages = chunkInfo.getPagesCount();
106+
for (int p = 0; p < nPages; p++) {
107+
int pageN = chunkInfo.getPages(p).getNValues();
108+
MemorySegment pageBuf = ctx.buffer(bufIdx++);
109+
110+
LeBitReader pageReader = new LeBitReader(pageBuf);
111+
int[] stateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING];
112+
for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) {
113+
stateIdxs[i] = (int) pageReader.readBits(chunkMeta.ansSizeLog());
114+
}
115+
pageReader.alignToByte();
116+
117+
tans.decodePage(pageReader, stateIdxs, pageN, out, outByteOffset);
118+
outByteOffset += (long) pageN * Long.BYTES;
119+
}
120+
}
121+
122+
// Convert U64 latents → I64: flip sign bit (from_latent_ordered for signed types)
123+
for (long i = 0; i < n; i++) {
124+
long byteOff = i * Long.BYTES;
125+
out.set(LE_LONG, byteOff, out.get(LE_LONG, byteOff) ^ Long.MIN_VALUE);
126+
}
127+
128+
return new LongArray(dtype, n, out, ArrayStats.empty());
129+
}
130+
131+
private static PcoChunkMeta readChunkMeta(MemorySegment buf) {
132+
LeBitReader r = new LeBitReader(buf);
133+
134+
int modeNibble = (int) r.readBits(4);
135+
if (modeNibble != 0) {
136+
throw new VortexException(EncodingId.VORTEX_PCO,
137+
"pco mode " + modeNibble + " not yet implemented (only Classic=0)");
138+
}
139+
int deltaNibble = (int) r.readBits(4);
140+
if (deltaNibble != 0) {
141+
throw new VortexException(EncodingId.VORTEX_PCO,
142+
"pco delta " + deltaNibble + " not yet implemented (only None=0)");
143+
}
144+
145+
// One primary latent variable for Classic + None delta.
146+
int ansSizeLog = (int) r.readBits(4);
147+
int nBins = (int) r.readBits(15);
148+
149+
PcoBin[] bins = new PcoBin[nBins];
150+
for (int b = 0; b < nBins; b++) {
151+
int weight = (int) r.readBits(ansSizeLog) + 1;
152+
long lower = r.readBits(64); // dtype_size = 64 for I64/U64
153+
int offsetBits = (int) r.readBits(BITS_TO_ENCODE_OFFSET_BITS_64);
154+
bins[b] = new PcoBin(weight, lower, offsetBits);
155+
}
156+
r.alignToByte(); // drain padding at end of chunk meta
157+
158+
return new PcoChunkMeta(ansSizeLog, bins);
67159
}
68160

69161
private static EncodingProtos.PcoMetadata parseMeta(DecodeContext ctx) {
@@ -74,7 +166,8 @@ private static EncodingProtos.PcoMetadata parseMeta(DecodeContext ctx) {
74166
try {
75167
return EncodingProtos.PcoMetadata.parseFrom(raw.duplicate());
76168
} catch (InvalidProtocolBufferException e) {
77-
throw new VortexException(EncodingId.VORTEX_PCO, "invalid PcoMetadata: " + e.getMessage());
169+
throw new VortexException(EncodingId.VORTEX_PCO,
170+
"invalid PcoMetadata: " + e.getMessage());
78171
}
79172
}
80173

@@ -92,4 +185,7 @@ private static void validateHeader(EncodingProtos.PcoMetadata meta) {
92185
}
93186
}
94187
}
188+
189+
private record PcoChunkMeta(int ansSizeLog, PcoBin[] bins) {
190+
}
95191
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package io.github.dfa1.vortex.encoding;
2+
3+
import java.lang.foreign.MemorySegment;
4+
import java.lang.foreign.ValueLayout;
5+
import java.nio.ByteOrder;
6+
7+
/// 4-way interleaved tANS decoder for one pco latent variable.
8+
///
9+
/// Build via {@link #build(int, PcoBin[])}; then call {@link #decodePage} once per page.
10+
/// Port of {@code pco/src/ans/spec.rs} (spread) and {@code pco/src/ans/decoding.rs} (nodes).
11+
final class PcoTansDecoder {
12+
13+
private static final ValueLayout.OfLong LE_LONG =
14+
ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
15+
16+
static final int BATCH_N = 256;
17+
static final int ANS_INTERLEAVING = 4;
18+
19+
// All arrays indexed by state index in [0, tableSize).
20+
private final int[] nextStateIdxBase; // = (symbolXs[sym] << bitsToRead) - tableSize
21+
private final int[] bitsToRead; // bits consumed from bit stream per ANS step
22+
private final int[] nodeOffsetBits; // offset bits for this bin (stored per-state for cache locality)
23+
private final long[] stateLowers; // bin.lower for each state
24+
25+
private PcoTansDecoder(int[] nextStateIdxBase, int[] bitsToRead,
26+
int[] nodeOffsetBits, long[] stateLowers) {
27+
this.nextStateIdxBase = nextStateIdxBase;
28+
this.bitsToRead = bitsToRead;
29+
this.nodeOffsetBits = nodeOffsetBits;
30+
this.stateLowers = stateLowers;
31+
}
32+
33+
/// Build the decode table from chunk latent-var metadata.
34+
///
35+
/// Port of {@code Spec::from_weights} + {@code Decoder::new} from pcodec.
36+
static PcoTansDecoder build(int ansSizeLog, PcoBin[] bins) {
37+
if (bins.length == 0) {
38+
// Degenerate: no bins → 1-state table, all offsets zero.
39+
return new PcoTansDecoder(new int[]{0}, new int[]{0}, new int[]{0}, new long[]{0L});
40+
}
41+
42+
int tableSize = 1 << ansSizeLog;
43+
int[] weights = new int[bins.length];
44+
for (int i = 0; i < bins.length; i++) {
45+
weights[i] = bins[i].weight();
46+
}
47+
48+
int[] stateSymbols = spreadStateSymbols(ansSizeLog, weights, tableSize);
49+
50+
int[] symbolXs = weights.clone();
51+
int[] nextStateIdxBase = new int[tableSize];
52+
int[] bitsToRead = new int[tableSize];
53+
int[] nodeOffsetBits = new int[tableSize];
54+
long[] stateLowers = new long[tableSize];
55+
56+
for (int s = 0; s < tableSize; s++) {
57+
int sym = stateSymbols[s];
58+
int xs = symbolXs[sym];
59+
int btr = Integer.numberOfLeadingZeros(xs) - Integer.numberOfLeadingZeros(tableSize);
60+
int nextBase = xs << btr;
61+
nextStateIdxBase[s] = nextBase - tableSize;
62+
bitsToRead[s] = btr;
63+
nodeOffsetBits[s] = sym < bins.length ? bins[sym].offsetBits() : 0;
64+
stateLowers[s] = sym < bins.length ? bins[sym].lower() : 0L;
65+
symbolXs[sym]++;
66+
}
67+
68+
return new PcoTansDecoder(nextStateIdxBase, bitsToRead, nodeOffsetBits, stateLowers);
69+
}
70+
71+
/// Port of {@code Spec::spread_state_symbols} from pcodec.
72+
///
73+
/// Spreads symbols across the table with a stride of ~3/5 * tableSize (odd).
74+
static int[] spreadStateSymbols(int ansSizeLog, int[] weights, int tableSize) {
75+
int[] stateSymbols = new int[tableSize];
76+
int stride = (3 * tableSize) / 5;
77+
if (stride % 2 == 0) {
78+
stride++;
79+
}
80+
int modMask = tableSize - 1;
81+
int step = 0;
82+
for (int sym = 0; sym < weights.length; sym++) {
83+
for (int k = 0; k < weights[sym]; k++) {
84+
stateSymbols[(stride * step) & modMask] = sym;
85+
step++;
86+
}
87+
}
88+
return stateSymbols;
89+
}
90+
91+
/// Decode {@code n} raw latent values (U64) from {@code reader} into {@code out}.
92+
///
93+
/// Caller must have already read 4 initial ANS state indices and called
94+
/// {@link LeBitReader#alignToByte()} before this call.
95+
/// {@code ansStateIdxs} is modified in place and not valid after return.
96+
void decodePage(LeBitReader reader, int[] ansStateIdxs, int n,
97+
MemorySegment out, long outByteOffset) {
98+
int remaining = n;
99+
long pos = outByteOffset;
100+
long[] batchLowers = new long[BATCH_N];
101+
int[] batchOffsetBits = new int[BATCH_N];
102+
103+
while (remaining > 0) {
104+
int batchN = Math.min(remaining, BATCH_N);
105+
106+
// Phase 1 — read all ANS bin indices for this batch (sequential bit stream).
107+
for (int i = 0; i < batchN; i++) {
108+
int si = ansStateIdxs[i % ANS_INTERLEAVING];
109+
batchLowers[i] = stateLowers[si];
110+
batchOffsetBits[i] = nodeOffsetBits[si];
111+
long ansVal = reader.readBits(bitsToRead[si]);
112+
ansStateIdxs[i % ANS_INTERLEAVING] = nextStateIdxBase[si] + (int) ansVal;
113+
}
114+
115+
// Phase 2 — read all offsets and reconstruct latents.
116+
for (int i = 0; i < batchN; i++) {
117+
long offset = reader.readBits(batchOffsetBits[i]);
118+
out.set(LE_LONG, pos, batchLowers[i] + offset);
119+
pos += Long.BYTES;
120+
}
121+
122+
remaining -= batchN;
123+
}
124+
}
125+
}

0 commit comments

Comments
 (0)