diff --git a/table/sorting.go b/table/sorting.go index 09c198f10..9b833bc10 100644 --- a/table/sorting.go +++ b/table/sorting.go @@ -44,6 +44,7 @@ const ( var ( ErrInvalidSortOrderID = errors.New("invalid sort order ID") + ErrInvalidSortSourceID = errors.New("invalid sort source ID") ErrInvalidTransform = errors.New("invalid transform, must be a valid transform string or a transform object") ErrInvalidSortDirection = errors.New("invalid sort direction, must be 'asc' or 'desc'") ErrInvalidNullOrder = errors.New("invalid null order, must be 'nulls-first' or 'nulls-last'") @@ -133,10 +134,10 @@ func (s *SortField) UnmarshalJSON(b []byte) error { return fmt.Errorf("%w: failed to unmarshal sort field", err) } - if _, ok := raw["source-id"]; ok { - if _, ok := raw["source-ids"]; ok { - return errors.New("sort field cannot contain both source-id and source-ids") - } + _, hasSourceID := raw["source-id"] + _, hasSourceIDs := raw["source-ids"] + if hasSourceID && hasSourceIDs { + return errors.New("sort field cannot contain both source-id and source-ids") } aux := struct { @@ -154,10 +155,12 @@ func (s *SortField) UnmarshalJSON(b []byte) error { s.Direction = aux.Direction s.NullOrder = aux.NullOrder - if len(aux.SourceIDs) > 0 { + if hasSourceIDs { s.SourceIDs = aux.SourceIDs - } else { + } else if hasSourceID { s.SourceIDs = []int{aux.SourceID} + } else { + s.SourceIDs = nil } var err error @@ -180,6 +183,28 @@ func (s *SortField) UnmarshalJSON(b []byte) error { return nil } +func validateSortSourceID(id int) error { + if id < 0 { + return fmt.Errorf("source ID must be non-negative: %d", id) + } + + return nil +} + +func validateSortSourceIDs(ids []int) error { + if len(ids) == 0 { + return errors.New("source-ids must not be empty") + } + + for _, id := range ids { + if err := validateSortSourceID(id); err != nil { + return err + } + } + + return nil +} + const ( InitialSortOrderID = 1 UnsortedSortOrderID = 0 @@ -242,7 +267,7 @@ func (s *SortOrder) UnmarshalJSON(b []byte) error { aux.OrderID = InitialSortOrderID } - newOrder, err := NewSortOrder(aux.OrderID, aux.Fields) + newOrder, err := newSortOrder(aux.OrderID, aux.Fields, false) if err != nil { return err } @@ -256,8 +281,13 @@ func (s *SortOrder) UnmarshalJSON(b []byte) error { // // The orderID must be greater than or equal to 0. // If orderID is 0, no fields can be passed, this is equal to UnsortedSortOrder. -// Fields need to have non-nil Transform, valid Direction and NullOrder values. +// Fields need to have non-nil Transform, valid Direction and NullOrder values, +// and non-empty source IDs. func NewSortOrder(orderID int, fields []SortField) (SortOrder, error) { + return newSortOrder(orderID, fields, true) +} + +func newSortOrder(orderID int, fields []SortField, validateSourceIDs bool) (SortOrder, error) { if orderID < 0 { return SortOrder{}, fmt.Errorf("%w: sort order ID %d must be a non-negative integer", ErrInvalidSortOrderID, orderID) @@ -280,6 +310,12 @@ func NewSortOrder(orderID int, fields []SortField) (SortOrder, error) { if field.NullOrder != NullsFirst && field.NullOrder != NullsLast { return SortOrder{}, fmt.Errorf("%w: sort field at index %d", ErrInvalidNullOrder, idx) } + if validateSourceIDs { + if err := validateSortSourceIDs(field.SourceIDs); err != nil { + return SortOrder{}, fmt.Errorf("%w: sort field at index %d has invalid source IDs: %v", + ErrInvalidSortSourceID, idx, err) + } + } } return SortOrder{orderID, fields}, nil @@ -295,21 +331,32 @@ func (s *SortOrder) CheckCompatibility(schema *iceberg.Schema) error { } for _, field := range s.fields { - f, ok := schema.FindFieldByID(field.SourceID()) - if !ok { - return fmt.Errorf("sort field with source id %d not found in schema", field.SourceID()) + if field.Transform == nil { + return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID()) } - if _, ok := f.Type.(iceberg.PrimitiveType); !ok { - return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type()) + if err := validateSortSourceIDs(field.SourceIDs); err != nil { + return fmt.Errorf("%w: sort field has invalid source IDs: %v", ErrInvalidSortSourceID, err) } - if field.Transform == nil { - return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID()) + var firstField iceberg.NestedField + for idx, sourceID := range field.SourceIDs { + f, ok := schema.FindFieldByID(sourceID) + if !ok { + return fmt.Errorf("sort field with source id %d not found in schema", sourceID) + } + + if _, ok := f.Type.(iceberg.PrimitiveType); !ok { + return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type()) + } + + if idx == 0 { + firstField = f + } } - if !field.Transform.CanTransform(f.Type) { - return fmt.Errorf("invalid source type %s for transform %s", f.Type.Type(), field.Transform) + if !field.Transform.CanTransform(firstField.Type) { + return fmt.Errorf("invalid source type %s for transform %s", firstField.Type.Type(), field.Transform) } } diff --git a/table/sorting_test.go b/table/sorting_test.go index baeae0936..24cc7fbef 100644 --- a/table/sorting_test.go +++ b/table/sorting_test.go @@ -64,6 +64,60 @@ func TestNewSortOrderRejectsNilTransform(t *testing.T) { require.Error(t, err) assert.ErrorIs(t, err, table.ErrInvalidTransform) assert.Contains(t, err.Error(), "has no transform") + + _, err = table.NewSortOrder(1, []table.SortField{{ + NullOrder: table.NullsFirst, + Direction: table.SortASC, + }}) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidTransform) + assert.Contains(t, err.Error(), "has no transform") +} + +func TestNewSortOrderRejectsInvalidSourceIDs(t *testing.T) { + for _, tt := range []struct { + name string + sourceIDs []int + }{ + { + name: "missing", + sourceIDs: nil, + }, + { + name: "empty", + sourceIDs: []int{}, + }, + { + name: "negative", + sourceIDs: []int{-1}, + }, + { + name: "multi arg with negative", + sourceIDs: []int{1, -1}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + _, err := table.NewSortOrder(1, []table.SortField{{ + SourceIDs: tt.sourceIDs, + Transform: iceberg.IdentityTransform{}, + NullOrder: table.NullsFirst, + Direction: table.SortASC, + }}) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + } +} + +func TestNewSortOrderAcceptsZeroSourceID(t *testing.T) { + sortOrder, err := table.NewSortOrder(1, []table.SortField{{ + SourceIDs: []int{0}, + Transform: iceberg.IdentityTransform{}, + NullOrder: table.NullsFirst, + Direction: table.SortASC, + }}) + require.NoError(t, err) + assert.Equal(t, 1, sortOrder.Len()) } func TestNewSortOrderAcceptsValidTransform(t *testing.T) { @@ -92,6 +146,68 @@ func TestSortOrderCheckCompatibilityWithValidTransform(t *testing.T) { require.NoError(t, sortOrder.CheckCompatibility(schema)) } +func TestSortOrderCheckCompatibilityAcceptsZeroSourceIDInSchema(t *testing.T) { + schema := iceberg.NewSchema(0, + iceberg.NestedField{ID: 0, Name: "id", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + ) + sortOrder, err := table.NewSortOrder(1, []table.SortField{{ + SourceIDs: []int{0}, + Transform: iceberg.IdentityTransform{}, + NullOrder: table.NullsFirst, + Direction: table.SortASC, + }}) + require.NoError(t, err) + require.NoError(t, sortOrder.CheckCompatibility(schema)) +} + +func TestSortOrderCheckCompatibilityRejectsInvalidSourceIDs(t *testing.T) { + schema := iceberg.NewSchema(0, + iceberg.NestedField{ID: 19, Name: "id", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + ) + for _, tt := range []struct { + name string + jsonData string + wantErr string + wantInvalidSourceID bool + }{ + { + name: "missing", + jsonData: `{"order-id": 1, "fields": [{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + wantErr: "source-ids must not be empty", + wantInvalidSourceID: true, + }, + { + name: "empty", + jsonData: `{"order-id": 1, "fields": [{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + wantErr: "source-ids must not be empty", + wantInvalidSourceID: true, + }, + { + name: "negative", + jsonData: `{"order-id": 1, "fields": [{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + wantErr: "source ID must be non-negative: -1", + wantInvalidSourceID: true, + }, + { + name: "multi arg with nonexistent source id", + jsonData: `{"order-id": 1, "fields": [{"source-ids": [19, 999], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + wantErr: "sort field with source id 999 not found in schema", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var sortOrder table.SortOrder + require.NoError(t, json.Unmarshal([]byte(tt.jsonData), &sortOrder)) + + err := sortOrder.CheckCompatibility(schema) + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + if tt.wantInvalidSourceID { + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + } + }) + } +} + func TestUnmarshalSortOrderDefaults(t *testing.T) { var order table.SortOrder require.NoError(t, json.Unmarshal([]byte(`{"fields": []}`), &order)) @@ -101,6 +217,47 @@ func TestUnmarshalSortOrderDefaults(t *testing.T) { assert.Equal(t, table.InitialSortOrderID, order.OrderID()) } +func TestUnmarshalSortOrderAllowsLenientSourceIDs(t *testing.T) { + for _, tt := range []struct { + name string + jsonData string + sourceIDs []int + }{ + { + name: "missing", + jsonData: `{"order-id": 1, "fields": [{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + sourceIDs: nil, + }, + { + name: "zero", + jsonData: `{"order-id": 1, "fields": [{"source-id": 0, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + sourceIDs: []int{0}, + }, + { + name: "negative", + jsonData: `{"order-id": 1, "fields": [{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + sourceIDs: []int{-1}, + }, + { + name: "empty multi arg", + jsonData: `{"order-id": 1, "fields": [{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`, + sourceIDs: []int{}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var order table.SortOrder + require.NoError(t, json.Unmarshal([]byte(tt.jsonData), &order)) + require.Equal(t, 1, order.Len()) + + var field table.SortField + for _, sortField := range order.Fields() { + field = sortField + } + assert.Equal(t, tt.sourceIDs, field.SourceIDs) + }) + } +} + func TestUnmarshalInvalidSortOrderID(t *testing.T) { var order table.SortOrder require.ErrorContains(t, json.Unmarshal([]byte(`{"order-id": 0, "fields": [{"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`), &order), "invalid sort order ID: sort order ID 0 is reserved for unsorted order") @@ -169,6 +326,47 @@ func TestSortFieldMultiArgSourceIDs(t *testing.T) { assert.Contains(t, err.Error(), "cannot contain both source-id and source-ids") }) + t.Run("unmarshal allows source ids through parse", func(t *testing.T) { + for _, tt := range []struct { + name string + jsonData string + sourceIDs []int + }{ + { + name: "missing", + jsonData: `{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}`, + sourceIDs: nil, + }, + { + name: "zero source-id", + jsonData: `{"source-id": 0, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`, + sourceIDs: []int{0}, + }, + { + name: "negative source-id", + jsonData: `{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`, + sourceIDs: []int{-1}, + }, + { + name: "empty source-ids", + jsonData: `{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`, + sourceIDs: []int{}, + }, + { + name: "source-ids with zero", + jsonData: `{"source-ids": [1, 0], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`, + sourceIDs: []int{1, 0}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var field table.SortField + err := json.Unmarshal([]byte(tt.jsonData), &field) + require.NoError(t, err) + assert.Equal(t, tt.sourceIDs, field.SourceIDs) + }) + } + }) + t.Run("marshal multi-arg round-trip", func(t *testing.T) { field := table.SortField{ SourceIDs: []int{2, 3},