diff --git a/partitions.go b/partitions.go index 0feb6c377..334805609 100644 --- a/partitions.go +++ b/partitions.go @@ -371,7 +371,7 @@ func (ps *PartitionSpec) CompatibleWith(other *PartitionSpec) bool { return slices.EqualFunc(ps.fields, other.fields, func(left, right PartitionField) bool { return slices.Equal(left.SourceIDs, right.SourceIDs) && left.Name == right.Name && - left.Transform == right.Transform + left.Transform.Equals(right.Transform) }) } diff --git a/partitions_test.go b/partitions_test.go index 05b0546bf..e06764d3b 100644 --- a/partitions_test.go +++ b/partitions_test.go @@ -19,6 +19,7 @@ package iceberg_test import ( "encoding/json" + "slices" "testing" "github.com/apache/iceberg-go" @@ -26,6 +27,18 @@ import ( "github.com/stretchr/testify/require" ) +type nonComparableTransform struct { + iceberg.IdentityTransform + // values makes this transform non-comparable, which would have panicked with ==. + values []int +} + +func (t nonComparableTransform) Equals(other iceberg.Transform) bool { + o, ok := other.(nonComparableTransform) + + return ok && slices.Equal(t.values, o.values) +} + func TestPartitionSpec(t *testing.T) { assert.Equal(t, 999, iceberg.UnpartitionedSpec.LastAssignedFieldID()) @@ -61,6 +74,71 @@ func TestPartitionSpec(t *testing.T) { assert.Equal(t, 1002, spec3.LastAssignedFieldID()) } +func TestPartitionSpecCompatibleWithUsesTransformEquals(t *testing.T) { + spec := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceIDs: []int{1}, FieldID: 1001, Name: "id", + Transform: nonComparableTransform{values: []int{1, 2}}, + }) + sameTransformSpec := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceIDs: []int{1}, FieldID: 1002, Name: "id", + Transform: nonComparableTransform{values: []int{1, 2}}, + }) + differentTransformSpec := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceIDs: []int{1}, FieldID: 1003, Name: "id", + Transform: nonComparableTransform{values: []int{2, 3}}, + }) + + require.NotPanics(t, func() { + assert.True(t, spec.CompatibleWith(&sameTransformSpec)) + assert.False(t, spec.CompatibleWith(&differentTransformSpec)) + }) + + tests := []struct { + name string + left iceberg.Transform + right iceberg.Transform + compatible bool + }{ + { + name: "identical bucket transforms are compatible", + left: iceberg.BucketTransform{NumBuckets: 16}, + right: iceberg.BucketTransform{NumBuckets: 16}, + compatible: true, + }, + { + name: "different bucket transforms are incompatible", + left: iceberg.BucketTransform{NumBuckets: 16}, + right: iceberg.BucketTransform{NumBuckets: 32}, + compatible: false, + }, + { + name: "identical truncate transforms are compatible", + left: iceberg.TruncateTransform{Width: 4}, + right: iceberg.TruncateTransform{Width: 4}, + compatible: true, + }, + { + name: "different truncate transforms are incompatible", + left: iceberg.TruncateTransform{Width: 4}, + right: iceberg.TruncateTransform{Width: 8}, + compatible: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + left := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceIDs: []int{1}, FieldID: 1001, Name: "id", Transform: tt.left, + }) + right := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceIDs: []int{1}, FieldID: 1002, Name: "id", Transform: tt.right, + }) + + assert.Equal(t, tt.compatible, left.CompatibleWith(&right)) + }) + } +} + func TestUnpartitionedWithVoidField(t *testing.T) { spec := iceberg.NewPartitionSpec(iceberg.PartitionField{ SourceIDs: []int{3}, FieldID: 1001, Name: "void", Transform: iceberg.VoidTransform{},