Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 64 additions & 17 deletions table/sorting.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (

var (
ErrInvalidSortOrderID = errors.New("invalid sort order ID")
ErrInvalidSortSourceID = errors.New("invalid sort source ID")
ErrInvalidTransform = errors.New("invalid transform, must be a valid transform string or a transform object")
ErrInvalidSortDirection = errors.New("invalid sort direction, must be 'asc' or 'desc'")
ErrInvalidNullOrder = errors.New("invalid null order, must be 'nulls-first' or 'nulls-last'")
Expand Down Expand Up @@ -133,10 +134,10 @@ func (s *SortField) UnmarshalJSON(b []byte) error {
return fmt.Errorf("%w: failed to unmarshal sort field", err)
}

if _, ok := raw["source-id"]; ok {
if _, ok := raw["source-ids"]; ok {
return errors.New("sort field cannot contain both source-id and source-ids")
}
_, hasSourceID := raw["source-id"]
_, hasSourceIDs := raw["source-ids"]
if hasSourceID && hasSourceIDs {
return errors.New("sort field cannot contain both source-id and source-ids")
}

aux := struct {
Expand All @@ -154,10 +155,12 @@ func (s *SortField) UnmarshalJSON(b []byte) error {
s.Direction = aux.Direction
s.NullOrder = aux.NullOrder

if len(aux.SourceIDs) > 0 {
if hasSourceIDs {
s.SourceIDs = aux.SourceIDs
} else {
} else if hasSourceID {
s.SourceIDs = []int{aux.SourceID}
} else {
s.SourceIDs = nil
}

var err error
Expand All @@ -180,6 +183,28 @@ func (s *SortField) UnmarshalJSON(b []byte) error {
return nil
}

func validateSortSourceID(id int) error {
if id < 0 {
return fmt.Errorf("source ID must be non-negative: %d", id)
}

return nil
}

func validateSortSourceIDs(ids []int) error {
if len(ids) == 0 {
return errors.New("source-ids must not be empty")
}

for _, id := range ids {
if err := validateSortSourceID(id); err != nil {
return err
}
}

return nil
}

const (
InitialSortOrderID = 1
UnsortedSortOrderID = 0
Expand Down Expand Up @@ -242,7 +267,7 @@ func (s *SortOrder) UnmarshalJSON(b []byte) error {
aux.OrderID = InitialSortOrderID
}

newOrder, err := NewSortOrder(aux.OrderID, aux.Fields)
newOrder, err := newSortOrder(aux.OrderID, aux.Fields, false)
if err != nil {
return err
}
Expand All @@ -256,8 +281,13 @@ func (s *SortOrder) UnmarshalJSON(b []byte) error {
//
// The orderID must be greater than or equal to 0.
// If orderID is 0, no fields can be passed, this is equal to UnsortedSortOrder.
// Fields need to have non-nil Transform, valid Direction and NullOrder values.
// Fields need to have non-nil Transform, valid Direction and NullOrder values,
// and non-empty source IDs.
func NewSortOrder(orderID int, fields []SortField) (SortOrder, error) {
return newSortOrder(orderID, fields, true)
}

func newSortOrder(orderID int, fields []SortField, validateSourceIDs bool) (SortOrder, error) {
if orderID < 0 {
return SortOrder{}, fmt.Errorf("%w: sort order ID %d must be a non-negative integer",
ErrInvalidSortOrderID, orderID)
Expand All @@ -280,6 +310,12 @@ func NewSortOrder(orderID int, fields []SortField) (SortOrder, error) {
if field.NullOrder != NullsFirst && field.NullOrder != NullsLast {
return SortOrder{}, fmt.Errorf("%w: sort field at index %d", ErrInvalidNullOrder, idx)
}
if validateSourceIDs {
if err := validateSortSourceIDs(field.SourceIDs); err != nil {
return SortOrder{}, fmt.Errorf("%w: sort field at index %d has invalid source IDs: %v",
ErrInvalidSortSourceID, idx, err)
}
}
}

return SortOrder{orderID, fields}, nil
Expand All @@ -295,21 +331,32 @@ func (s *SortOrder) CheckCompatibility(schema *iceberg.Schema) error {
}

for _, field := range s.fields {
f, ok := schema.FindFieldByID(field.SourceID())
if !ok {
return fmt.Errorf("sort field with source id %d not found in schema", field.SourceID())
if field.Transform == nil {
return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID())
}

if _, ok := f.Type.(iceberg.PrimitiveType); !ok {
return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type())
if err := validateSortSourceIDs(field.SourceIDs); err != nil {
return fmt.Errorf("%w: sort field has invalid source IDs: %v", ErrInvalidSortSourceID, err)
}

if field.Transform == nil {
return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID())
var firstField iceberg.NestedField
for idx, sourceID := range field.SourceIDs {
f, ok := schema.FindFieldByID(sourceID)
if !ok {
return fmt.Errorf("sort field with source id %d not found in schema", sourceID)
}

if _, ok := f.Type.(iceberg.PrimitiveType); !ok {
return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type())
}

if idx == 0 {
firstField = f
}
}

if !field.Transform.CanTransform(f.Type) {
return fmt.Errorf("invalid source type %s for transform %s", f.Type.Type(), field.Transform)
if !field.Transform.CanTransform(firstField.Type) {
return fmt.Errorf("invalid source type %s for transform %s", firstField.Type.Type(), field.Transform)
}
}

Expand Down
198 changes: 198 additions & 0 deletions table/sorting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,60 @@ func TestNewSortOrderRejectsNilTransform(t *testing.T) {
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidTransform)
assert.Contains(t, err.Error(), "has no transform")

_, err = table.NewSortOrder(1, []table.SortField{{
NullOrder: table.NullsFirst,
Direction: table.SortASC,
}})
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidTransform)
assert.Contains(t, err.Error(), "has no transform")
}

func TestNewSortOrderRejectsInvalidSourceIDs(t *testing.T) {
for _, tt := range []struct {
name string
sourceIDs []int
}{
{
name: "missing",
sourceIDs: nil,
},
{
name: "empty",
sourceIDs: []int{},
},
{
name: "negative",
sourceIDs: []int{-1},
},
{
name: "multi arg with negative",
sourceIDs: []int{1, -1},
},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := table.NewSortOrder(1, []table.SortField{{
SourceIDs: tt.sourceIDs,
Transform: iceberg.IdentityTransform{},
NullOrder: table.NullsFirst,
Direction: table.SortASC,
}})
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})
}
}

func TestNewSortOrderAcceptsZeroSourceID(t *testing.T) {
sortOrder, err := table.NewSortOrder(1, []table.SortField{{
SourceIDs: []int{0},
Transform: iceberg.IdentityTransform{},
NullOrder: table.NullsFirst,
Direction: table.SortASC,
}})
require.NoError(t, err)
assert.Equal(t, 1, sortOrder.Len())
}

func TestNewSortOrderAcceptsValidTransform(t *testing.T) {
Expand Down Expand Up @@ -92,6 +146,68 @@ func TestSortOrderCheckCompatibilityWithValidTransform(t *testing.T) {
require.NoError(t, sortOrder.CheckCompatibility(schema))
}

func TestSortOrderCheckCompatibilityAcceptsZeroSourceIDInSchema(t *testing.T) {
schema := iceberg.NewSchema(0,
iceberg.NestedField{ID: 0, Name: "id", Type: iceberg.PrimitiveTypes.Int64, Required: true},
)
sortOrder, err := table.NewSortOrder(1, []table.SortField{{
SourceIDs: []int{0},
Transform: iceberg.IdentityTransform{},
NullOrder: table.NullsFirst,
Direction: table.SortASC,
}})
require.NoError(t, err)
require.NoError(t, sortOrder.CheckCompatibility(schema))
}

func TestSortOrderCheckCompatibilityRejectsInvalidSourceIDs(t *testing.T) {
schema := iceberg.NewSchema(0,
iceberg.NestedField{ID: 19, Name: "id", Type: iceberg.PrimitiveTypes.Int64, Required: true},
)
for _, tt := range []struct {
name string
jsonData string
wantErr string
wantInvalidSourceID bool
}{
{
name: "missing",
jsonData: `{"order-id": 1, "fields": [{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
wantErr: "source-ids must not be empty",
wantInvalidSourceID: true,
},
{
name: "empty",
jsonData: `{"order-id": 1, "fields": [{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
wantErr: "source-ids must not be empty",
wantInvalidSourceID: true,
},
{
name: "negative",
jsonData: `{"order-id": 1, "fields": [{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
wantErr: "source ID must be non-negative: -1",
wantInvalidSourceID: true,
},
{
name: "multi arg with nonexistent source id",
jsonData: `{"order-id": 1, "fields": [{"source-ids": [19, 999], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
wantErr: "sort field with source id 999 not found in schema",
},
} {
t.Run(tt.name, func(t *testing.T) {
var sortOrder table.SortOrder
require.NoError(t, json.Unmarshal([]byte(tt.jsonData), &sortOrder))

err := sortOrder.CheckCompatibility(schema)
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
if tt.wantInvalidSourceID {
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
}
})
}
}

func TestUnmarshalSortOrderDefaults(t *testing.T) {
var order table.SortOrder
require.NoError(t, json.Unmarshal([]byte(`{"fields": []}`), &order))
Expand All @@ -101,6 +217,47 @@ func TestUnmarshalSortOrderDefaults(t *testing.T) {
assert.Equal(t, table.InitialSortOrderID, order.OrderID())
}

func TestUnmarshalSortOrderAllowsLenientSourceIDs(t *testing.T) {
for _, tt := range []struct {
name string
jsonData string
sourceIDs []int
}{
{
name: "missing",
jsonData: `{"order-id": 1, "fields": [{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
sourceIDs: nil,
},
{
name: "zero",
jsonData: `{"order-id": 1, "fields": [{"source-id": 0, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
sourceIDs: []int{0},
},
{
name: "negative",
jsonData: `{"order-id": 1, "fields": [{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
sourceIDs: []int{-1},
},
{
name: "empty multi arg",
jsonData: `{"order-id": 1, "fields": [{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`,
sourceIDs: []int{},
},
} {
t.Run(tt.name, func(t *testing.T) {
var order table.SortOrder
require.NoError(t, json.Unmarshal([]byte(tt.jsonData), &order))
require.Equal(t, 1, order.Len())

var field table.SortField
for _, sortField := range order.Fields() {
field = sortField
}
assert.Equal(t, tt.sourceIDs, field.SourceIDs)
})
}
}

func TestUnmarshalInvalidSortOrderID(t *testing.T) {
var order table.SortOrder
require.ErrorContains(t, json.Unmarshal([]byte(`{"order-id": 0, "fields": [{"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`), &order), "invalid sort order ID: sort order ID 0 is reserved for unsorted order")
Expand Down Expand Up @@ -169,6 +326,47 @@ func TestSortFieldMultiArgSourceIDs(t *testing.T) {
assert.Contains(t, err.Error(), "cannot contain both source-id and source-ids")
})

t.Run("unmarshal allows source ids through parse", func(t *testing.T) {
for _, tt := range []struct {
name string
jsonData string
sourceIDs []int
}{
{
name: "missing",
jsonData: `{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}`,
sourceIDs: nil,
},
{
name: "zero source-id",
jsonData: `{"source-id": 0, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`,
sourceIDs: []int{0},
},
{
name: "negative source-id",
jsonData: `{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`,
sourceIDs: []int{-1},
},
{
name: "empty source-ids",
jsonData: `{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`,
sourceIDs: []int{},
},
{
name: "source-ids with zero",
jsonData: `{"source-ids": [1, 0], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`,
sourceIDs: []int{1, 0},
},
} {
t.Run(tt.name, func(t *testing.T) {
var field table.SortField
err := json.Unmarshal([]byte(tt.jsonData), &field)
require.NoError(t, err)
assert.Equal(t, tt.sourceIDs, field.SourceIDs)
})
}
})

t.Run("marshal multi-arg round-trip", func(t *testing.T) {
field := table.SortField{
SourceIDs: []int{2, 3},
Expand Down
Loading