diff --git a/cli/src/main/java/io/github/dfa1/vortex/cli/tui/IoWorker.java b/cli/src/main/java/io/github/dfa1/vortex/cli/tui/IoWorker.java index c41f0add..429d6d01 100644 --- a/cli/src/main/java/io/github/dfa1/vortex/cli/tui/IoWorker.java +++ b/cli/src/main/java/io/github/dfa1/vortex/cli/tui/IoWorker.java @@ -1,6 +1,7 @@ package io.github.dfa1.vortex.cli.tui; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; @@ -57,23 +58,22 @@ public void submit(Runnable task) { /// @param task task to execute /// @throws InterruptedException if the calling thread is interrupted while waiting public void runAndAwait(Runnable task) throws InterruptedException { - Object signal = new Object(); - boolean[] done = {false}; - submit(() -> { + if (closed) { + return; + } + // Count down only after pending is decremented, so a caller that reads pending() + // immediately after this returns never observes the in-flight task still counted. + CountDownLatch done = new CountDownLatch(1); + pending.incrementAndGet(); + queue.add(() -> { try { task.run(); } finally { - synchronized (signal) { - done[0] = true; - signal.notifyAll(); - } + pending.decrementAndGet(); + done.countDown(); } }); - synchronized (signal) { - while (!done[0]) { - signal.wait(); - } - } + done.await(); } /// Number of submitted tasks that have not yet finished. diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoderTest.java index 5f51e05b..567f2fc7 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoderTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoderTest.java @@ -2,12 +2,15 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.encoding.DTypes; import io.github.dfa1.vortex.encoding.EncodingId; import io.github.dfa1.vortex.reader.ReadRegistry; import io.github.dfa1.vortex.reader.array.Array; +import io.github.dfa1.vortex.reader.array.ByteArray; import io.github.dfa1.vortex.reader.array.IntArray; import io.github.dfa1.vortex.reader.array.LongArray; +import io.github.dfa1.vortex.reader.array.ShortArray; import io.github.dfa1.vortex.reader.decode.ArrayNode; import io.github.dfa1.vortex.reader.decode.DecodeContext; import io.github.dfa1.vortex.reader.decode.PatchedEncodingDecoder; @@ -16,6 +19,7 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; import java.lang.foreign.Arena; @@ -23,6 +27,7 @@ import java.util.List; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; class PatchedEncodingEncoderTest { @@ -119,6 +124,47 @@ void encode_roundTrip_i32_multipleChunks() { } } + @ParameterizedTest + @EnumSource(value = PType.class, names = {"I8", "U8", "I16", "U16", "I32", "U32", "I64", "U64"}) + void encode_roundTrip_allIntegerPTypes(PType ptype) { + // Given: 100 small values + 1 outlier — exercises toLongs/PTypeIO.set for every width + int n = 100; + DType dtype = dtypeOf(ptype); + Object data = buildData(ptype, n); + + // When + EncodeResult result = SUT.encode(dtype, data, EncodeTestHelper.testCtx()); + Array decoded = decode(result, dtype, n); + + // Then + for (int i = 0; i < n; i++) { + assertThat(readLong(decoded, i, ptype)).as("ptype %s index %d", ptype, i) + .isEqualTo(expectedAt(ptype, i, n)); + } + } + + @Test + void encode_throws_onUnsupportedPType() { + // Given: a primitive float dtype slips past the instanceof check but has no integer mapping + double[] data = {1.0, 2.0, 3.0}; + + // When / Then + assertThatThrownBy(() -> SUT.encode(DTypes.F64, data, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class) + .hasMessageContaining("unsupported ptype"); + } + + @Test + void encode_throws_onNonPrimitiveDtype() { + // Given + long[] data = {1L, 2L, 3L}; + + // When / Then: non-primitive dtype is a programmer error on the encode path + assertThatThrownBy(() -> SUT.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected primitive"); + } + @ParameterizedTest @ValueSource(ints = {10, 100, 1023, 1024, 1025, 2048}) void encode_roundTrip_i64_variousLengths(int n) { @@ -212,5 +258,122 @@ void encodeCascade_notApplicable_whenEmpty() { // Then assertThat(step.applicable()).isFalse(); } + + @Test + void encodeCascade_notApplicable_whenNonPrimitiveDtype() { + // Given + long[] data = {1L, 2L, 3L}; + + // When + CascadeStep step = SUT.encodeCascade(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + + // Then + assertThat(step.applicable()).isFalse(); + } + + @Test + void encodeCascade_notApplicable_whenNoBitWidthGain() { + // Given: every value needs the full 64 bits, so no target width beats the raw cost + long[] data = new long[100]; + for (int i = 0; i < data.length; i++) { + data[i] = Long.MIN_VALUE | i; // bit 63 set → width 64 for all + } + + // When + CascadeStep step = SUT.encodeCascade(DTypes.I64, data, EncodeTestHelper.testCtx()); + + // Then + assertThat(step.applicable()).isFalse(); + } + + @ParameterizedTest + @EnumSource(value = PType.class, names = {"I8", "U8", "I16", "U16", "I32", "U32"}) + void encodeCascade_applicable_smallIntegerPTypes(PType ptype) { + // Given: 100 small values + 1 outlier — exercises fromLongs for every narrow width + int n = 100; + DType dtype = dtypeOf(ptype); + Object data = buildData(ptype, n); + + // When + CascadeStep step = SUT.encodeCascade(dtype, data, EncodeTestHelper.testCtx()); + + // Then + assertThat(step.applicable()).isTrue(); + assertThat(step.openChildren()).hasSize(4); + assertThat(step.openChildren().get(0).childDtype()).isEqualTo(dtype); + assertThat(step.openChildren().get(3).childDtype()).isEqualTo(dtype); + } + } + + private static DType dtypeOf(PType ptype) { + return new DType.Primitive(ptype, false); + } + + /// Builds an array of the given primitive type: small values everywhere except a + /// single high outlier at the midpoint, which forces a patch. + private static Object buildData(PType ptype, int n) { + return switch (ptype) { + case I8, U8 -> { + byte[] a = new byte[n]; + for (int i = 0; i < n; i++) { + a[i] = (byte) (i & 3); + } + a[n / 2] = (byte) (ptype == PType.U8 ? 0xC8 : 0x7F); + yield a; + } + case I16, U16 -> { + short[] a = new short[n]; + for (int i = 0; i < n; i++) { + a[i] = (short) (i & 3); + } + a[n / 2] = (short) (ptype == PType.U16 ? 0xFFFF : 0x7FFF); + yield a; + } + case I32, U32 -> { + int[] a = new int[n]; + for (int i = 0; i < n; i++) { + a[i] = i & 3; + } + a[n / 2] = ptype == PType.U32 ? 0xFFFF_FFFF : 0x7FFF_FFFF; + yield a; + } + case I64, U64 -> { + long[] a = new long[n]; + for (int i = 0; i < n; i++) { + a[i] = i & 3; + } + a[n / 2] = ptype == PType.U64 ? -1L : Long.MAX_VALUE; + yield a; + } + default -> throw new IllegalArgumentException("unsupported: " + ptype); + }; + } + + /// Expected raw element value for [#buildData], used to compare against the decoded array. + private static long expectedAt(PType ptype, int i, int n) { + if (i != n / 2) { + return i & 3; + } + return switch (ptype) { + case I8 -> 0x7F; + case U8 -> (byte) 0xC8; + case I16 -> 0x7FFF; + case U16 -> (short) 0xFFFF; + case I32 -> 0x7FFF_FFFF; + case U32 -> 0xFFFF_FFFF; + case I64 -> Long.MAX_VALUE; + case U64 -> -1L; + default -> throw new IllegalArgumentException("unsupported: " + ptype); + }; + } + + private static long readLong(Array array, int i, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((ByteArray) array).getByte(i); + case I16, U16 -> ((ShortArray) array).getShort(i); + case I32, U32 -> ((IntArray) array).getInt(i); + case I64, U64 -> ((LongArray) array).getLong(i); + default -> throw new IllegalArgumentException("unsupported: " + ptype); + }; } }