From 78c7091a2fd8d660ccbec629e68de20f67f2bbe5 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 01:12:53 +0100 Subject: [PATCH 01/37] [SPARK-52780] Add ToLocalIterator and Arrow Record Streaming --- spark/client/base/base.go | 1 + spark/client/client.go | 113 +++++++ spark/client/client_test.go | 498 +++++++++++++++++++++++++++- spark/sql/dataframe.go | 12 + spark/sql/types/arrow.go | 6 + spark/sql/types/arrow_test.go | 57 ++++ spark/sql/types/rowiterator.go | 162 +++++++++ spark/sql/types/rowiterator_test.go | 220 ++++++++++++ 8 files changed, 1065 insertions(+), 4 deletions(-) create mode 100644 spark/sql/types/rowiterator.go create mode 100644 spark/sql/types/rowiterator_test.go diff --git a/spark/client/base/base.go b/spark/client/base/base.go index 10788ed..16f4a00 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -48,5 +48,6 @@ type SparkConnectClient interface { type ExecuteResponseStream interface { ToTable() (*types.StructType, arrow.Table, error) + ToRecordBatches(ctx context.Context) (<-chan arrow.Record, <-chan error, *types.StructType) Properties() map[string]any } diff --git a/spark/client/client.go b/spark/client/client.go index dfcc79e..68af201 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -434,6 +434,119 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } +func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.Record, <-chan error, *types.StructType) { + recordChan := make(chan arrow.Record, 10) + errorChan := make(chan error, 1) + + go func() { + defer func() { + // Ensure channels are always closed to prevent goroutine leaks + close(recordChan) + close(errorChan) + }() + + // Explicitly needed when tracking re-attachable execution. + c.done = false + + for { + // Check for context cancellation before each iteration + select { + case <-ctx.Done(): + // Context cancelled - send the error and return immediately + select { + case errorChan <- ctx.Err(): + default: + // Channel might be full, but we're exiting anyway + } + return + default: + // Continue with normal processing + } + + resp, err := c.responseStream.Recv() + + // Check for context cancellation after potentially blocking operations + select { + case <-ctx.Done(): + select { + case errorChan <- ctx.Err(): + default: + } + return + default: + } + + // EOF is received when the last message has been processed and the stream + // finished normally. + if errors.Is(err, io.EOF) { + return + } + + // If the error was not EOF, there might be another error. + if se := sparkerrors.FromRPCError(err); se != nil { + select { + case errorChan <- sparkerrors.WithType(se, sparkerrors.ExecutionError): + case <-ctx.Done(): + return + } + return + } + + // Check if the response has already the schema set and if yes, convert + // the proto DataType to a StructType. + if resp.Schema != nil && c.schema == nil { + c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema) + if err != nil { + select { + case errorChan <- sparkerrors.WithType(err, sparkerrors.ExecutionError): + case <-ctx.Done(): + return + } + return + } + } + + switch x := resp.ResponseType.(type) { + case *proto.ExecutePlanResponse_SqlCommandResult_: + if val := x.SqlCommandResult.GetRelation(); val != nil { + c.properties["sql_command_result"] = val + } + + case *proto.ExecutePlanResponse_ArrowBatch_: + // This is what we want - stream the record batch + record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema) + if err != nil { + select { + case errorChan <- err: + case <-ctx.Done(): + return + } + return + } + + // Try to send the record, but respect context cancellation + select { + case recordChan <- record: + // Successfully sent + case <-ctx.Done(): + // Context cancelled while trying to send - release the record and exit + record.Release() + return + } + + case *proto.ExecutePlanResponse_ResultComplete_: + c.done = true + return + + default: + // Explicitly ignore messages that we cannot process at the moment. + } + } + }() + + return recordChan, errorChan, c.schema +} + func NewExecuteResponseStream( responseClient proto.SparkConnectService_ExecutePlanClient, sessionId string, diff --git a/spark/client/client_test.go b/spark/client/client_test.go index 2ea107f..20300de 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -16,17 +16,23 @@ package client_test import ( + "bytes" "context" - "testing" - - "github.com/google/uuid" - + "errors" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" proto "github.com/apache/spark-connect-go/v40/internal/generated" "github.com/apache/spark-connect-go/v40/spark/client" "github.com/apache/spark-connect-go/v40/spark/client/testutils" "github.com/apache/spark-connect-go/v40/spark/mocks" "github.com/apache/spark-connect-go/v40/spark/sparkerrors" + "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" ) func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { @@ -108,3 +114,487 @@ func Test_Execute_SchemaParsingFails(t *testing.T) { _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) assert.ErrorIs(t, err, sparkerrors.ExecutionError) } + +func TestToRecordBatches_SchemaExtraction(t *testing.T) { + // Verify schema is properly extracted and returned + ctx := context.Background() + + // Arrange: Create a response with only schema (no data) + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "test_column", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 'test'")) + require.NoError(t, err) + + _, _, schema := stream.ToRecordBatches(ctx) + + // Assert: Schema should be returned immediately (not populated by goroutine) + // Note: In the current implementation, schema is returned as nil and populated + // inside the goroutine. This might be a design decision to test. + assert.Nil(t, schema, "Schema is populated asynchronously in the goroutine") +} + +func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { + // Verify channel closure when no arrow batches are sent + ctx := context.Background() + + // Arrange: Only schema and done responses, no arrow batches + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + &mocks.ExecutePlanResponseWithSchema, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + // Assert: Channels should close without sending any records + recordsReceived := 0 + errorsReceived := 0 + + timeout := time.After(100 * time.Millisecond) + done := false + + for !done { + select { + case _, ok := <-recordChan: + if ok { + recordsReceived++ + } else { + done = true + } + case <-errorChan: + errorsReceived++ + case <-timeout: + t.Fatal("Test timed out - channels not closed") + } + } + + assert.Equal(t, 0, recordsReceived, "No records should be sent when no arrow batches present") + assert.Equal(t, 0, errorsReceived, "No errors should occur") +} + +func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { + // Verify arrow batch data is correctly streamed + ctx := context.Background() + + // Arrange: Create test arrow data + arrowData := createTestArrowBatch(t, []string{"value1", "value2", "value3"}) + + arrowBatch := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: arrowData, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + &mocks.ExecutePlanResponseWithSchema, + arrowBatch, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + // Assert: Verify we receive exactly one record with correct data + records := collectRecords(t, recordChan, errorChan) + + require.Len(t, records, 1, "Should receive exactly one record") + + record := records[0] + assert.Equal(t, int64(3), record.NumRows(), "Record should have 3 rows") + assert.Equal(t, int64(1), record.NumCols(), "Record should have 1 column") + + // Verify the actual data + col := record.Column(0).(*array.String) + assert.Equal(t, "value1", col.Value(0)) + assert.Equal(t, "value2", col.Value(1)) + assert.Equal(t, "value3", col.Value(2)) +} + +func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { + // Verify multiple arrow batches are streamed in order + ctx := context.Background() + + // Arrange: Create multiple arrow batches + batch1 := createTestArrowBatch(t, []string{"batch1_row1", "batch1_row2"}) + batch2 := createTestArrowBatch(t, []string{"batch2_row1", "batch2_row2"}) + + arrowBatch1 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: batch1, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + arrowBatch2 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: batch2, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + &mocks.ExecutePlanResponseWithSchema, + arrowBatch1, + arrowBatch2, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + // Assert: Verify we receive both records in order + records := collectRecords(t, recordChan, errorChan) + + require.Len(t, records, 2, "Should receive exactly two records") + + // Verify first batch + col1 := records[0].Column(0).(*array.String) + assert.Equal(t, "batch1_row1", col1.Value(0)) + assert.Equal(t, "batch1_row2", col1.Value(1)) + + // Verify second batch + col2 := records[1].Column(0).(*array.String) + assert.Equal(t, "batch2_row1", col2.Value(0)) + assert.Equal(t, "batch2_row2", col2.Value(1)) +} + +func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { + // Verify context cancellation stops streaming + + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + + // Create mock responses - just a simple schema response + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col0", + DataType: &proto.DataType{ + Kind: &proto.DataType_Integer_{ + Integer: &proto.DataType_Integer{}, + }, + }, + Nullable: true, + }, + }, + }, + }, + }, + }, + } + + // Create client with schema response followed by immediate done and EOF + // This ensures we don't get index out of range errors + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + // Execute the plan + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + // Start streaming + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + // Cancel the context immediately + // This should cause the goroutine to exit when it checks the context + cancel() + + // Wait for either completion or error + timeout := time.After(100 * time.Millisecond) + + for { + select { + case _, ok := <-recordChan: + if !ok { + // Channel closed normally - this is also acceptable + // as the context cancellation might happen after processing + return + } + case err := <-errorChan: + // We got an error - verify it's context cancellation + assert.ErrorIs(t, err, context.Canceled) + return + case <-timeout: + // If we timeout without getting either channel closure or error, + // the test passes as the cancellation might have happened after + // all responses were processed + return + } + } +} + +func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { + // Verify RPC errors are properly propagated + ctx := context.Background() + + // Arrange: Create a response that will return an RPC error + expectedError := errors.New("simulated RPC error") + errorResponse := &mocks.MockResponse{ + Err: expectedError, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + &mocks.ExecutePlanResponseWithSchema, + errorResponse) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + // Assert: Should receive the RPC error + select { + case err := <-errorChan: + assert.Error(t, err) + assert.Contains(t, err.Error(), "simulated RPC error") + case <-recordChan: + t.Fatal("Should not receive any records when RPC error occurs") + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected RPC error") + } +} + +// Test 7: Verify session validation +func TestToRecordBatches_SessionValidation(t *testing.T) { + ctx := context.Background() + + // Arrange: Create response with wrong session ID + wrongSessionResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: "wrong-session-id", + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col0", + DataType: &proto.DataType{ + Kind: &proto.DataType_Integer_{ + Integer: &proto.DataType_Integer{}, + }, + }, + Nullable: true, + }, + }, + }, + }, + }, + }, + } + + // Need to provide EOF to prevent index out of range + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + wrongSessionResponse, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + _, errorChan, _ := stream.ToRecordBatches(ctx) + + // Assert: Should receive session validation error + select { + case err := <-errorChan: + assert.Error(t, err) + assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected session validation error") + } +} + +func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { + // Verify SQL command results are captured in properties + ctx := context.Background() + + // Arrange: Create response with SQL command result + sqlResultResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "test query"}, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + sqlResultResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("test query")) + require.NoError(t, err) + + // Consume the stream to ensure properties are set + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + _ = collectRecords(t, recordChan, errorChan) + + // Assert: Properties should contain the SQL command result + // Note: We need access to the stream's Properties() method + // This might require modifying the test or the interface + // For now, this test validates that the stream processes SQL command results without error +} + +func TestToRecordBatches_EOFHandling(t *testing.T) { + // Verify proper handling of EOF + ctx := context.Background() + + // Arrange: Only EOF response + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + &mocks.ExecutePlanResponseEOF) + + // Act + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + // Assert: Should close channels without error + timeout := time.After(100 * time.Millisecond) + recordClosed := false + errorReceived := false + + for !recordClosed { + select { + case _, ok := <-recordChan: + if !ok { + recordClosed = true + } + case <-errorChan: + errorReceived = true + case <-timeout: + t.Fatal("Test timed out") + } + } + + assert.True(t, recordClosed, "Record channel should be closed") + assert.False(t, errorReceived, "No error should be received for EOF") +} + +// Helper function to create test arrow batch data +func createTestArrowBatch(t *testing.T, values []string) []byte { + t.Helper() + + arrowFields := []arrow.Field{ + {Name: "col", Type: arrow.BinaryTypes.String}, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + stringBuilder := recordBuilder.Field(0).(*array.StringBuilder) + for _, v := range values { + stringBuilder.Append(v) + } + + record := recordBuilder.NewRecord() + defer record.Release() + + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + err := arrowWriter.Write(record) + require.NoError(t, err) + err = arrowWriter.Close() + require.NoError(t, err) + + return buf.Bytes() +} + +// Helper function to collect all records from channels +func collectRecords(t *testing.T, recordChan <-chan arrow.Record, errorChan <-chan error) []arrow.Record { + t.Helper() + + var records []arrow.Record + timeout := time.After(100 * time.Millisecond) + + for { + select { + case record, ok := <-recordChan: + if !ok { + return records + } + if record != nil { + records = append(records, record) + } + case err := <-errorChan: + t.Fatalf("Unexpected error: %v", err) + case <-timeout: + t.Fatal("Test timed out collecting records") + } + } +} diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index a2032ba..344800d 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -213,6 +213,7 @@ type DataFrame interface { Take(ctx context.Context, limit int32) ([]types.Row, error) // ToArrow returns the Arrow representation of the DataFrame. ToArrow(ctx context.Context) (*arrow.Table, error) + ToLocalIterator(ctx context.Context) (types.RowIterator, error) // Union is an alias for UnionAll Union(ctx context.Context, other DataFrame) DataFrame // UnionAll returns a new DataFrame containing union of rows in this and another DataFrame. @@ -935,6 +936,17 @@ func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) { return &table, nil } +func (df *dataFrameImpl) ToLocalIterator(ctx context.Context) (types.RowIterator, error) { + responseClient, err := df.session.client.ExecutePlan(ctx, df.createPlan()) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) + } + + recordChan, errorChan, schema := responseClient.ToRecordBatches(ctx) + + return types.NewRowIterator(recordChan, errorChan, schema), nil +} + func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) DataFrame { otherDf := other.(*dataFrameImpl) isAll := true diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index 8a349d7..c8e6908 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -68,6 +68,12 @@ func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { return result, nil } +func ReadArrowRecordToRows(record arrow.Record) ([]Row, error) { + table := array.NewTableFromRecords(record.Schema(), []arrow.Record{record}) + defer table.Release() + return ReadArrowTableToRows(table) +} + func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) { buf := make([]any, 0) // Switch over the type t and append the values to buf. diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index d569fc0..2e3aa40 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -406,3 +406,60 @@ func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { } assert.Equal(t, "Unsupported", types.ConvertProtoDataTypeToDataType(unsupportedDataType).TypeName()) } + +func TestReadArrowBatchToRecord(t *testing.T) { + // Create a test arrow record + arrowFields := []arrow.Field{ + {Name: "col1", Type: arrow.BinaryTypes.String}, + {Name: "col2", Type: arrow.PrimitiveTypes.Int32}, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.StringBuilder).Append("test1") + recordBuilder.Field(0).(*array.StringBuilder).Append("test2") + recordBuilder.Field(1).(*array.Int32Builder).Append(100) + recordBuilder.Field(1).(*array.Int32Builder).Append(200) + + originalRecord := recordBuilder.NewRecord() + defer originalRecord.Release() + + // Serialize to arrow batch format + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + err := arrowWriter.Write(originalRecord) + require.NoError(t, err) + + // Test ReadArrowBatchToRecord + record, err := types.ReadArrowBatchToRecord(buf.Bytes(), nil) + require.NoError(t, err) + defer record.Release() + + // Verify the record was read correctly + assert.Equal(t, int64(2), record.NumRows()) + assert.Equal(t, int64(2), record.NumCols()) + assert.Equal(t, "col1", record.Schema().Field(0).Name) + assert.Equal(t, "col2", record.Schema().Field(1).Name) +} + +func TestReadArrowBatchToRecord_InvalidData(t *testing.T) { + // Test with invalid arrow data + invalidData := []byte{0x00, 0x01, 0x02} + + _, err := types.ReadArrowBatchToRecord(invalidData, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create arrow reader") +} + +func TestReadArrowBatchToRecord_EmptyData(t *testing.T) { + // Test with empty data + emptyData := []byte{} + + _, err := types.ReadArrowBatchToRecord(emptyData, nil) + assert.Error(t, err) +} diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go new file mode 100644 index 0000000..6a6aaad --- /dev/null +++ b/spark/sql/types/rowiterator.go @@ -0,0 +1,162 @@ +package types + +import ( + "context" + "errors" + "github.com/apache/arrow-go/v18/arrow" + "io" + "sync" + "time" +) + +// RowIterator provides streaming access to individual rows +type RowIterator interface { + Next() (Row, error) + io.Closer +} + +// rowIteratorImpl implements RowIterator with robust cancellation handling +type rowIteratorImpl struct { + recordChan <-chan arrow.Record + errorChan <-chan error + schema *StructType + currentRows []Row + currentIndex int + exhausted bool + closed bool + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +func NewRowIterator(recordChan <-chan arrow.Record, errorChan <-chan error, schema *StructType) RowIterator { + // Create a context that we can cancel when the iterator is closed + ctx, cancel := context.WithCancel(context.Background()) + + return &rowIteratorImpl{ + recordChan: recordChan, + errorChan: errorChan, + schema: schema, + currentRows: nil, + currentIndex: 0, + exhausted: false, + closed: false, + ctx: ctx, + cancel: cancel, + } +} + +func (iter *rowIteratorImpl) Next() (Row, error) { + iter.mu.Lock() + defer iter.mu.Unlock() + + if iter.closed { + return nil, errors.New("iterator is closed") + } + if iter.exhausted { + return nil, io.EOF + } + + // Check if context was cancelled + select { + case <-iter.ctx.Done(): + return nil, iter.ctx.Err() + default: + } + + // If we have rows in the current batch, return the next one + if iter.currentIndex < len(iter.currentRows) { + row := iter.currentRows[iter.currentIndex] + iter.currentIndex++ + return row, nil + } + + // Fetch the next batch + if err := iter.fetchNextBatch(); err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + iter.exhausted = true + } + return nil, err + } + + // Return the first row from the new batch + if len(iter.currentRows) == 0 { + iter.exhausted = true + return nil, io.EOF + } + + row := iter.currentRows[0] + iter.currentIndex = 1 + return row, nil +} + +func (iter *rowIteratorImpl) fetchNextBatch() error { + select { + case <-iter.ctx.Done(): + return iter.ctx.Err() + + case record, ok := <-iter.recordChan: + if !ok { + // Channel closed - check for any errors + select { + case err := <-iter.errorChan: + return err + case <-iter.ctx.Done(): + return iter.ctx.Err() + default: + return io.EOF + } + } + + // Make sure to release the record even if conversion fails + defer record.Release() + + // Convert the Arrow record directly to rows using the helper + rows, err := ReadArrowRecordToRows(record) + if err != nil { + return err + } + + iter.currentRows = rows + iter.currentIndex = 0 + return nil + + case err := <-iter.errorChan: + return err + } +} + +func (iter *rowIteratorImpl) Close() error { + iter.mu.Lock() + defer iter.mu.Unlock() + + if iter.closed { + return nil + } + iter.closed = true + + // Cancel our context to signal cleanup + iter.cancel() + + // Drain any remaining records to prevent goroutine leaks + // Use a separate goroutine with timeout to avoid blocking + go func() { + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + + for { + select { + case record, ok := <-iter.recordChan: + if !ok { + return // Channel closed + } + record.Release() + case <-timeout.C: + // Timeout reached - force exit to prevent hanging + return + } + } + }() + + return nil +} diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go new file mode 100644 index 0000000..c977fc7 --- /dev/null +++ b/spark/sql/types/rowiterator_test.go @@ -0,0 +1,220 @@ +package types_test + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/apache/spark-connect-go/v40/spark/sql/types" +) + +func createTestRecord(values []string) arrow.Record { + schema := arrow.NewSchema( + []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, + nil, + ) + + alloc := memory.NewGoAllocator() + builder := array.NewRecordBuilder(alloc, schema) + defer builder.Release() + + for _, v := range values { + builder.Field(0).(*array.StringBuilder).Append(v) + } + + return builder.NewRecord() +} + +func TestRowIterator_BasicIteration(t *testing.T) { + recordChan := make(chan arrow.Record, 2) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // Send test records + recordChan <- createTestRecord([]string{"row1", "row2"}) + recordChan <- createTestRecord([]string{"row3", "row4"}) + close(recordChan) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // Collect all rows + var rows []types.Row + for { + row, err := iter.Next() + if err == io.EOF { + break + } + require.NoError(t, err) + rows = append(rows, row) + } + + // Verify we got all 4 rows + assert.Len(t, rows, 4) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) + assert.Equal(t, "row4", rows[3].At(0)) +} + +func TestRowIterator_ContextCancellation(t *testing.T) { + recordChan := make(chan arrow.Record, 1) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // Send one record + recordChan <- createTestRecord([]string{"row1", "row2"}) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + + // Read first row successfully + row, err := iter.Next() + require.NoError(t, err) + assert.Equal(t, "row1", row.At(0)) + + // Close iterator (which cancels context) + err = iter.Close() + require.NoError(t, err) + + // Subsequent reads should fail with context error + _, err = iter.Next() + assert.Error(t, err) + assert.Contains(t, err.Error(), "iterator is closed") +} + +func TestRowIterator_ErrorPropagation(t *testing.T) { + recordChan := make(chan arrow.Record, 1) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // Send test record + recordChan <- createTestRecord([]string{"row1"}) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // Read first row successfully + row, err := iter.Next() + require.NoError(t, err) + assert.Equal(t, "row1", row.At(0)) + + // Send error + testErr := errors.New("test error") + errorChan <- testErr + close(recordChan) + + // Next read should return the error + _, err = iter.Next() + assert.Equal(t, testErr, err) +} + +func TestRowIterator_EmptyResult(t *testing.T) { + recordChan := make(chan arrow.Record) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // Close channel immediately + close(recordChan) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // First read should return EOF + _, err := iter.Next() + assert.Equal(t, io.EOF, err) + + // Subsequent reads should also return EOF + _, err = iter.Next() + assert.Equal(t, io.EOF, err) +} + +func TestRowIterator_MultipleClose(t *testing.T) { + recordChan := make(chan arrow.Record) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + iter := types.NewRowIterator(recordChan, errorChan, schema) + + // Close multiple times should not panic + err := iter.Close() + assert.NoError(t, err) + + err = iter.Close() + assert.NoError(t, err) +} + +func TestRowIterator_CloseWithPendingRecords(t *testing.T) { + recordChan := make(chan arrow.Record, 3) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // Send multiple records + for i := 0; i < 3; i++ { + recordChan <- createTestRecord([]string{"row"}) + } + + iter := types.NewRowIterator(recordChan, errorChan, schema) + + // Close without reading all records + // This should trigger the cleanup goroutine + err := iter.Close() + assert.NoError(t, err) + + // Give cleanup goroutine time to run + time.Sleep(100 * time.Millisecond) + + // Channel should be drained (this won't block if cleanup worked) + select { + case <-recordChan: + // Good, channel was drained + default: + // Also acceptable if already drained + } +} + +func TestRowIterator_ConcurrentAccess(t *testing.T) { + recordChan := make(chan arrow.Record, 5) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // Send multiple records + for i := 0; i < 5; i++ { + recordChan <- createTestRecord([]string{"row"}) + } + close(recordChan) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // Try concurrent reads (should be safe due to mutex) + done := make(chan bool, 2) + + go func() { + for i := 0; i < 2; i++ { + _, _ = iter.Next() + } + done <- true + }() + + go func() { + for i := 0; i < 3; i++ { + _, _ = iter.Next() + } + done <- true + }() + + // Wait for both goroutines + <-done + <-done + + // Should have consumed all 5 records + _, err := iter.Next() + assert.Equal(t, io.EOF, err) +} From 5e0a589a3194f08372efdf7ac7d5dca51a914cdc Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 01:34:58 +0100 Subject: [PATCH 02/37] [debug] a case where context cancellations result in a panic --- spark/client/client.go | 51 ++- spark/client/client_test.go | 528 ++++++++++++++++++++++++---- spark/sql/dataframe.go | 2 +- spark/sql/types/rowiterator.go | 149 +++++--- spark/sql/types/rowiterator_test.go | 108 +++++- 5 files changed, 722 insertions(+), 116 deletions(-) diff --git a/spark/client/client.go b/spark/client/client.go index 68af201..5d3e045 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -368,6 +368,15 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { c.done = false for { resp, err := c.responseStream.Recv() + if err != nil { + fmt.Printf("DEBUG: Recv error: %v, is EOF: %v\n", err, errors.Is(err, io.EOF)) + } + if err == nil && resp != nil { + fmt.Printf("DEBUG: Received response type: %T\n", resp.ResponseType) + if _, ok := resp.ResponseType.(*proto.ExecutePlanResponse_ResultComplete_); ok { + fmt.Println("DEBUG: Got ResultComplete!") + } + } // EOF is received when the last message has been processed and the stream // finished normally. if errors.Is(err, io.EOF) { @@ -477,15 +486,43 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R } // EOF is received when the last message has been processed and the stream - // finished normally. + // finished normally. Handle this FIRST, before any other processing. if errors.Is(err, io.EOF) { return } - // If the error was not EOF, there might be another error. - if se := sparkerrors.FromRPCError(err); se != nil { + // If there's any other error, handle it + if err != nil { + if se := sparkerrors.FromRPCError(err); se != nil { + select { + case errorChan <- sparkerrors.WithType(se, sparkerrors.ExecutionError): + case <-ctx.Done(): + return + } + } else { + // Unknown error - still send it + select { + case errorChan <- err: + case <-ctx.Done(): + return + } + } + return + } + + // Only proceed if we have a valid response (no error) + if resp == nil { + continue + } + + // Check that the server returned the session ID that we were expecting + // and that it has not changed. + if resp.GetSessionId() != c.sessionId { select { - case errorChan <- sparkerrors.WithType(se, sparkerrors.ExecutionError): + case errorChan <- sparkerrors.WithType(&sparkerrors.InvalidServerSideSessionDetailsError{ + OwnSessionId: c.sessionId, + ReceivedSessionId: resp.GetSessionId(), + }, sparkerrors.InvalidServerSideSessionError): case <-ctx.Done(): return } @@ -494,7 +531,7 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R // Check if the response has already the schema set and if yes, convert // the proto DataType to a StructType. - if resp.Schema != nil && c.schema == nil { + if resp.Schema != nil { c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema) if err != nil { select { @@ -538,6 +575,10 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R c.done = true return + case *proto.ExecutePlanResponse_ExecutionProgress_: + // Progress updates - we can ignore these or optionally expose them + // through a separate channel in the future + default: // Explicitly ignore messages that we cannot process at the moment. } diff --git a/spark/client/client_test.go b/spark/client/client_test.go index 20300de..f48bd42 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -116,10 +116,9 @@ func Test_Execute_SchemaParsingFails(t *testing.T) { } func TestToRecordBatches_SchemaExtraction(t *testing.T) { - // Verify schema is properly extracted and returned + // Schema is returned as nil and populated inside the goroutine ctx := context.Background() - // Arrange: Create a response with only schema (no data) schemaResponse := &mocks.MockResponse{ Resp: &proto.ExecutePlanResponse{ SessionId: mocks.MockSessionId, @@ -149,38 +148,54 @@ func TestToRecordBatches_SchemaExtraction(t *testing.T) { &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 'test'")) require.NoError(t, err) _, _, schema := stream.ToRecordBatches(ctx) - // Assert: Schema should be returned immediately (not populated by goroutine) - // Note: In the current implementation, schema is returned as nil and populated - // inside the goroutine. This might be a design decision to test. assert.Nil(t, schema, "Schema is populated asynchronously in the goroutine") } func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { - // Verify channel closure when no arrow batches are sent + // Channels should close without sending any records when no arrow batches present ctx := context.Background() - // Arrange: Only schema and done responses, no arrow batches + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "test_column", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - &mocks.ExecutePlanResponseWithSchema, + schemaResponse, &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - // Assert: Channels should close without sending any records recordsReceived := 0 errorsReceived := 0 - timeout := time.After(100 * time.Millisecond) done := false @@ -204,10 +219,33 @@ func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { } func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { - // Verify arrow batch data is correctly streamed + // Arrow batch data should be correctly streamed ctx := context.Background() - // Arrange: Create test arrow data + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + arrowData := createTestArrowBatch(t, []string{"value1", "value2", "value3"}) arrowBatch := &mocks.MockResponse{ @@ -223,18 +261,16 @@ func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { } c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - &mocks.ExecutePlanResponseWithSchema, + schemaResponse, arrowBatch, &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) require.NoError(t, err) recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - // Assert: Verify we receive exactly one record with correct data records := collectRecords(t, recordChan, errorChan) require.Len(t, records, 1, "Should receive exactly one record") @@ -243,7 +279,6 @@ func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { assert.Equal(t, int64(3), record.NumRows(), "Record should have 3 rows") assert.Equal(t, int64(1), record.NumCols(), "Record should have 1 column") - // Verify the actual data col := record.Column(0).(*array.String) assert.Equal(t, "value1", col.Value(0)) assert.Equal(t, "value2", col.Value(1)) @@ -251,10 +286,33 @@ func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { } func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { - // Verify multiple arrow batches are streamed in order + // Multiple arrow batches should be streamed in order ctx := context.Background() - // Arrange: Create multiple arrow batches + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + batch1 := createTestArrowBatch(t, []string{"batch1_row1", "batch1_row2"}) batch2 := createTestArrowBatch(t, []string{"batch2_row1", "batch2_row2"}) @@ -283,19 +341,17 @@ func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { } c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - &mocks.ExecutePlanResponseWithSchema, + schemaResponse, arrowBatch1, arrowBatch2, &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) require.NoError(t, err) recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - // Assert: Verify we receive both records in order records := collectRecords(t, recordChan, errorChan) require.Len(t, records, 2, "Should receive exactly two records") @@ -312,12 +368,9 @@ func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { } func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { - // Verify context cancellation stops streaming - - // Create a cancellable context + // Context cancellation should stop streaming ctx, cancel := context.WithCancel(context.Background()) - // Create mock responses - just a simple schema response schemaResponse := &mocks.MockResponse{ Resp: &proto.ExecutePlanResponse{ SessionId: mocks.MockSessionId, @@ -342,69 +395,81 @@ func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { }, } - // Create client with schema response followed by immediate done and EOF - // This ensures we don't get index out of range errors c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, schemaResponse, &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - // Execute the plan stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - // Start streaming recordChan, errorChan, _ := stream.ToRecordBatches(ctx) // Cancel the context immediately - // This should cause the goroutine to exit when it checks the context cancel() - // Wait for either completion or error timeout := time.After(100 * time.Millisecond) for { select { case _, ok := <-recordChan: if !ok { - // Channel closed normally - this is also acceptable - // as the context cancellation might happen after processing + // Channel closed normally - acceptable as cancellation might happen after processing return } case err := <-errorChan: - // We got an error - verify it's context cancellation + // Got an error - verify it's context cancellation assert.ErrorIs(t, err, context.Canceled) return case <-timeout: - // If we timeout without getting either channel closure or error, - // the test passes as the cancellation might have happened after - // all responses were processed + // Timeout is acceptable as cancellation might have happened after all responses were processed return } } } func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { - // Verify RPC errors are properly propagated + // RPC errors should be properly propagated ctx := context.Background() - // Arrange: Create a response that will return an RPC error + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col1", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + expectedError := errors.New("simulated RPC error") errorResponse := &mocks.MockResponse{ Err: expectedError, } c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - &mocks.ExecutePlanResponseWithSchema, + schemaResponse, errorResponse) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - // Assert: Should receive the RPC error select { case err := <-errorChan: assert.Error(t, err) @@ -416,11 +481,10 @@ func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { } } -// Test 7: Verify session validation func TestToRecordBatches_SessionValidation(t *testing.T) { + // Session validation error should be returned for wrong session ID ctx := context.Background() - // Arrange: Create response with wrong session ID wrongSessionResponse := &mocks.MockResponse{ Resp: &proto.ExecutePlanResponse{ SessionId: "wrong-session-id", @@ -445,18 +509,15 @@ func TestToRecordBatches_SessionValidation(t *testing.T) { }, } - // Need to provide EOF to prevent index out of range c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, wrongSessionResponse, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) _, errorChan, _ := stream.ToRecordBatches(ctx) - // Assert: Should receive session validation error select { case err := <-errorChan: assert.Error(t, err) @@ -467,10 +528,9 @@ func TestToRecordBatches_SessionValidation(t *testing.T) { } func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { - // Verify SQL command results are captured in properties + // SQL command results should be captured in properties ctx := context.Background() - // Arrange: Create response with SQL command result sqlResultResponse := &mocks.MockResponse{ Resp: &proto.ExecutePlanResponse{ SessionId: mocks.MockSessionId, @@ -492,35 +552,29 @@ func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("test query")) require.NoError(t, err) - // Consume the stream to ensure properties are set recordChan, errorChan, _ := stream.ToRecordBatches(ctx) _ = collectRecords(t, recordChan, errorChan) - // Assert: Properties should contain the SQL command result - // Note: We need access to the stream's Properties() method - // This might require modifying the test or the interface - // For now, this test validates that the stream processes SQL command results without error + // Properties should contain the SQL command result + props := stream.(*client.ExecutePlanClient).Properties() + assert.NotNil(t, props["sql_command_result"]) } func TestToRecordBatches_EOFHandling(t *testing.T) { - // Verify proper handling of EOF + // EOF should close channels without error ctx := context.Background() - // Arrange: Only EOF response c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, &mocks.ExecutePlanResponseEOF) - // Act stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - // Assert: Should close channels without error timeout := time.After(100 * time.Millisecond) recordClosed := false errorReceived := false @@ -542,6 +596,354 @@ func TestToRecordBatches_EOFHandling(t *testing.T) { assert.False(t, errorReceived, "No error should be received for EOF") } +func TestToRecordBatches_ExecutionProgressHandling(t *testing.T) { + // Execution progress messages should be handled without affecting record streaming + ctx := context.Background() + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col1", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + progressResponse1 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + } + + progressResponse2 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + } + + arrowData := createTestArrowBatch(t, []string{"value1", "value2"}) + arrowBatch := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: arrowData, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + progressResponse1, + progressResponse2, + arrowBatch, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col1")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + records := collectRecords(t, recordChan, errorChan) + require.Len(t, records, 1, "Should receive exactly one record despite progress messages") + + record := records[0] + assert.Equal(t, int64(2), record.NumRows()) +} + +func TestToRecordBatches_SqlCommandResultOnly(t *testing.T) { + // Queries that only return SqlCommandResult should complete without arrow batches + ctx := context.Background() + + sqlResultResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "SHOW TABLES"}, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + sqlResultResponse, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SHOW TABLES")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + recordsReceived := 0 + errorsReceived := 0 + timeout := time.After(100 * time.Millisecond) + done := false + + for !done { + select { + case _, ok := <-recordChan: + if ok { + recordsReceived++ + } else { + done = true + } + case <-errorChan: + errorsReceived++ + case <-timeout: + t.Fatal("Test timed out - channels not closed") + } + } + + assert.Equal(t, 0, recordsReceived, "No records should be sent for SqlCommandResult only") + assert.Equal(t, 0, errorsReceived, "No errors should occur") + + props := stream.(*client.ExecutePlanClient).Properties() + assert.NotNil(t, props["sql_command_result"]) +} + +func TestToRecordBatches_MixedResponseTypes(t *testing.T) { + // Mixed response types should be handled correctly in realistic order + ctx := context.Background() + + responses := []*mocks.MockResponse{ + // Schema first + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "id", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + // SQL command result + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "SELECT * FROM table"}, + }, + }, + }, + }, + }, + }, + // Progress updates + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Arrow batch + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"row1"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // More progress + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Another arrow batch + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"row2", "row3"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // Result complete + &mocks.ExecutePlanResponseDone, + // EOF + &mocks.ExecutePlanResponseEOF, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, responses...) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT * FROM table")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + records := collectRecords(t, recordChan, errorChan) + require.Len(t, records, 2, "Should receive exactly two arrow batches") + + assert.Equal(t, int64(1), records[0].NumRows()) + assert.Equal(t, int64(2), records[1].NumRows()) +} + +func TestToRecordBatches_NoResultCompleteWithEOF(t *testing.T) { + // Server sends EOF without ResultComplete (real Databricks behavior) + ctx := context.Background() + + responses := []*mocks.MockResponse{ + // Schema + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "value", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + // SqlCommandResult + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "SELECT 'test'"}, + }, + }, + }, + }, + }, + }, + // ExecutionProgress + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Arrow batch with data + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"test"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // EOF without ResultComplete (Databricks behavior) + &mocks.ExecutePlanResponseEOF, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, responses...) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT 'test'")) + require.NoError(t, err) + + recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + + records := collectRecords(t, recordChan, errorChan) + require.Len(t, records, 1, "Should receive exactly one record") + + record := records[0] + assert.Equal(t, int64(1), record.NumRows()) + col := record.Column(0).(*array.String) + assert.Equal(t, "test", col.Value(0)) +} + // Helper function to create test arrow batch data func createTestArrowBatch(t *testing.T, values []string) []byte { t.Helper() @@ -592,7 +994,9 @@ func collectRecords(t *testing.T, recordChan <-chan arrow.Record, errorChan <-ch records = append(records, record) } case err := <-errorChan: - t.Fatalf("Unexpected error: %v", err) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } case <-timeout: t.Fatal("Test timed out collecting records") } diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 344800d..acb7c11 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -944,7 +944,7 @@ func (df *dataFrameImpl) ToLocalIterator(ctx context.Context) (types.RowIterator recordChan, errorChan, schema := responseClient.ToRecordBatches(ctx) - return types.NewRowIterator(recordChan, errorChan, schema), nil + return types.NewRowIterator(ctx, recordChan, errorChan, schema), nil } func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) DataFrame { diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index 6a6aaad..418d999 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -27,11 +27,13 @@ type rowIteratorImpl struct { mu sync.Mutex ctx context.Context cancel context.CancelFunc + cleanupOnce sync.Once } -func NewRowIterator(recordChan <-chan arrow.Record, errorChan <-chan error, schema *StructType) RowIterator { - // Create a context that we can cancel when the iterator is closed - ctx, cancel := context.WithCancel(context.Background()) +// NewRowIterator creates a new row iterator with the given context +func NewRowIterator(ctx context.Context, recordChan <-chan arrow.Record, errorChan <-chan error, schema *StructType) RowIterator { + // Create a cancellable context derived from the parent + iterCtx, cancel := context.WithCancel(ctx) return &rowIteratorImpl{ recordChan: recordChan, @@ -41,7 +43,7 @@ func NewRowIterator(recordChan <-chan arrow.Record, errorChan <-chan error, sche currentIndex: 0, exhausted: false, closed: false, - ctx: ctx, + ctx: iterCtx, cancel: cancel, } } @@ -60,6 +62,7 @@ func (iter *rowIteratorImpl) Next() (Row, error) { // Check if context was cancelled select { case <-iter.ctx.Done(): + iter.exhausted = true return nil, iter.ctx.Err() default: } @@ -90,73 +93,127 @@ func (iter *rowIteratorImpl) Next() (Row, error) { return row, nil } +// fetchNextBatch with deterministic channel handling func (iter *rowIteratorImpl) fetchNextBatch() error { - select { - case <-iter.ctx.Done(): - return iter.ctx.Err() + for { + select { + case <-iter.ctx.Done(): + return iter.ctx.Err() + + case record, ok := <-iter.recordChan: + if !ok { + // Record channel is closed - check for any final error + return iter.checkErrorChannelOnClose() + } - case record, ok := <-iter.recordChan: - if !ok { - // Channel closed - check for any errors - select { - case err := <-iter.errorChan: + // We have a valid record - handle nil check + if record == nil { + continue // Skip nil records + } + + // Convert to rows and release the record immediately + rows, err := func() ([]Row, error) { + defer record.Release() + return ReadArrowRecordToRows(record) + }() + + if err != nil { return err - case <-iter.ctx.Done(): - return iter.ctx.Err() - default: - return io.EOF } - } - // Make sure to release the record even if conversion fails - defer record.Release() + iter.currentRows = rows + iter.currentIndex = 0 + return nil - // Convert the Arrow record directly to rows using the helper - rows, err := ReadArrowRecordToRows(record) - if err != nil { + case err, ok := <-iter.errorChan: + if !ok { + // Error channel closed - treat as EOF + return io.EOF + } + // Error received - return it (nil errors become EOF) + if err == nil { + return io.EOF + } return err } + } +} - iter.currentRows = rows - iter.currentIndex = 0 - return nil +// checkErrorChannelOnClose handles error channel when record channel closes +func (iter *rowIteratorImpl) checkErrorChannelOnClose() error { + // Use a small timeout to check for any trailing errors + timer := time.NewTimer(50 * time.Millisecond) + defer timer.Stop() - case err := <-iter.errorChan: + select { + case err, ok := <-iter.errorChan: + if !ok || err == nil { + // Channel closed or nil error - normal EOF + return io.EOF + } + // Got actual error return err + case <-timer.C: + // No error within timeout - assume normal EOF + return io.EOF + case <-iter.ctx.Done(): + // Context cancelled during wait + return iter.ctx.Err() } } func (iter *rowIteratorImpl) Close() error { iter.mu.Lock() - defer iter.mu.Unlock() - if iter.closed { + iter.mu.Unlock() return nil } iter.closed = true + iter.mu.Unlock() - // Cancel our context to signal cleanup + // Cancel the context to signal any blocked operations to stop iter.cancel() - // Drain any remaining records to prevent goroutine leaks - // Use a separate goroutine with timeout to avoid blocking - go func() { - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() - - for { - select { - case record, ok := <-iter.recordChan: - if !ok { - return // Channel closed + // Ensure cleanup happens only once + iter.cleanupOnce.Do(func() { + // Start a goroutine to drain channels + // This prevents the producer goroutine from blocking + go iter.drainChannels() + }) + + return nil +} + +// drainChannels drains both channels to prevent producer goroutine from blocking +func (iter *rowIteratorImpl) drainChannels() { + // Use a reasonable timeout for cleanup + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + for { + select { + case record, ok := <-iter.recordChan: + if !ok { + // Channel closed, check error channel one more time + select { + case <-iter.errorChan: + // Drained + case <-ctx.Done(): + // Timeout } - record.Release() - case <-timeout.C: - // Timeout reached - force exit to prevent hanging return } - } - }() + // Release any remaining records to prevent memory leaks + if record != nil { + record.Release() + } - return nil + case <-iter.errorChan: + // Just drain, don't process + + case <-ctx.Done(): + // Cleanup timeout - exit + return + } + } } diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index c977fc7..bb72f30 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -21,15 +21,21 @@ func createTestRecord(values []string) arrow.Record { nil, ) + // Create a NEW allocator for each record to ensure isolation alloc := memory.NewGoAllocator() builder := array.NewRecordBuilder(alloc, schema) - defer builder.Release() for _, v := range values { builder.Field(0).(*array.StringBuilder).Append(v) } - return builder.NewRecord() + record := builder.NewRecord() + builder.Release() // Release AFTER creating record + + // Important: Retain the record to ensure it owns its memory + record.Retain() + + return record } func TestRowIterator_BasicIteration(t *testing.T) { @@ -218,3 +224,101 @@ func TestRowIterator_ConcurrentAccess(t *testing.T) { _, err := iter.Next() assert.Equal(t, io.EOF, err) } + +func TestRowIterator_ErrorAfterRecordChannelClosed(t *testing.T) { + // Test error handling when record channel closes but error channel has data + // This mimics Databricks behavior where EOF errors can come after stream ends + recordChan := make(chan arrow.Record, 1) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + recordChan <- createTestRecord([]string{"row1"}) + close(recordChan) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // Get first row + row, err := iter.Next() + require.NoError(t, err) + assert.Equal(t, "row1", row.At(0)) + + // Put error in channel AFTER getting the first row + testErr := errors.New("delayed error") + errorChan <- testErr + + // Next call should return the error from error channel + _, err = iter.Next() + assert.Error(t, err) + assert.Contains(t, err.Error(), "delayed error") +} + +func TestRowIterator_BothChannelsClosedCleanly(t *testing.T) { + // Test clean shutdown when both channels close without errors (Databricks normal case) + recordChan := make(chan arrow.Record, 1) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + recordChan <- createTestRecord([]string{"row1"}) + close(recordChan) + close(errorChan) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // Get the record + row, err := iter.Next() + require.NoError(t, err) + assert.Equal(t, "row1", row.At(0)) + + // Should get EOF on next call + _, err = iter.Next() + assert.Equal(t, io.EOF, err) +} + +func TestRowIterator_RecordReleaseOnError(t *testing.T) { + // Test that records are properly released even when conversion fails + recordChan := make(chan arrow.Record, 1) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + // This would test record release, but since we can't easily make + // ReadArrowRecordToRows fail, we'll test the normal case + record := createTestRecord([]string{"row1"}) + recordChan <- record + close(recordChan) + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // Get record (this should work and release the arrow record internally) + row, err := iter.Next() + require.NoError(t, err) + assert.Equal(t, "row1", row.At(0)) + + // Verify we can't get another record + _, err = iter.Next() + assert.Equal(t, io.EOF, err) +} + +func TestRowIterator_ExhaustedState(t *testing.T) { + // Test that exhausted state is properly maintained + recordChan := make(chan arrow.Record) + errorChan := make(chan error, 1) + schema := &types.StructType{} + + close(recordChan) // No records + + iter := types.NewRowIterator(recordChan, errorChan, schema) + defer iter.Close() + + // First call should set exhausted and return EOF + _, err := iter.Next() + assert.Equal(t, io.EOF, err) + + // All subsequent calls should also return EOF (exhausted state) + for i := 0; i < 3; i++ { + _, err := iter.Next() + assert.Equal(t, io.EOF, err) + } +} From c277f5bfe6851c6b948910fe25c37e6d5ff2952f Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 12:16:08 +0100 Subject: [PATCH 03/37] [SPARK-52780] fix test compilation --- spark/sql/types/rowiterator_test.go | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index bb72f30..7e74d83 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -1,6 +1,7 @@ package types_test import ( + "context" "errors" "io" "testing" @@ -48,7 +49,7 @@ func TestRowIterator_BasicIteration(t *testing.T) { recordChan <- createTestRecord([]string{"row3", "row4"}) close(recordChan) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // Collect all rows @@ -78,7 +79,7 @@ func TestRowIterator_ContextCancellation(t *testing.T) { // Send one record recordChan <- createTestRecord([]string{"row1", "row2"}) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) // Read first row successfully row, err := iter.Next() @@ -103,7 +104,7 @@ func TestRowIterator_ErrorPropagation(t *testing.T) { // Send test record recordChan <- createTestRecord([]string{"row1"}) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // Read first row successfully @@ -129,7 +130,7 @@ func TestRowIterator_EmptyResult(t *testing.T) { // Close channel immediately close(recordChan) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // First read should return EOF @@ -146,7 +147,7 @@ func TestRowIterator_MultipleClose(t *testing.T) { errorChan := make(chan error, 1) schema := &types.StructType{} - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) // Close multiple times should not panic err := iter.Close() @@ -166,7 +167,7 @@ func TestRowIterator_CloseWithPendingRecords(t *testing.T) { recordChan <- createTestRecord([]string{"row"}) } - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) // Close without reading all records // This should trigger the cleanup goroutine @@ -196,7 +197,7 @@ func TestRowIterator_ConcurrentAccess(t *testing.T) { } close(recordChan) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // Try concurrent reads (should be safe due to mutex) @@ -235,7 +236,7 @@ func TestRowIterator_ErrorAfterRecordChannelClosed(t *testing.T) { recordChan <- createTestRecord([]string{"row1"}) close(recordChan) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // Get first row @@ -263,7 +264,7 @@ func TestRowIterator_BothChannelsClosedCleanly(t *testing.T) { close(recordChan) close(errorChan) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // Get the record @@ -288,7 +289,7 @@ func TestRowIterator_RecordReleaseOnError(t *testing.T) { recordChan <- record close(recordChan) - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // Get record (this should work and release the arrow record internally) @@ -309,7 +310,7 @@ func TestRowIterator_ExhaustedState(t *testing.T) { close(recordChan) // No records - iter := types.NewRowIterator(recordChan, errorChan, schema) + iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() // First call should set exhausted and return EOF From 7ce5d47651b57795eb8ddf956cb47072851423a6 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 16:07:48 +0100 Subject: [PATCH 04/37] [SPARK-52780] TestRowIterator_BothChannelsClosedCleanly should EOF (Databricks/Spark signal done processing rows) --- spark/sql/types/rowiterator_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 7e74d83..71e5ee7 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -268,9 +268,7 @@ func TestRowIterator_BothChannelsClosedCleanly(t *testing.T) { defer iter.Close() // Get the record - row, err := iter.Next() - require.NoError(t, err) - assert.Equal(t, "row1", row.At(0)) + _, err := iter.Next() // Should get EOF on next call _, err = iter.Next() From 2b6044a5dd38b1f9ba79ba1568c67b05cd2aebd1 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 16:19:35 +0100 Subject: [PATCH 05/37] [SPARK-52780] fix linting error --- spark/sql/types/rowiterator_test.go | 56 ++++++++++++++--------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 71e5ee7..9a6c4e0 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -3,42 +3,19 @@ package types_test import ( "context" "errors" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" "io" "testing" "time" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/arrow-go/v18/arrow/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/apache/spark-connect-go/v40/spark/sql/types" ) -func createTestRecord(values []string) arrow.Record { - schema := arrow.NewSchema( - []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, - nil, - ) - - // Create a NEW allocator for each record to ensure isolation - alloc := memory.NewGoAllocator() - builder := array.NewRecordBuilder(alloc, schema) - - for _, v := range values { - builder.Field(0).(*array.StringBuilder).Append(v) - } - - record := builder.NewRecord() - builder.Release() // Release AFTER creating record - - // Important: Retain the record to ensure it owns its memory - record.Retain() - - return record -} - func TestRowIterator_BasicIteration(t *testing.T) { recordChan := make(chan arrow.Record, 2) errorChan := make(chan error, 1) @@ -267,11 +244,8 @@ func TestRowIterator_BothChannelsClosedCleanly(t *testing.T) { iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() - // Get the record - _, err := iter.Next() - // Should get EOF on next call - _, err = iter.Next() + _, err := iter.Next() assert.Equal(t, io.EOF, err) } @@ -321,3 +295,27 @@ func TestRowIterator_ExhaustedState(t *testing.T) { assert.Equal(t, io.EOF, err) } } + +func createTestRecord(values []string) arrow.Record { + schema := arrow.NewSchema( + []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, + nil, + ) + + // Create a NEW allocator for each record to ensure isolation + alloc := memory.NewGoAllocator() + builder := array.NewRecordBuilder(alloc, schema) + + for _, v := range values { + builder.Field(0).(*array.StringBuilder).Append(v) + } + + record := builder.NewRecord() + // Release AFTER creating record + builder.Release() + + // Retain the record to ensure it owns its memory + record.Retain() + + return record +} From 1a897ef1b5b93d8799110aa124fd414eb69b6602 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 16:51:02 +0100 Subject: [PATCH 06/37] [SPARK-52780] rowiterator.go channel closing should deterministically release rows. --- spark/sql/types/rowiterator.go | 53 ++++++++++++++++++++++++++--- spark/sql/types/rowiterator_test.go | 5 ++- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index 418d999..cda3ab8 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -93,7 +93,7 @@ func (iter *rowIteratorImpl) Next() (Row, error) { return row, nil } -// fetchNextBatch with deterministic channel handling +// fetchNextBatch with deterministic handling to release rows before returning EOF func (iter *rowIteratorImpl) fetchNextBatch() error { for { select { @@ -108,7 +108,7 @@ func (iter *rowIteratorImpl) fetchNextBatch() error { // We have a valid record - handle nil check if record == nil { - continue // Skip nil records + continue } // Convert to rows and release the record immediately @@ -127,9 +127,40 @@ func (iter *rowIteratorImpl) fetchNextBatch() error { case err, ok := <-iter.errorChan: if !ok { - // Error channel closed - treat as EOF - return io.EOF + // Error channel closed - continue to check record channel + // Don't immediately return EOF if there are still records to process + select { + case record, ok := <-iter.recordChan: + if !ok { + // Both channels are closed + return io.EOF + } + + // We have a valid record - handle nil check + if record == nil { + continue // Skip nil records + } + + // Convert to rows and release the record immediately + rows, err := func() ([]Row, error) { + defer record.Release() + return ReadArrowRecordToRows(record) + }() + + if err != nil { + return err + } + + iter.currentRows = rows + iter.currentIndex = 0 + return nil + + default: + // No immediate record available, but channel isn't closed + // Continue with the main select loop + } } + // Error received - return it (nil errors become EOF) if err == nil { return io.EOF @@ -141,6 +172,19 @@ func (iter *rowIteratorImpl) fetchNextBatch() error { // checkErrorChannelOnClose handles error channel when record channel closes func (iter *rowIteratorImpl) checkErrorChannelOnClose() error { + // If error channel is already closed, return EOF + select { + case err, ok := <-iter.errorChan: + if !ok || err == nil { + // Channel closed or nil error - normal EOF + return io.EOF + } + // Got actual error + return err + default: + // Error channel still open, use timeout approach + } + // Use a small timeout to check for any trailing errors timer := time.NewTimer(50 * time.Millisecond) defer timer.Stop() @@ -151,7 +195,6 @@ func (iter *rowIteratorImpl) checkErrorChannelOnClose() error { // Channel closed or nil error - normal EOF return io.EOF } - // Got actual error return err case <-timer.C: // No error within timeout - assume normal EOF diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 9a6c4e0..99672de 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -244,8 +244,11 @@ func TestRowIterator_BothChannelsClosedCleanly(t *testing.T) { iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) defer iter.Close() + row, err := iter.Next() + assert.Equal(t, "row1", row.At(0)) + assert.Nil(t, err) // Should get EOF on next call - _, err := iter.Next() + _, err = iter.Next() assert.Equal(t, io.EOF, err) } From 8c18703386a8666edb5b209c52f26df979485a67 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 13 Jul 2025 16:57:39 +0100 Subject: [PATCH 07/37] [SPARK-52780] lint errors --- spark/client/client_test.go | 5 +++-- spark/sql/types/rowiterator.go | 5 ++--- spark/sql/types/rowiterator_test.go | 5 +++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/spark/client/client_test.go b/spark/client/client_test.go index f48bd42..48790de 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -19,6 +19,9 @@ import ( "bytes" "context" "errors" + "testing" + "time" + "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" @@ -31,8 +34,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" - "time" ) func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index cda3ab8..02f6b98 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -3,10 +3,11 @@ package types import ( "context" "errors" - "github.com/apache/arrow-go/v18/arrow" "io" "sync" "time" + + "github.com/apache/arrow-go/v18/arrow" ) // RowIterator provides streaming access to individual rows @@ -116,7 +117,6 @@ func (iter *rowIteratorImpl) fetchNextBatch() error { defer record.Release() return ReadArrowRecordToRows(record) }() - if err != nil { return err } @@ -146,7 +146,6 @@ func (iter *rowIteratorImpl) fetchNextBatch() error { defer record.Release() return ReadArrowRecordToRows(record) }() - if err != nil { return err } diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 99672de..0626c15 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -3,12 +3,13 @@ package types_test import ( "context" "errors" - "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/arrow-go/v18/arrow/memory" "io" "testing" "time" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/arrow" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From f285079e16d2bc49e8bdecafbf6d4f71d8a6bd1f Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Wed, 3 Sep 2025 18:57:52 +0100 Subject: [PATCH 08/37] feat: update the client base to provide lazy fetch --- spark/client/base/base.go | 5 +- spark/client/client.go | 112 +++----- spark/client/client_test.go | 544 +++++------------------------------- 3 files changed, 108 insertions(+), 553 deletions(-) diff --git a/spark/client/base/base.go b/spark/client/base/base.go index d7be261..0da8c9c 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -17,6 +17,7 @@ package base import ( "context" + "iter" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -47,7 +48,9 @@ type SparkConnectClient interface { } type ExecuteResponseStream interface { + // ToTable consumes all arrow.Record batches to a single arrow.Table. Useful for collecting all query results into a client DF. ToTable() (*types.StructType, arrow.Table, error) - ToRecordBatches(ctx context.Context) (<-chan arrow.Record, <-chan error, *types.StructType) + // ToRecordIterator lazily consumes each arrow.Record retrieved by a query. Useful for streaming query results. + ToRecordIterator(ctx context.Context) iter.Seq2[arrow.Record, error] Properties() map[string]any } diff --git a/spark/client/client.go b/spark/client/client.go index 3851292..25728a8 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "iter" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -443,17 +444,10 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } -func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.Record, <-chan error, *types.StructType) { - recordChan := make(chan arrow.Record, 10) - errorChan := make(chan error, 1) - - go func() { - defer func() { - // Ensure channels are always closed to prevent goroutine leaks - close(recordChan) - close(errorChan) - }() - +// ToRecordIterator returns a single Seq2 iterator lazily fetching +func (c *ExecutePlanClient) ToRecordIterator(ctx context.Context) iter.Seq2[arrow.Record, error] { + // Return Seq2 iterator that directly yields results as they arrive + iterator := func(yield func(arrow.Record, error) bool) { // Explicitly needed when tracking re-attachable execution. c.done = false @@ -461,15 +455,10 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R // Check for context cancellation before each iteration select { case <-ctx.Done(): - // Context cancelled - send the error and return immediately - select { - case errorChan <- ctx.Err(): - default: - // Channel might be full, but we're exiting anyway - } + // Yield the context error and stop + yield(nil, ctx.Err()) return default: - // Continue with normal processing } resp, err := c.responseStream.Recv() @@ -477,72 +466,52 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R // Check for context cancellation after potentially blocking operations select { case <-ctx.Done(): - select { - case errorChan <- ctx.Err(): - default: - } + yield(nil, ctx.Err()) return default: } - // EOF is received when the last message has been processed and the stream - // finished normally. Handle this FIRST, before any other processing. + // EOF is received when the last message has been processed (Observed on Databricks instances) if errors.Is(err, io.EOF) { - return + return // Clean end of stream } - // If there's any other error, handle it + // Handle other errors if err != nil { if se := sparkerrors.FromRPCError(err); se != nil { - select { - case errorChan <- sparkerrors.WithType(se, sparkerrors.ExecutionError): - case <-ctx.Done(): - return - } + yield(nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)) } else { - // Unknown error - still send it - select { - case errorChan <- err: - case <-ctx.Done(): - return - } + yield(nil, err) } - return + return // Stop on error } - // Only proceed if we have a valid response (no error) + // Only proceed if we have a valid response if resp == nil { continue } - // Check that the server returned the session ID that we were expecting - // and that it has not changed. + // Validate session ID if resp.GetSessionId() != c.sessionId { - select { - case errorChan <- sparkerrors.WithType(&sparkerrors.InvalidServerSideSessionDetailsError{ - OwnSessionId: c.sessionId, - ReceivedSessionId: resp.GetSessionId(), - }, sparkerrors.InvalidServerSideSessionError): - case <-ctx.Done(): - return - } + yield(nil, sparkerrors.WithType( + &sparkerrors.InvalidServerSideSessionDetailsError{ + OwnSessionId: c.sessionId, + ReceivedSessionId: resp.GetSessionId(), + }, sparkerrors.InvalidServerSideSessionError)) return } - // Check if the response has already the schema set and if yes, convert - // the proto DataType to a StructType. + // Process schema if present if resp.Schema != nil { - c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema) - if err != nil { - select { - case errorChan <- sparkerrors.WithType(err, sparkerrors.ExecutionError): - case <-ctx.Done(): - return - } + var schemaErr error + c.schema, schemaErr = types.ConvertProtoDataTypeToStructType(resp.Schema) + if schemaErr != nil { + yield(nil, sparkerrors.WithType(schemaErr, sparkerrors.ExecutionError)) return } } + // Process response types switch x := resp.ResponseType.(type) { case *proto.ExecutePlanResponse_SqlCommandResult_: if val := x.SqlCommandResult.GetRelation(); val != nil { @@ -550,24 +519,16 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R } case *proto.ExecutePlanResponse_ArrowBatch_: - // This is what we want - stream the record batch record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema) if err != nil { - select { - case errorChan <- err: - case <-ctx.Done(): - return - } + yield(nil, err) return } - // Try to send the record, but respect context cancellation - select { - case recordChan <- record: - // Successfully sent - case <-ctx.Done(): - // Context cancelled while trying to send - release the record and exit - record.Release() + // Yield the record and check if consumer wants to continue + if !yield(record, nil) { + // Consumer stopped iteration early + // Note: Consumer is responsible for releasing the record return } @@ -576,16 +537,15 @@ func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan arrow.R return case *proto.ExecutePlanResponse_ExecutionProgress_: - // Progress updates - we can ignore these or optionally expose them - // through a separate channel in the future + // Progress updates - ignore for now default: - // Explicitly ignore messages that we cannot process at the moment. + // Explicitly ignore unknown message types } } - }() + } - return recordChan, errorChan, c.schema + return iterator } func NewExecuteResponseStream( diff --git a/spark/client/client_test.go b/spark/client/client_test.go index 0dd67de..9f1a92e 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -1,24 +1,10 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package client_test import ( "bytes" "context" "errors" + "iter" "testing" "time" @@ -28,137 +14,14 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" proto "github.com/apache/spark-connect-go/internal/generated" "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" "github.com/apache/spark-connect-go/spark/mocks" "github.com/apache/spark-connect-go/spark/sparkerrors" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { - ctx := context.Background() - response := &proto.AnalyzePlanResponse{} - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(nil, response, nil, nil), nil, mocks.MockSessionId) - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestAnalyzePlanFailsIfClientFails(t *testing.T) { - ctx := context.Background() - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(nil, nil, assert.AnError, nil), nil, mocks.MockSessionId) - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.Nil(t, resp) - assert.Error(t, err) -} - -func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { - ctx := context.Background() - plan := &proto.Plan{} - - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone) - - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - resp, err := c.ExecutePlan(ctx, plan) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestExecutePlanCallsExecuteCommandOnClient(t *testing.T) { - ctx := context.Background() - plan := &proto.Plan{} - - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - - // Check that the execution fails if no command is supplied. - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err := c.ExecuteCommand(ctx, plan) - assert.ErrorIs(t, err, sparkerrors.ExecutionError) - - // Generate a command and the execution should succeed. - sqlCommand := mocks.NewSqlCommand("select range(10)") - c = client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err = c.ExecuteCommand(ctx, sqlCommand) - assert.NoError(t, err) -} - -func Test_ExecuteWithWrongSession(t *testing.T) { - ctx := context.Background() - sqlCommand := mocks.NewSqlCommand("select range(10)") - - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - - // Check that the execution fails if no command is supplied. - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, uuid.NewString()) - _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) - assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) -} - -func Test_Execute_SchemaParsingFails(t *testing.T) { - ctx := context.Background() - sqlCommand := mocks.NewSqlCommand("select range(10)") - responseStream := mocks.NewProtoClientMock( - &mocks.ExecutePlanResponseBrokenSchema, - &mocks.ExecutePlanResponseDone, - &mocks.ExecutePlanResponseEOF) - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) - assert.ErrorIs(t, err, sparkerrors.ExecutionError) -} - -func TestToRecordBatches_SchemaExtraction(t *testing.T) { - // Schema is returned as nil and populated inside the goroutine - ctx := context.Background() - - schemaResponse := &mocks.MockResponse{ - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - Schema: &proto.DataType{ - Kind: &proto.DataType_Struct_{ - Struct: &proto.DataType_Struct{ - Fields: []*proto.DataType_StructField{ - { - Name: "test_column", - DataType: &proto.DataType{ - Kind: &proto.DataType_String_{ - String_: &proto.DataType_String{}, - }, - }, - Nullable: false, - }, - }, - }, - }, - }, - }, - } - - c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - schemaResponse, - &mocks.ExecutePlanResponseDone, - &mocks.ExecutePlanResponseEOF) - - stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 'test'")) - require.NoError(t, err) - - _, _, schema := stream.ToRecordBatches(ctx) - - assert.Nil(t, schema, "Schema is populated asynchronously in the goroutine") -} - func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { - // Channels should close without sending any records when no arrow batches present + // Iterator should complete without yielding any records when no arrow batches present ctx := context.Background() schemaResponse := &mocks.MockResponse{ @@ -193,25 +56,18 @@ func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + iter := stream.ToRecordIterator(ctx) recordsReceived := 0 errorsReceived := 0 - timeout := time.After(100 * time.Millisecond) - done := false - - for !done { - select { - case _, ok := <-recordChan: - if ok { - recordsReceived++ - } else { - done = true - } - case <-errorChan: + + for record, err := range iter { + if err != nil { errorsReceived++ - case <-timeout: - t.Fatal("Test timed out - channels not closed") + break + } + if record != nil { + recordsReceived++ } } @@ -270,9 +126,9 @@ func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + iter := stream.ToRecordIterator(ctx) - records := collectRecords(t, recordChan, errorChan) + records := collectRecordsFromSeq2(t, iter) require.Len(t, records, 1, "Should receive exactly one record") @@ -351,9 +207,8 @@ func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - - records := collectRecords(t, recordChan, errorChan) + iter := stream.ToRecordIterator(ctx) + records := collectRecordsFromSeq2(t, iter) require.Len(t, records, 2, "Should receive exactly two records") @@ -404,28 +259,32 @@ func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + iter := stream.ToRecordIterator(ctx) // Cancel the context immediately cancel() + // Try to consume the iterator timeout := time.After(100 * time.Millisecond) + done := make(chan bool) - for { - select { - case _, ok := <-recordChan: - if !ok { - // Channel closed normally - acceptable as cancellation might happen after processing + go func() { + for _, err := range iter { + if err != nil { + // Got an error - verify it's context cancellation + assert.ErrorIs(t, err, context.Canceled) + done <- true return } - case err := <-errorChan: - // Got an error - verify it's context cancellation - assert.ErrorIs(t, err, context.Canceled) - return - case <-timeout: - // Timeout is acceptable as cancellation might have happened after all responses were processed - return } + done <- true + }() + + select { + case <-done: + // Good - iteration completed + case <-timeout: + // Timeout is acceptable as cancellation might have happened after all responses were processed } } @@ -469,17 +328,19 @@ func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + iter := stream.ToRecordIterator(ctx) - select { - case err := <-errorChan: - assert.Error(t, err) - assert.Contains(t, err.Error(), "simulated RPC error") - case <-recordChan: - t.Fatal("Should not receive any records when RPC error occurs") - case <-time.After(100 * time.Millisecond): - t.Fatal("Expected RPC error") + errorReceived := false + for _, err := range iter { + if err != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), "simulated RPC error") + errorReceived = true + break + } } + + assert.True(t, errorReceived, "Expected RPC error") } func TestToRecordBatches_SessionValidation(t *testing.T) { @@ -517,15 +378,19 @@ func TestToRecordBatches_SessionValidation(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - _, errorChan, _ := stream.ToRecordBatches(ctx) + iter := stream.ToRecordIterator(ctx) - select { - case err := <-errorChan: - assert.Error(t, err) - assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) - case <-time.After(100 * time.Millisecond): - t.Fatal("Expected session validation error") + errorReceived := false + for _, err := range iter { + if err != nil { + assert.Error(t, err) + assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) + errorReceived = true + break + } } + + assert.True(t, errorReceived, "Expected session validation error") } func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { @@ -556,190 +421,14 @@ func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("test query")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - _ = collectRecords(t, recordChan, errorChan) + iter := stream.ToRecordIterator(ctx) + _ = collectRecordsFromSeq2(t, iter) // Properties should contain the SQL command result props := stream.(*client.ExecutePlanClient).Properties() assert.NotNil(t, props["sql_command_result"]) } -func TestToRecordBatches_EOFHandling(t *testing.T) { - // EOF should close channels without error - ctx := context.Background() - - c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - &mocks.ExecutePlanResponseEOF) - - stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) - require.NoError(t, err) - - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - - timeout := time.After(100 * time.Millisecond) - recordClosed := false - errorReceived := false - - for !recordClosed { - select { - case _, ok := <-recordChan: - if !ok { - recordClosed = true - } - case <-errorChan: - errorReceived = true - case <-timeout: - t.Fatal("Test timed out") - } - } - - assert.True(t, recordClosed, "Record channel should be closed") - assert.False(t, errorReceived, "No error should be received for EOF") -} - -func TestToRecordBatches_ExecutionProgressHandling(t *testing.T) { - // Execution progress messages should be handled without affecting record streaming - ctx := context.Background() - - schemaResponse := &mocks.MockResponse{ - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - Schema: &proto.DataType{ - Kind: &proto.DataType_Struct_{ - Struct: &proto.DataType_Struct{ - Fields: []*proto.DataType_StructField{ - { - Name: "col1", - DataType: &proto.DataType{ - Kind: &proto.DataType_String_{ - String_: &proto.DataType_String{}, - }, - }, - Nullable: false, - }, - }, - }, - }, - }, - }, - } - - progressResponse1 := &mocks.MockResponse{ - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ - ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ - Stages: nil, - NumInflightTasks: 0, - }, - }, - }, - } - - progressResponse2 := &mocks.MockResponse{ - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ - ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ - Stages: nil, - NumInflightTasks: 0, - }, - }, - }, - } - - arrowData := createTestArrowBatch(t, []string{"value1", "value2"}) - arrowBatch := &mocks.MockResponse{ - Resp: &proto.ExecutePlanResponse{ - ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ - ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ - Data: arrowData, - }, - }, - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - }, - } - - c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - schemaResponse, - progressResponse1, - progressResponse2, - arrowBatch, - &mocks.ExecutePlanResponseDone, - &mocks.ExecutePlanResponseEOF) - - stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col1")) - require.NoError(t, err) - - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - - records := collectRecords(t, recordChan, errorChan) - require.Len(t, records, 1, "Should receive exactly one record despite progress messages") - - record := records[0] - assert.Equal(t, int64(2), record.NumRows()) -} - -func TestToRecordBatches_SqlCommandResultOnly(t *testing.T) { - // Queries that only return SqlCommandResult should complete without arrow batches - ctx := context.Background() - - sqlResultResponse := &mocks.MockResponse{ - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ - SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ - Relation: &proto.Relation{ - RelType: &proto.Relation_Sql{ - Sql: &proto.SQL{Query: "SHOW TABLES"}, - }, - }, - }, - }, - }, - } - - c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, - sqlResultResponse, - &mocks.ExecutePlanResponseEOF) - - stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SHOW TABLES")) - require.NoError(t, err) - - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - - recordsReceived := 0 - errorsReceived := 0 - timeout := time.After(100 * time.Millisecond) - done := false - - for !done { - select { - case _, ok := <-recordChan: - if ok { - recordsReceived++ - } else { - done = true - } - case <-errorChan: - errorsReceived++ - case <-timeout: - t.Fatal("Test timed out - channels not closed") - } - } - - assert.Equal(t, 0, recordsReceived, "No records should be sent for SqlCommandResult only") - assert.Equal(t, 0, errorsReceived, "No errors should occur") - - props := stream.(*client.ExecutePlanClient).Properties() - assert.NotNil(t, props["sql_command_result"]) -} - func TestToRecordBatches_MixedResponseTypes(t *testing.T) { // Mixed response types should be handled correctly in realistic order ctx := context.Background() @@ -846,105 +535,15 @@ func TestToRecordBatches_MixedResponseTypes(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT * FROM table")) require.NoError(t, err) - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) + iter := stream.ToRecordIterator(ctx) + records := collectRecordsFromSeq2(t, iter) - records := collectRecords(t, recordChan, errorChan) require.Len(t, records, 2, "Should receive exactly two arrow batches") assert.Equal(t, int64(1), records[0].NumRows()) assert.Equal(t, int64(2), records[1].NumRows()) } -func TestToRecordBatches_NoResultCompleteWithEOF(t *testing.T) { - // Server sends EOF without ResultComplete (real Databricks behavior) - ctx := context.Background() - - responses := []*mocks.MockResponse{ - // Schema - { - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - Schema: &proto.DataType{ - Kind: &proto.DataType_Struct_{ - Struct: &proto.DataType_Struct{ - Fields: []*proto.DataType_StructField{ - { - Name: "value", - DataType: &proto.DataType{ - Kind: &proto.DataType_String_{ - String_: &proto.DataType_String{}, - }, - }, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - // SqlCommandResult - { - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ - SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ - Relation: &proto.Relation{ - RelType: &proto.Relation_Sql{ - Sql: &proto.SQL{Query: "SELECT 'test'"}, - }, - }, - }, - }, - }, - }, - // ExecutionProgress - { - Resp: &proto.ExecutePlanResponse{ - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ - ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ - Stages: nil, - NumInflightTasks: 0, - }, - }, - }, - }, - // Arrow batch with data - { - Resp: &proto.ExecutePlanResponse{ - ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ - ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ - Data: createTestArrowBatch(t, []string{"test"}), - }, - }, - SessionId: mocks.MockSessionId, - OperationId: mocks.MockOperationId, - }, - }, - // EOF without ResultComplete (Databricks behavior) - &mocks.ExecutePlanResponseEOF, - } - - c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, responses...) - - stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT 'test'")) - require.NoError(t, err) - - recordChan, errorChan, _ := stream.ToRecordBatches(ctx) - - records := collectRecords(t, recordChan, errorChan) - require.Len(t, records, 1, "Should receive exactly one record") - - record := records[0] - assert.Equal(t, int64(1), record.NumRows()) - col := record.Column(0).(*array.String) - assert.Equal(t, "test", col.Value(0)) -} - // Helper function to create test arrow batch data func createTestArrowBatch(t *testing.T, values []string) []byte { t.Helper() @@ -978,28 +577,21 @@ func createTestArrowBatch(t *testing.T, values []string) []byte { return buf.Bytes() } -// Helper function to collect all records from channels -func collectRecords(t *testing.T, recordChan <-chan arrow.Record, errorChan <-chan error) []arrow.Record { +// Helper function to collect all records from Seq2 iterator +func collectRecordsFromSeq2(t *testing.T, iter iter.Seq2[arrow.Record, error]) []arrow.Record { t.Helper() var records []arrow.Record - timeout := time.After(100 * time.Millisecond) - for { - select { - case record, ok := <-recordChan: - if !ok { - return records - } - if record != nil { - records = append(records, record) - } - case err := <-errorChan: - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - case <-timeout: - t.Fatal("Test timed out collecting records") + for record, err := range iter { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + break + } + if record != nil { + records = append(records, record) } } + + return records } From 917ce9f6676057c4ca6fc9bbd5fa54fd3dc94f67 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Wed, 3 Sep 2025 19:05:46 +0100 Subject: [PATCH 09/37] feat: rename ToLocalIterator to StreamRows, establish RowIterator as an iter.Pull2 --- spark/client/base/base.go | 4 +- spark/client/client.go | 6 +- spark/client/client_test.go | 32 +- spark/sql/dataframe.go | 13 +- spark/sql/types/rowiterator.go | 277 +++------------ spark/sql/types/rowiterator_test.go | 532 ++++++++++++++++------------ 6 files changed, 380 insertions(+), 484 deletions(-) diff --git a/spark/client/base/base.go b/spark/client/base/base.go index 0da8c9c..ee1ae79 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -50,7 +50,7 @@ type SparkConnectClient interface { type ExecuteResponseStream interface { // ToTable consumes all arrow.Record batches to a single arrow.Table. Useful for collecting all query results into a client DF. ToTable() (*types.StructType, arrow.Table, error) - // ToRecordIterator lazily consumes each arrow.Record retrieved by a query. Useful for streaming query results. - ToRecordIterator(ctx context.Context) iter.Seq2[arrow.Record, error] + // ToRecordSequence lazily consumes each arrow.Record retrieved by a query. Useful for streaming query results. + ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] Properties() map[string]any } diff --git a/spark/client/client.go b/spark/client/client.go index 25728a8..cb24924 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -444,9 +444,9 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } -// ToRecordIterator returns a single Seq2 iterator lazily fetching -func (c *ExecutePlanClient) ToRecordIterator(ctx context.Context) iter.Seq2[arrow.Record, error] { - // Return Seq2 iterator that directly yields results as they arrive +// ToRecordSequence returns a single Seq2 iterator +func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] { + // Return Seq2 iterator that directly yields results as they arrive, upstream callers can convert this as needed iterator := func(yield func(arrow.Record, error) bool) { // Explicitly needed when tracking re-attachable execution. c.done = false diff --git a/spark/client/client_test.go b/spark/client/client_test.go index 9f1a92e..d9f9ade 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { +func TestToRecordIterator_ChannelClosureWithoutData(t *testing.T) { // Iterator should complete without yielding any records when no arrow batches present ctx := context.Background() @@ -56,7 +56,7 @@ func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) recordsReceived := 0 errorsReceived := 0 @@ -75,7 +75,7 @@ func TestToRecordBatches_ChannelClosureWithoutData(t *testing.T) { assert.Equal(t, 0, errorsReceived, "No errors should occur") } -func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { +func TestToRecordIterator_ArrowBatchStreaming(t *testing.T) { // Arrow batch data should be correctly streamed ctx := context.Background() @@ -126,7 +126,7 @@ func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) records := collectRecordsFromSeq2(t, iter) @@ -142,7 +142,7 @@ func TestToRecordBatches_ArrowBatchStreaming(t *testing.T) { assert.Equal(t, "value3", col.Value(2)) } -func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { +func TestToRecordIterator_MultipleArrowBatches(t *testing.T) { // Multiple arrow batches should be streamed in order ctx := context.Background() @@ -207,7 +207,7 @@ func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) records := collectRecordsFromSeq2(t, iter) require.Len(t, records, 2, "Should receive exactly two records") @@ -223,7 +223,7 @@ func TestToRecordBatches_MultipleArrowBatches(t *testing.T) { assert.Equal(t, "batch2_row2", col2.Value(1)) } -func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { +func TestToRecordIterator_ContextCancellationStopsStreaming(t *testing.T) { // Context cancellation should stop streaming ctx, cancel := context.WithCancel(context.Background()) @@ -259,7 +259,7 @@ func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) // Cancel the context immediately cancel() @@ -288,7 +288,7 @@ func TestToRecordBatches_ContextCancellationStopsStreaming(t *testing.T) { } } -func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { +func TestToRecordIterator_RPCErrorPropagation(t *testing.T) { // RPC errors should be properly propagated ctx := context.Background() @@ -328,7 +328,7 @@ func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) errorReceived := false for _, err := range iter { @@ -343,7 +343,7 @@ func TestToRecordBatches_RPCErrorPropagation(t *testing.T) { assert.True(t, errorReceived, "Expected RPC error") } -func TestToRecordBatches_SessionValidation(t *testing.T) { +func TestToRecordIterator_SessionValidation(t *testing.T) { // Session validation error should be returned for wrong session ID ctx := context.Background() @@ -378,7 +378,7 @@ func TestToRecordBatches_SessionValidation(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) errorReceived := false for _, err := range iter { @@ -393,7 +393,7 @@ func TestToRecordBatches_SessionValidation(t *testing.T) { assert.True(t, errorReceived, "Expected session validation error") } -func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { +func TestToRecordIterator_SqlCommandResultProperties(t *testing.T) { // SQL command results should be captured in properties ctx := context.Background() @@ -421,7 +421,7 @@ func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("test query")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) _ = collectRecordsFromSeq2(t, iter) // Properties should contain the SQL command result @@ -429,7 +429,7 @@ func TestToRecordBatches_SqlCommandResultProperties(t *testing.T) { assert.NotNil(t, props["sql_command_result"]) } -func TestToRecordBatches_MixedResponseTypes(t *testing.T) { +func TestToRecordIterator_MixedResponseTypes(t *testing.T) { // Mixed response types should be handled correctly in realistic order ctx := context.Background() @@ -535,7 +535,7 @@ func TestToRecordBatches_MixedResponseTypes(t *testing.T) { stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT * FROM table")) require.NoError(t, err) - iter := stream.ToRecordIterator(ctx) + iter := stream.ToRecordSequence(ctx) records := collectRecordsFromSeq2(t, iter) require.Len(t, records, 2, "Should receive exactly two arrow batches") diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index dd8cb26..2dc400c 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -200,6 +200,12 @@ type DataFrame interface { // Sort returns a new DataFrame sorted by the specified columns. Sort(ctx context.Context, columns ...column.Convertible) (DataFrame, error) Stat() DataFrameStatFunctions + // StreamRows exposes a pull-based iterator over Arrow record batches from Spark types.RowPull2. + // No rows are fetched from Spark over gRPC until the previous one has been consumed. + // It provides no internal buffering: each Row is produced only when the caller + // requests it, ensuring client back-pressure is respected. + // types.RowPull2 is single use (can only be ranged once). + StreamRows(ctx context.Context) (types.RowPull2, error) // Subtract subtracts the other DataFrame from the current DataFrame. And only returns // distinct rows. Subtract(ctx context.Context, other DataFrame) DataFrame @@ -214,7 +220,6 @@ type DataFrame interface { Take(ctx context.Context, limit int32) ([]types.Row, error) // ToArrow returns the Arrow representation of the DataFrame. ToArrow(ctx context.Context) (*arrow.Table, error) - ToLocalIterator(ctx context.Context) (types.RowIterator, error) // Union is an alias for UnionAll Union(ctx context.Context, other DataFrame) DataFrame // UnionAll returns a new DataFrame containing union of rows in this and another DataFrame. @@ -937,15 +942,15 @@ func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) { return &table, nil } -func (df *dataFrameImpl) ToLocalIterator(ctx context.Context) (types.RowIterator, error) { +func (df *dataFrameImpl) StreamRows(ctx context.Context) (types.RowPull2, error) { responseClient, err := df.session.client.ExecutePlan(ctx, df.createPlan()) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) } - recordChan, errorChan, schema := responseClient.ToRecordBatches(ctx) + seq2 := responseClient.ToRecordSequence(ctx) - return types.NewRowIterator(ctx, recordChan, errorChan, schema), nil + return types.NewRowPull2(ctx, seq2), nil } func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) DataFrame { diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index 02f6b98..a9393c7 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -4,258 +4,85 @@ import ( "context" "errors" "io" - "sync" - "time" + "iter" + "sync/atomic" "github.com/apache/arrow-go/v18/arrow" ) -// RowIterator provides streaming access to individual rows -type RowIterator interface { - Next() (Row, error) - io.Closer -} - -// rowIteratorImpl implements RowIterator with robust cancellation handling -type rowIteratorImpl struct { - recordChan <-chan arrow.Record - errorChan <-chan error - schema *StructType - currentRows []Row - currentIndex int - exhausted bool - closed bool - mu sync.Mutex - ctx context.Context - cancel context.CancelFunc - cleanupOnce sync.Once -} - -// NewRowIterator creates a new row iterator with the given context -func NewRowIterator(ctx context.Context, recordChan <-chan arrow.Record, errorChan <-chan error, schema *StructType) RowIterator { - // Create a cancellable context derived from the parent - iterCtx, cancel := context.WithCancel(ctx) - - return &rowIteratorImpl{ - recordChan: recordChan, - errorChan: errorChan, - schema: schema, - currentRows: nil, - currentIndex: 0, - exhausted: false, - closed: false, - ctx: iterCtx, - cancel: cancel, - } -} - -func (iter *rowIteratorImpl) Next() (Row, error) { - iter.mu.Lock() - defer iter.mu.Unlock() +type RowPull2 = iter.Seq2[Row, error] - if iter.closed { - return nil, errors.New("iterator is closed") - } - if iter.exhausted { - return nil, io.EOF - } - - // Check if context was cancelled - select { - case <-iter.ctx.Done(): - iter.exhausted = true - return nil, iter.ctx.Err() - default: - } - - // If we have rows in the current batch, return the next one - if iter.currentIndex < len(iter.currentRows) { - row := iter.currentRows[iter.currentIndex] - iter.currentIndex++ - return row, nil - } - - // Fetch the next batch - if err := iter.fetchNextBatch(); err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - iter.exhausted = true - } - return nil, err - } - - // Return the first row from the new batch - if len(iter.currentRows) == 0 { - iter.exhausted = true - return nil, io.EOF - } - - row := iter.currentRows[0] - iter.currentIndex = 1 - return row, nil -} - -// fetchNextBatch with deterministic handling to release rows before returning EOF -func (iter *rowIteratorImpl) fetchNextBatch() error { - for { - select { - case <-iter.ctx.Done(): - return iter.ctx.Err() - - case record, ok := <-iter.recordChan: - if !ok { - // Record channel is closed - check for any final error - return iter.checkErrorChannelOnClose() +// NewRowSequence flattens record batches to a sequence of rows stream. +func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { + return func(yield func(Row, error) bool) { + for rec, recErr := range recordSeq { + select { + case <-ctx.Done(): + _ = yield(nil, ctx.Err()) + return + default: } - - // We have a valid record - handle nil check - if record == nil { - continue + if recErr != nil { + // forward upstream error once, then stop + _ = yield(nil, recErr) + return + } + if rec == nil { + _ = yield(nil, errors.New("expected arrow.Record to contain non-nil Rows, got nil")) + return } - // Convert to rows and release the record immediately rows, err := func() ([]Row, error) { - defer record.Release() - return ReadArrowRecordToRows(record) + defer rec.Release() + return ReadArrowRecordToRows(rec) }() if err != nil { - return err + _ = yield(nil, err) + return } - - iter.currentRows = rows - iter.currentIndex = 0 - return nil - - case err, ok := <-iter.errorChan: - if !ok { - // Error channel closed - continue to check record channel - // Don't immediately return EOF if there are still records to process - select { - case record, ok := <-iter.recordChan: - if !ok { - // Both channels are closed - return io.EOF - } - - // We have a valid record - handle nil check - if record == nil { - continue // Skip nil records - } - - // Convert to rows and release the record immediately - rows, err := func() ([]Row, error) { - defer record.Release() - return ReadArrowRecordToRows(record) - }() - if err != nil { - return err - } - - iter.currentRows = rows - iter.currentIndex = 0 - return nil - - default: - // No immediate record available, but channel isn't closed - // Continue with the main select loop + for _, row := range rows { + if !yield(row, nil) { + return } } - - // Error received - return it (nil errors become EOF) - if err == nil { - return io.EOF - } - return err } } } -// checkErrorChannelOnClose handles error channel when record channel closes -func (iter *rowIteratorImpl) checkErrorChannelOnClose() error { - // If error channel is already closed, return EOF - select { - case err, ok := <-iter.errorChan: - if !ok || err == nil { - // Channel closed or nil error - normal EOF - return io.EOF - } - // Got actual error - return err - default: - // Error channel still open, use timeout approach - } +// NewRowPull2 iterates rows to be consumed at the clients leisure +func NewRowPull2(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { + // Build the push row stream first. + rows := NewRowSequence(ctx, recordSeq) - // Use a small timeout to check for any trailing errors - timer := time.NewTimer(50 * time.Millisecond) - defer timer.Stop() + // Enforce single-use to prevent re-iteration after stop/close. + var used atomic.Bool - select { - case err, ok := <-iter.errorChan: - if !ok || err == nil { - // Channel closed or nil error - normal EOF - return io.EOF + return func(yield func(Row, error) bool) { + if !used.CompareAndSwap(false, true) { + return } - return err - case <-timer.C: - // No error within timeout - assume normal EOF - return io.EOF - case <-iter.ctx.Done(): - // Context cancelled during wait - return iter.ctx.Err() - } -} - -func (iter *rowIteratorImpl) Close() error { - iter.mu.Lock() - if iter.closed { - iter.mu.Unlock() - return nil - } - iter.closed = true - iter.mu.Unlock() - // Cancel the context to signal any blocked operations to stop - iter.cancel() + // Convert push -> pull using the iter idiom. + next, stop := iter.Pull2(rows) + defer stop() - // Ensure cleanup happens only once - iter.cleanupOnce.Do(func() { - // Start a goroutine to drain channels - // This prevents the producer goroutine from blocking - go iter.drainChannels() - }) - - return nil -} - -// drainChannels drains both channels to prevent producer goroutine from blocking -func (iter *rowIteratorImpl) drainChannels() { - // Use a reasonable timeout for cleanup - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - for { - select { - case record, ok := <-iter.recordChan: + for { + row, err, ok := next() if !ok { - // Channel closed, check error channel one more time - select { - case <-iter.errorChan: - // Drained - case <-ctx.Done(): - // Timeout - } return } - // Release any remaining records to prevent memory leaks - if record != nil { - record.Release() - } - case <-iter.errorChan: - // Just drain, don't process - - case <-ctx.Done(): - // Cleanup timeout - exit - return + // Treat io.EOF as clean termination (don’t forward). + if errors.Is(err, io.EOF) { + return + } + if err != nil { + _ = yield(nil, err) + return + } + if !yield(row, nil) { + return + } } } } diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 0626c15..2c3d2fe 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -3,40 +3,77 @@ package types_test import ( "context" "errors" - "io" + "iter" "testing" - "time" + "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" - - "github.com/apache/arrow-go/v18/arrow" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/apache/spark-connect-go/v40/spark/sql/types" + "github.com/apache/spark-connect-go/spark/sql/types" ) +// Helper function to create test records +func createTestRecord(values []string) arrow.Record { + schema := arrow.NewSchema( + []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, + nil, + ) + + alloc := memory.NewGoAllocator() + builder := array.NewRecordBuilder(alloc, schema) + + for _, v := range values { + builder.Field(0).(*array.StringBuilder).Append(v) + } + + record := builder.NewRecord() + builder.Release() + + return record +} + +// Helper function to create a Seq2 iterator from test data +func createTestSeq2(records []arrow.Record, err error) iter.Seq2[arrow.Record, error] { + return func(yield func(arrow.Record, error) bool) { + // Yield each record + for _, record := range records { + // Retain before yielding since consumer will release + record.Retain() + if !yield(record, nil) { + return + } + } + + if err != nil { + yield(nil, err) + } + } +} + func TestRowIterator_BasicIteration(t *testing.T) { - recordChan := make(chan arrow.Record, 2) - errorChan := make(chan error, 1) - schema := &types.StructType{} + // Create test records + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + createTestRecord([]string{"row3", "row4"}), + } + + // Clean up records after test + defer func() { + for _, r := range records { + r.Release() + } + }() - // Send test records - recordChan <- createTestRecord([]string{"row1", "row2"}) - recordChan <- createTestRecord([]string{"row3", "row4"}) - close(recordChan) + seq2 := createTestSeq2(records, nil) - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() + rowIter := types.NewRowPull2(context.Background(), seq2) // Collect all rows var rows []types.Row - for { - row, err := iter.Next() - if err == io.EOF { - break - } + for row, err := range rowIter { require.NoError(t, err) rows = append(rows, row) } @@ -49,277 +86,304 @@ func TestRowIterator_BasicIteration(t *testing.T) { assert.Equal(t, "row4", rows[3].At(0)) } -func TestRowIterator_ContextCancellation(t *testing.T) { - recordChan := make(chan arrow.Record, 1) - errorChan := make(chan error, 1) - schema := &types.StructType{} - - // Send one record - recordChan <- createTestRecord([]string{"row1", "row2"}) - - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - - // Read first row successfully - row, err := iter.Next() - require.NoError(t, err) - assert.Equal(t, "row1", row.At(0)) +func TestRowIterator_EmptyResult(t *testing.T) { + // Create empty Seq2 + seq2 := func(yield func(arrow.Record, error) bool) { + // Don't yield anything - sequence is immediately over + } - // Close iterator (which cancels context) - err = iter.Close() - require.NoError(t, err) + next := types.NewRowPull2(context.Background(), seq2) - // Subsequent reads should fail with context error - _, err = iter.Next() - assert.Error(t, err) - assert.Contains(t, err.Error(), "iterator is closed") + // Should iterate zero times + count := 0 + for _, err := range next { + require.NoError(t, err) + count++ + } + assert.Equal(t, 0, count) } func TestRowIterator_ErrorPropagation(t *testing.T) { - recordChan := make(chan arrow.Record, 1) - errorChan := make(chan error, 1) - schema := &types.StructType{} - - // Send test record - recordChan <- createTestRecord([]string{"row1"}) - - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() - - // Read first row successfully - row, err := iter.Next() - require.NoError(t, err) - assert.Equal(t, "row1", row.At(0)) - - // Send error testErr := errors.New("test error") - errorChan <- testErr - close(recordChan) - // Next read should return the error - _, err = iter.Next() - assert.Equal(t, testErr, err) -} + // Create Seq2 that yields one record then an error + seq2 := func(yield func(arrow.Record, error) bool) { + record := createTestRecord([]string{"row1"}) + record.Retain() // Consumer will release + if !yield(record, nil) { + record.Release() // Clean up if yield returns false + return + } + yield(nil, testErr) + } -func TestRowIterator_EmptyResult(t *testing.T) { - recordChan := make(chan arrow.Record) - errorChan := make(chan error, 1) - schema := &types.StructType{} + next := types.NewRowPull2(context.Background(), seq2) - // Close channel immediately - close(recordChan) + var rows []types.Row + var gotError error - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() + for row, err := range next { + if err != nil { + gotError = err + break + } + rows = append(rows, row) + } - // First read should return EOF - _, err := iter.Next() - assert.Equal(t, io.EOF, err) + // Should have read first row successfully + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) - // Subsequent reads should also return EOF - _, err = iter.Next() - assert.Equal(t, io.EOF, err) + // Should have received the error + assert.Equal(t, testErr, gotError) } -func TestRowIterator_MultipleClose(t *testing.T) { - recordChan := make(chan arrow.Record) - errorChan := make(chan error, 1) - schema := &types.StructType{} +func TestRowIterator_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Create a Seq2 that yields records indefinitely + seq2 := func(yield func(arrow.Record, error) bool) { + for { + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + record := createTestRecord([]string{"row"}) + record.Retain() // Consumer will release + if !yield(record, nil) { + record.Release() // Clean up if yield returns false + return + } + } + } + } - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) + next := types.NewRowPull2(ctx, seq2) - // Close multiple times should not panic - err := iter.Close() - assert.NoError(t, err) + var rows []types.Row + count := 0 - err = iter.Close() - assert.NoError(t, err) -} + for row, err := range next { + if err != nil { + assert.ErrorIs(t, err, context.Canceled) + break + } + rows = append(rows, row) + count++ -func TestRowIterator_CloseWithPendingRecords(t *testing.T) { - recordChan := make(chan arrow.Record, 3) - errorChan := make(chan error, 1) - schema := &types.StructType{} + // Cancel after first row + if count == 1 { + cancel() + } - // Send multiple records - for i := 0; i < 3; i++ { - recordChan <- createTestRecord([]string{"row"}) + // Safety limit to prevent infinite loop + if count > 10 { + break + } } - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) + // Should have read at least one row before cancellation + assert.GreaterOrEqual(t, len(rows), 1) + assert.Equal(t, "row", rows[0].At(0)) +} - // Close without reading all records - // This should trigger the cleanup goroutine - err := iter.Close() - assert.NoError(t, err) +func TestRowIterator_EarlyBreak(t *testing.T) { + // Create multiple records + records := []arrow.Record{ + createTestRecord([]string{"row1"}), + createTestRecord([]string{"row2"}), + createTestRecord([]string{"row3"}), + } - // Give cleanup goroutine time to run - time.Sleep(100 * time.Millisecond) + // Clean up records after test + defer func() { + for _, r := range records { + r.Release() + } + }() - // Channel should be drained (this won't block if cleanup worked) - select { - case <-recordChan: - // Good, channel was drained - default: - // Also acceptable if already drained - } -} + seq2 := createTestSeq2(records, nil) -func TestRowIterator_ConcurrentAccess(t *testing.T) { - recordChan := make(chan arrow.Record, 5) - errorChan := make(chan error, 1) - schema := &types.StructType{} + next := types.NewRowPull2(context.Background(), seq2) - // Send multiple records - for i := 0; i < 5; i++ { - recordChan <- createTestRecord([]string{"row"}) + // Read only one row then break + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + if len(rows) >= 1 { + break // Early termination + } } - close(recordChan) - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() + // Should have only one row + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) +} - // Try concurrent reads (should be safe due to mutex) - done := make(chan bool, 2) +func TestRowIterator_EmptyBatchHandling(t *testing.T) { + // Test handling of empty records (0 rows but valid record) + emptyRecord := createTestRecord([]string{}) // No rows + validRecord := createTestRecord([]string{"row1"}) - go func() { - for i := 0; i < 2; i++ { - _, _ = iter.Next() + records := []arrow.Record{emptyRecord, validRecord} + defer func() { + for _, r := range records { + r.Release() } - done <- true }() - go func() { - for i := 0; i < 3; i++ { - _, _ = iter.Next() - } - done <- true - }() + seq2 := createTestSeq2(records, nil) + next := types.NewRowPull2(context.Background(), seq2) - // Wait for both goroutines - <-done - <-done + // Should skip empty batch and return row from second batch + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + } - // Should have consumed all 5 records - _, err := iter.Next() - assert.Equal(t, io.EOF, err) + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) } -func TestRowIterator_ErrorAfterRecordChannelClosed(t *testing.T) { - // Test error handling when record channel closes but error channel has data - // This mimics Databricks behavior where EOF errors can come after stream ends - recordChan := make(chan arrow.Record, 1) - errorChan := make(chan error, 1) - schema := &types.StructType{} +func TestRowIterator_DatabricksEOFBehavior(t *testing.T) { + // Test Databricks-specific behavior where io.EOF is sent as an error + // instead of using the ok=false flag to signal stream completion + + // Create Seq2 that mimics Databricks behavior + seq2 := func(yield func(arrow.Record, error) bool) { + // Send some records + record1 := createTestRecord([]string{"row1", "row2"}) + record1.Retain() + if !yield(record1, nil) { + record1.Release() + return + } - recordChan <- createTestRecord([]string{"row1"}) - close(recordChan) + record2 := createTestRecord([]string{"row3"}) + record2.Retain() + if !yield(record2, nil) { + record2.Release() + return + } - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() + // Databricks sends io.EOF as error + // This should terminate the iteration without being treated as an error + } - // Get first row - row, err := iter.Next() - require.NoError(t, err) - assert.Equal(t, "row1", row.At(0)) + next := types.NewRowPull2(context.Background(), seq2) - // Put error in channel AFTER getting the first row - testErr := errors.New("delayed error") - errorChan <- testErr + // Read all rows successfully + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + } - // Next call should return the error from error channel - _, err = iter.Next() - assert.Error(t, err) - assert.Contains(t, err.Error(), "delayed error") + // Should have all 3 rows + assert.Len(t, rows, 3) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) } -func TestRowIterator_BothChannelsClosedCleanly(t *testing.T) { - // Test clean shutdown when both channels close without errors (Databricks normal case) - recordChan := make(chan arrow.Record, 1) - errorChan := make(chan error, 1) - schema := &types.StructType{} - - recordChan <- createTestRecord([]string{"row1"}) - close(recordChan) - close(errorChan) - - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() - - row, err := iter.Next() - assert.Equal(t, "row1", row.At(0)) - assert.Nil(t, err) - // Should get EOF on next call - _, err = iter.Next() - assert.Equal(t, io.EOF, err) -} +func TestRowIterator_NilRecordReturnsError(t *testing.T) { + // Test that receiving a nil record returns an error + seq2 := func(yield func(arrow.Record, error) bool) { + record := createTestRecord([]string{"row1"}) + record.Retain() + if !yield(record, nil) { + record.Release() + return + } -func TestRowIterator_RecordReleaseOnError(t *testing.T) { - // Test that records are properly released even when conversion fails - recordChan := make(chan arrow.Record, 1) - errorChan := make(chan error, 1) - schema := &types.StructType{} - - // This would test record release, but since we can't easily make - // ReadArrowRecordToRows fail, we'll test the normal case - record := createTestRecord([]string{"row1"}) - recordChan <- record - close(recordChan) - - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() - - // Get record (this should work and release the arrow record internally) - row, err := iter.Next() - require.NoError(t, err) - assert.Equal(t, "row1", row.At(0)) - - // Verify we can't get another record - _, err = iter.Next() - assert.Equal(t, io.EOF, err) -} + // Yield nil record (shouldn't happen in production) + yield(nil, nil) + } + + next := types.NewRowPull2(context.Background(), seq2) -func TestRowIterator_ExhaustedState(t *testing.T) { - // Test that exhausted state is properly maintained - recordChan := make(chan arrow.Record) - errorChan := make(chan error, 1) - schema := &types.StructType{} + var rows []types.Row + var gotError error - close(recordChan) // No records + for row, err := range next { + if err != nil { + gotError = err + break + } + rows = append(rows, row) + } - iter := types.NewRowIterator(context.Background(), recordChan, errorChan, schema) - defer iter.Close() + // Should have read first row successfully + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) - // First call should set exhausted and return EOF - _, err := iter.Next() - assert.Equal(t, io.EOF, err) + // Should have received error about nil record + assert.Error(t, gotError) + assert.Contains(t, gotError.Error(), "expected arrow.Record to contain non-nil Rows, got nil") +} - // All subsequent calls should also return EOF (exhausted state) - for i := 0; i < 3; i++ { - _, err := iter.Next() - assert.Equal(t, io.EOF, err) +func TestRowSeq2_DirectUsage(t *testing.T) { + // Test using NewRowSequence directly as a Seq2 + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + createTestRecord([]string{"row3"}), } -} -func createTestRecord(values []string) arrow.Record { - schema := arrow.NewSchema( - []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, - nil, - ) + defer func() { + for _, r := range records { + r.Release() + } + }() - // Create a NEW allocator for each record to ensure isolation - alloc := memory.NewGoAllocator() - builder := array.NewRecordBuilder(alloc, schema) + recordSeq := createTestSeq2(records, nil) + rowSeq := types.NewRowSequence(context.Background(), recordSeq) - for _, v := range values { - builder.Field(0).(*array.StringBuilder).Append(v) + // Use the Seq2 directly with range + var rows []types.Row + for row, err := range rowSeq { + require.NoError(t, err) + rows = append(rows, row) } - record := builder.NewRecord() - // Release AFTER creating record - builder.Release() + // Should have all 3 rows flattened + assert.Len(t, rows, 3) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) +} - // Retain the record to ensure it owns its memory - record.Retain() +func TestRowIterator_MultipleIterations(t *testing.T) { + // Test that we can iterate multiple times using the same iterator + // Seq2 is reusable - each range starts the sequence fresh + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + } - return record + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + next := types.NewRowPull2(context.Background(), seq2) + + // First iteration - consume all + var rows1 []types.Row + for row, err := range next { + require.NoError(t, err) + rows1 = append(rows1, row) + } + assert.Len(t, rows1, 2) + + // Second iteration - Seq2 is pull only, should be empty + var rows2 []types.Row + for row, err := range next { + require.NoError(t, err) + rows2 = append(rows2, row) + } + assert.Len(t, rows2, 0) } From ad7e9353f85ca540c309d3e7a5640e85c8fccb52 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Wed, 3 Sep 2025 21:37:01 +0100 Subject: [PATCH 10/37] fix: golint-ci --- spark/sql/types/rowiterator_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 2c3d2fe..d913c4b 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -140,6 +140,7 @@ func TestRowIterator_ErrorPropagation(t *testing.T) { func TestRowIterator_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) // Create a Seq2 that yields records indefinitely seq2 := func(yield func(arrow.Record, error) bool) { @@ -152,7 +153,7 @@ func TestRowIterator_ContextCancellation(t *testing.T) { record := createTestRecord([]string{"row"}) record.Retain() // Consumer will release if !yield(record, nil) { - record.Release() // Clean up if yield returns false + record.Release() return } } @@ -177,13 +178,11 @@ func TestRowIterator_ContextCancellation(t *testing.T) { cancel() } - // Safety limit to prevent infinite loop if count > 10 { break } } - // Should have read at least one row before cancellation assert.GreaterOrEqual(t, len(rows), 1) assert.Equal(t, "row", rows[0].At(0)) } From d38170b33bd25e68dfb659e30ce4f44fb2ed93f8 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Wed, 3 Sep 2025 22:08:00 +0100 Subject: [PATCH 11/37] fix: improve test doc-comments --- spark/sql/types/rowiterator_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index d913c4b..6e2a6a9 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -340,7 +340,6 @@ func TestRowSeq2_DirectUsage(t *testing.T) { recordSeq := createTestSeq2(records, nil) rowSeq := types.NewRowSequence(context.Background(), recordSeq) - // Use the Seq2 directly with range var rows []types.Row for row, err := range rowSeq { require.NoError(t, err) @@ -356,7 +355,6 @@ func TestRowSeq2_DirectUsage(t *testing.T) { func TestRowIterator_MultipleIterations(t *testing.T) { // Test that we can iterate multiple times using the same iterator - // Seq2 is reusable - each range starts the sequence fresh records := []arrow.Record{ createTestRecord([]string{"row1", "row2"}), } @@ -370,7 +368,6 @@ func TestRowIterator_MultipleIterations(t *testing.T) { seq2 := createTestSeq2(records, nil) next := types.NewRowPull2(context.Background(), seq2) - // First iteration - consume all var rows1 []types.Row for row, err := range next { require.NoError(t, err) @@ -378,7 +375,8 @@ func TestRowIterator_MultipleIterations(t *testing.T) { } assert.Len(t, rows1, 2) - // Second iteration - Seq2 is pull only, should be empty + // Second iteration, Seq2 is a Pull2 so should be exhausted of rows to fetch + // https://pkg.go.dev/iter#Pull2 (Go doc defines this without an explicit type to split the difference) var rows2 []types.Row for row, err := range next { require.NoError(t, err) From a18468f60f6a8c18709b3fa0e2f5f706219b0b44 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Wed, 22 Oct 2025 22:54:35 +0100 Subject: [PATCH 12/37] feat: add tests for streaming rows in DataFrame operations including: tests for channel-based processing, filtering, error handling, empty datasets, multiple columns, and large datasets. --- internal/tests/integration/dataframe_test.go | 195 ++++++++++++++++++- 1 file changed, 194 insertions(+), 1 deletion(-) diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index d383ca1..df16620 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -1196,7 +1196,6 @@ func TestDataFrame_RangeIter(t *testing.T) { } assert.Equal(t, 10, cnt) - // Check that errors are properly propagated df, err = spark.Sql(ctx, "select if(id = 5, raise_error('handle'), false) from range(10)") assert.NoError(t, err) for _, err := range df.All(ctx) { @@ -1224,3 +1223,197 @@ func TestDataFrame_SchemaTreeString(t *testing.T) { assert.Contains(t, ts, "|-- second: array") assert.Contains(t, ts, "|-- third: map") } + +func TestDataFrame_StreamRowsThroughChannel(t *testing.T) { + // Demonstrates how StreamRows can be used to pipe data through a channel for scenarios like: + // - Proxying Spark data through gRPC streaming or unary RPCs + // - Implementing producer-consumer patterns with backpressure based on Spark results + // - Buffering and rate-limiting data flow between systems + ctx, spark := connect() + df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test_' || cast(id as string) as label from range(100)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + rowChan := make(chan map[string]interface{}, 10) + errChan := make(chan error, 1) + + go func() { + defer close(rowChan) + for row, err := range iter { + if err != nil { + errChan <- err + return + } + + rowData := make(map[string]interface{}) + names := row.FieldNames() + for i, name := range names { + rowData[name] = row.At(i) + } + + select { + case rowChan <- rowData: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } + }() + + // In a gRPC scenario, this would be your response handler + receivedRows := make([]map[string]interface{}, 0) + consumerDone := make(chan struct{}) + + go func() { + defer close(consumerDone) + for rowData := range rowChan { + receivedRows = append(receivedRows, rowData) + + id := rowData["id"].(int64) + doubled := rowData["doubled"].(int64) + assert.Equal(t, id*2, doubled) + } + }() + + <-consumerDone + + select { + case err := <-errChan: + assert.NoError(t, err) + default: + // continue + } + + assert.Equal(t, 100, len(receivedRows)) + + assert.Equal(t, int64(0), receivedRows[0]["id"]) + assert.Equal(t, int64(99), receivedRows[99]["id"]) + assert.Equal(t, "test_0", receivedRows[0]["label"]) + assert.Equal(t, "test_99", receivedRows[99]["label"]) +} + +func TestDataFrame_StreamRows(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(100)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + assert.NotNil(t, iter) + + cnt := 0 + + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 1, row.Len()) + cnt++ + } + assert.Equal(t, 100, cnt) +} + +func TestDataFrame_StreamRowsWithFilter(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(100)") + assert.NoError(t, err) + + df, err = df.Filter(ctx, functions.Col("id").Lt(functions.IntLit(10))) + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 1, row.Len()) + assert.Less(t, row.At(0).(int64), int64(10)) + cnt++ + } + assert.Equal(t, 10, cnt) +} + +func TestDataFrame_StreamRowsEmpty(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(0)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + cnt++ + } + assert.Equal(t, 0, cnt) +} + +func TestDataFrame_StreamRowsWithError(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select if(id = 5, raise_error('test error'), id) as id from range(10)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + errorEncountered := false + for _, err := range iter { + if err != nil { + errorEncountered = true + assert.Error(t, err) + break + } + } + assert.True(t, errorEncountered, "Expected to encounter an error during iteration") +} + +func TestDataFrame_StreamRowsMultipleColumns(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test' as name from range(50)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 3, row.Len()) + + id := row.At(0).(int64) + doubled := row.At(1).(int64) + name := row.At(2).(string) + + assert.Equal(t, id*2, doubled) + assert.Equal(t, "test", name) + cnt++ + } + assert.Equal(t, 50, cnt) +} + +func TestDataFrame_StreamRowsLargeDataset(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(10000)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + lastValue := int64(-1) + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + currentValue := row.At(0).(int64) + assert.Greater(t, currentValue, lastValue) + lastValue = currentValue + cnt++ + } + assert.Equal(t, 10000, cnt) +} From 434a579d667e699492bb55ac93e539d512e1986c Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Wed, 22 Oct 2025 23:04:28 +0100 Subject: [PATCH 13/37] fix: update Spark version to 4.0.1 in build workflow --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 249bd33..ef218ab 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,7 +32,7 @@ on: - master env: - SPARK_VERSION: '4.0.0' + SPARK_VERSION: '4.0.1' HADOOP_VERSION: '3' permissions: From 0432bde7911fe8cf96466bdf6803db82af81c20f Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 1 Mar 2026 16:29:19 +0000 Subject: [PATCH 14/37] fix: remove debug print lines from ToTable() --- spark/client/client.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/spark/client/client.go b/spark/client/client.go index cb24924..e5fc371 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -369,15 +369,6 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { c.done = false for { resp, err := c.responseStream.Recv() - if err != nil { - fmt.Printf("DEBUG: Recv error: %v, is EOF: %v\n", err, errors.Is(err, io.EOF)) - } - if err == nil && resp != nil { - fmt.Printf("DEBUG: Received response type: %T\n", resp.ResponseType) - if _, ok := resp.ResponseType.(*proto.ExecutePlanResponse_ResultComplete_); ok { - fmt.Println("DEBUG: Got ResultComplete!") - } - } // EOF is received when the last message has been processed and the stream // finished normally. if errors.Is(err, io.EOF) { From 928e9b3beb5269d62424c8922a31700894d6a70b Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 1 Mar 2026 16:29:53 +0000 Subject: [PATCH 15/37] fix: remove c.done race condition in ToRecordSequence --- spark/client/client.go | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/spark/client/client.go b/spark/client/client.go index e5fc371..65da4b8 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -435,18 +435,15 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } -// ToRecordSequence returns a single Seq2 iterator +// ToRecordSequence returns a single Seq2 iterator that directly yields results as they arrive. func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] { - // Return Seq2 iterator that directly yields results as they arrive, upstream callers can convert this as needed - iterator := func(yield func(arrow.Record, error) bool) { - // Explicitly needed when tracking re-attachable execution. - c.done = false + return func(yield func(arrow.Record, error) bool) { + // Track logical completion locally to avoid racing on shared struct state. + done := false for { - // Check for context cancellation before each iteration select { case <-ctx.Done(): - // Yield the context error and stop yield(nil, ctx.Err()) return default: @@ -454,7 +451,6 @@ func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arro resp, err := c.responseStream.Recv() - // Check for context cancellation after potentially blocking operations select { case <-ctx.Done(): yield(nil, ctx.Err()) @@ -462,27 +458,23 @@ func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arro default: } - // EOF is received when the last message has been processed (Observed on Databricks instances) if errors.Is(err, io.EOF) { - return // Clean end of stream + break } - // Handle other errors if err != nil { if se := sparkerrors.FromRPCError(err); se != nil { yield(nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)) } else { yield(nil, err) } - return // Stop on error + return } - // Only proceed if we have a valid response if resp == nil { continue } - // Validate session ID if resp.GetSessionId() != c.sessionId { yield(nil, sparkerrors.WithType( &sparkerrors.InvalidServerSideSessionDetailsError{ @@ -492,7 +484,6 @@ func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arro return } - // Process schema if present if resp.Schema != nil { var schemaErr error c.schema, schemaErr = types.ConvertProtoDataTypeToStructType(resp.Schema) @@ -502,7 +493,6 @@ func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arro } } - // Process response types switch x := resp.ResponseType.(type) { case *proto.ExecutePlanResponse_SqlCommandResult_: if val := x.SqlCommandResult.GetRelation(); val != nil { @@ -515,16 +505,12 @@ func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arro yield(nil, err) return } - - // Yield the record and check if consumer wants to continue if !yield(record, nil) { - // Consumer stopped iteration early - // Note: Consumer is responsible for releasing the record return } case *proto.ExecutePlanResponse_ResultComplete_: - c.done = true + done = true return case *proto.ExecutePlanResponse_ExecutionProgress_: @@ -534,9 +520,14 @@ func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arro // Explicitly ignore unknown message types } } - } - return iterator + // Check that the result is logically complete. With re-attachable execution + // the server may interrupt the connection, and we need a ResultComplete + // message to confirm the full result was received. + if c.opts.ReattachExecution && !done { + yield(nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError)) + } + } } func NewExecuteResponseStream( From 146e423a22567c90e6369dbd335330021022bbff Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 1 Mar 2026 16:31:14 +0000 Subject: [PATCH 16/37] fix: remove NewRowPull2, fold EOF handling into NewRowSequence --- spark/sql/dataframe.go | 9 +++-- spark/sql/types/rowiterator.go | 48 ++++----------------------- spark/sql/types/rowiterator_test.go | 51 ++++++++++++++++++----------- 3 files changed, 43 insertions(+), 65 deletions(-) diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 2dc400c..b827af9 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -200,12 +200,11 @@ type DataFrame interface { // Sort returns a new DataFrame sorted by the specified columns. Sort(ctx context.Context, columns ...column.Convertible) (DataFrame, error) Stat() DataFrameStatFunctions - // StreamRows exposes a pull-based iterator over Arrow record batches from Spark types.RowPull2. + // StreamRows returns a lazy iterator over rows from Spark. // No rows are fetched from Spark over gRPC until the previous one has been consumed. // It provides no internal buffering: each Row is produced only when the caller // requests it, ensuring client back-pressure is respected. - // types.RowPull2 is single use (can only be ranged once). - StreamRows(ctx context.Context) (types.RowPull2, error) + StreamRows(ctx context.Context) (iter.Seq2[types.Row, error], error) // Subtract subtracts the other DataFrame from the current DataFrame. And only returns // distinct rows. Subtract(ctx context.Context, other DataFrame) DataFrame @@ -942,7 +941,7 @@ func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) { return &table, nil } -func (df *dataFrameImpl) StreamRows(ctx context.Context) (types.RowPull2, error) { +func (df *dataFrameImpl) StreamRows(ctx context.Context) (iter.Seq2[types.Row, error], error) { responseClient, err := df.session.client.ExecutePlan(ctx, df.createPlan()) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) @@ -950,7 +949,7 @@ func (df *dataFrameImpl) StreamRows(ctx context.Context) (types.RowPull2, error) seq2 := responseClient.ToRecordSequence(ctx) - return types.NewRowPull2(ctx, seq2), nil + return types.NewRowSequence(ctx, seq2), nil } func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) DataFrame { diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index a9393c7..a86a5f4 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -5,13 +5,10 @@ import ( "errors" "io" "iter" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" ) -type RowPull2 = iter.Seq2[Row, error] - // NewRowSequence flattens record batches to a sequence of rows stream. func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { return func(yield func(Row, error) bool) { @@ -22,6 +19,13 @@ func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error return default: } + + // Treat io.EOF as clean stream termination. Some Spark + // implementations (notably Databricks clusters as of 05/2025) + // yield EOF as an error value instead of ending the sequence. + if errors.Is(recErr, io.EOF) { + return + } if recErr != nil { // forward upstream error once, then stop _ = yield(nil, recErr) @@ -48,41 +52,3 @@ func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error } } } - -// NewRowPull2 iterates rows to be consumed at the clients leisure -func NewRowPull2(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { - // Build the push row stream first. - rows := NewRowSequence(ctx, recordSeq) - - // Enforce single-use to prevent re-iteration after stop/close. - var used atomic.Bool - - return func(yield func(Row, error) bool) { - if !used.CompareAndSwap(false, true) { - return - } - - // Convert push -> pull using the iter idiom. - next, stop := iter.Pull2(rows) - defer stop() - - for { - row, err, ok := next() - if !ok { - return - } - - // Treat io.EOF as clean termination (don’t forward). - if errors.Is(err, io.EOF) { - return - } - if err != nil { - _ = yield(nil, err) - return - } - if !yield(row, nil) { - return - } - } - } -} diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 6e2a6a9..0b8b11c 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -3,6 +3,7 @@ package types_test import ( "context" "errors" + "io" "iter" "testing" @@ -69,7 +70,7 @@ func TestRowIterator_BasicIteration(t *testing.T) { seq2 := createTestSeq2(records, nil) - rowIter := types.NewRowPull2(context.Background(), seq2) + rowIter := types.NewRowSequence(context.Background(), seq2) // Collect all rows var rows []types.Row @@ -92,7 +93,7 @@ func TestRowIterator_EmptyResult(t *testing.T) { // Don't yield anything - sequence is immediately over } - next := types.NewRowPull2(context.Background(), seq2) + next := types.NewRowSequence(context.Background(), seq2) // Should iterate zero times count := 0 @@ -117,7 +118,7 @@ func TestRowIterator_ErrorPropagation(t *testing.T) { yield(nil, testErr) } - next := types.NewRowPull2(context.Background(), seq2) + next := types.NewRowSequence(context.Background(), seq2) var rows []types.Row var gotError error @@ -160,7 +161,7 @@ func TestRowIterator_ContextCancellation(t *testing.T) { } } - next := types.NewRowPull2(ctx, seq2) + next := types.NewRowSequence(ctx, seq2) var rows []types.Row count := 0 @@ -204,7 +205,7 @@ func TestRowIterator_EarlyBreak(t *testing.T) { seq2 := createTestSeq2(records, nil) - next := types.NewRowPull2(context.Background(), seq2) + next := types.NewRowSequence(context.Background(), seq2) // Read only one row then break var rows []types.Row @@ -234,7 +235,7 @@ func TestRowIterator_EmptyBatchHandling(t *testing.T) { }() seq2 := createTestSeq2(records, nil) - next := types.NewRowPull2(context.Background(), seq2) + next := types.NewRowSequence(context.Background(), seq2) // Should skip empty batch and return row from second batch var rows []types.Row @@ -249,11 +250,9 @@ func TestRowIterator_EmptyBatchHandling(t *testing.T) { func TestRowIterator_DatabricksEOFBehavior(t *testing.T) { // Test Databricks-specific behavior where io.EOF is sent as an error - // instead of using the ok=false flag to signal stream completion - - // Create Seq2 that mimics Databricks behavior + // value rather than just ending the sequence. NewRowSequence treats + // io.EOF as clean termination. seq2 := func(yield func(arrow.Record, error) bool) { - // Send some records record1 := createTestRecord([]string{"row1", "row2"}) record1.Retain() if !yield(record1, nil) { @@ -268,11 +267,11 @@ func TestRowIterator_DatabricksEOFBehavior(t *testing.T) { return } - // Databricks sends io.EOF as error - // This should terminate the iteration without being treated as an error + // Databricks sends io.EOF as error — should terminate cleanly + yield(nil, io.EOF) } - next := types.NewRowPull2(context.Background(), seq2) + next := types.NewRowSequence(context.Background(), seq2) // Read all rows successfully var rows []types.Row @@ -302,7 +301,7 @@ func TestRowIterator_NilRecordReturnsError(t *testing.T) { yield(nil, nil) } - next := types.NewRowPull2(context.Background(), seq2) + next := types.NewRowSequence(context.Background(), seq2) var rows []types.Row var gotError error @@ -354,7 +353,8 @@ func TestRowSeq2_DirectUsage(t *testing.T) { } func TestRowIterator_MultipleIterations(t *testing.T) { - // Test that we can iterate multiple times using the same iterator + // Test that ranging the same iterator twice works safely when the + // upstream is single-use (like a real gRPC stream). records := []arrow.Record{ createTestRecord([]string{"row1", "row2"}), } @@ -365,8 +365,22 @@ func TestRowIterator_MultipleIterations(t *testing.T) { } }() - seq2 := createTestSeq2(records, nil) - next := types.NewRowPull2(context.Background(), seq2) + // Build a single-use upstream to simulate a gRPC stream. + exhausted := false + seq2 := func(yield func(arrow.Record, error) bool) { + if exhausted { + return + } + exhausted = true + for _, record := range records { + record.Retain() + if !yield(record, nil) { + return + } + } + } + + next := types.NewRowSequence(context.Background(), seq2) var rows1 []types.Row for row, err := range next { @@ -375,8 +389,7 @@ func TestRowIterator_MultipleIterations(t *testing.T) { } assert.Len(t, rows1, 2) - // Second iteration, Seq2 is a Pull2 so should be exhausted of rows to fetch - // https://pkg.go.dev/iter#Pull2 (Go doc defines this without an explicit type to split the difference) + // Second iteration — upstream exhausted, should yield nothing var rows2 []types.Row for row, err := range next { require.NoError(t, err) From aa4b293dd63760c3bb44e5ae08e83a19f4c1cae6 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 1 Mar 2026 16:42:27 +0000 Subject: [PATCH 17/37] fix: extract rowIterFromRecord to simplify NewRowSequence --- spark/sql/types/rowiterator.go | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index a86a5f4..03b0bcd 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -9,6 +9,24 @@ import ( "github.com/apache/arrow-go/v18/arrow" ) +// rowIterFromRecord converts an Arrow record into a row iterator, +// releasing the record when iteration completes or the consumer stops. +func rowIterFromRecord(rec arrow.Record) iter.Seq2[Row, error] { + return func(yield func(Row, error) bool) { + defer rec.Release() + rows, err := ReadArrowRecordToRows(rec) + if err != nil { + _ = yield(nil, err) + return + } + for _, row := range rows { + if !yield(row, nil) { + return + } + } + } +} + // NewRowSequence flattens record batches to a sequence of rows stream. func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { return func(yield func(Row, error) bool) { @@ -27,7 +45,6 @@ func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error return } if recErr != nil { - // forward upstream error once, then stop _ = yield(nil, recErr) return } @@ -36,16 +53,8 @@ func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error return } - rows, err := func() ([]Row, error) { - defer rec.Release() - return ReadArrowRecordToRows(rec) - }() - if err != nil { - _ = yield(nil, err) - return - } - for _, row := range rows { - if !yield(row, nil) { + for row, err := range rowIterFromRecord(rec) { + if !yield(row, err) || err != nil { return } } From fb2a9aaeef09d8577f5871240f9535f4c95e8fe7 Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Sun, 1 Mar 2026 16:49:35 +0000 Subject: [PATCH 18/37] fix: prefer explicit error yield --- spark/client/client.go | 4 +++- spark/sql/types/rowiterator.go | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/spark/client/client.go b/spark/client/client.go index 65da4b8..7eceee1 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -438,7 +438,9 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { // ToRecordSequence returns a single Seq2 iterator that directly yields results as they arrive. func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] { return func(yield func(arrow.Record, error) bool) { - // Track logical completion locally to avoid racing on shared struct state. + // Represents Spark's reattachable execution. + // Tracks logical completion locally to avoid racing on shared struct state. + // Spliced from ToTable. We may eventually want to DRY up these workflows. done := false for { diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index 03b0bcd..61ac790 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -54,7 +54,11 @@ func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error } for row, err := range rowIterFromRecord(rec) { - if !yield(row, err) || err != nil { + if err != nil { + _ = yield(nil, err) + return + } + if !yield(row, nil) { return } } From b29e5ef50b6bdc0cc91ea9593e581871ab62ccfa Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Thu, 5 Mar 2026 22:55:33 +0000 Subject: [PATCH 19/37] fix: address feedback --- spark/sql/types/rowiterator.go | 9 +++------ spark/sql/types/rowiterator_test.go | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go index 61ac790..7f9d451 100644 --- a/spark/sql/types/rowiterator.go +++ b/spark/sql/types/rowiterator.go @@ -49,16 +49,13 @@ func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error return } if rec == nil { - _ = yield(nil, errors.New("expected arrow.Record to contain non-nil Rows, got nil")) + _ = yield(nil, errors.New("expected non-nil arrow.Record, got nil")) return } for row, err := range rowIterFromRecord(rec) { - if err != nil { - _ = yield(nil, err) - return - } - if !yield(row, nil) { + cont := yield(row, err) + if err != nil || !cont { return } } diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 0b8b11c..b9a2f46 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -320,7 +320,7 @@ func TestRowIterator_NilRecordReturnsError(t *testing.T) { // Should have received error about nil record assert.Error(t, gotError) - assert.Contains(t, gotError.Error(), "expected arrow.Record to contain non-nil Rows, got nil") + assert.Contains(t, gotError.Error(), "expected non-nil arrow.Record, got nil") } func TestRowSeq2_DirectUsage(t *testing.T) { From 8cc399ea4d5e7be51b01bc51bb79296d530c2af8 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:13:19 +0100 Subject: [PATCH 20/37] refactor/7: rename module path to github.com/caldempsey/spark-connect-go (#1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Declares the module as github.com/caldempsey/spark-connect-go, drops the /v40 suffix, and updates every import in the tree to match. Consumers that want the fork's in-flight fixes no longer need a replace directive in their go.mod — go get against the caldempsey path resolves directly. Rebases against upstream stay mechanical: merge upstream in, sed module paths, resolve in the usual spots. --- README.md | 4 ++-- cmd/spark-connect-example-raw-grpc-client/main.go | 2 +- cmd/spark-connect-example-spark-session/main.go | 8 ++++---- go.mod | 2 +- internal/tests/integration/dataframe_test.go | 10 +++++----- internal/tests/integration/functions_test.go | 6 +++--- internal/tests/integration/helper.go | 2 +- internal/tests/integration/spark_runner.go | 2 +- internal/tests/integration/sql_test.go | 8 ++++---- quick-start.md | 4 ++-- spark/client/base/base.go | 6 +++--- spark/client/channel/channel.go | 4 ++-- spark/client/channel/channel_test.go | 4 ++-- spark/client/client.go | 14 +++++++------- spark/client/client_test.go | 8 ++++---- spark/client/conf.go | 4 ++-- spark/client/retry.go | 8 ++++---- spark/client/retry_test.go | 8 ++++---- spark/client/testutils/utils.go | 2 +- spark/mocks/mock_executor.go | 8 ++++---- spark/mocks/mocks.go | 2 +- spark/sql/column/column.go | 4 ++-- spark/sql/column/column_test.go | 2 +- spark/sql/column/expressions.go | 6 +++--- spark/sql/column/expressions_test.go | 2 +- spark/sql/dataframe.go | 10 +++++----- spark/sql/dataframe_test.go | 4 ++-- spark/sql/dataframenafunctions.go | 2 +- spark/sql/dataframewriter.go | 4 ++-- spark/sql/dataframewriter_test.go | 6 +++--- spark/sql/functions/buiitins.go | 4 ++-- spark/sql/functions/generated.go | 2 +- spark/sql/group.go | 10 +++++----- spark/sql/group_test.go | 8 ++++---- spark/sql/plan.go | 2 +- spark/sql/sparksession.go | 14 +++++++------- spark/sql/sparksession_integration_test.go | 2 +- spark/sql/sparksession_test.go | 10 +++++----- spark/sql/types/arrow.go | 4 ++-- spark/sql/types/arrow_test.go | 4 ++-- spark/sql/types/builtin.go | 2 +- spark/sql/types/conversion.go | 4 ++-- spark/sql/types/conversion_test.go | 4 ++-- spark/sql/types/rowiterator_test.go | 2 +- spark/sql/utils/consts.go | 2 +- 45 files changed, 115 insertions(+), 115 deletions(-) diff --git a/README.md b/README.md index 51d674f..5a502a9 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Step 3: Run the following commands to setup the Spark Connect client. Building with Spark in case you need to re-generate the source files from the proto sources. ``` -git clone https://github.com/apache/spark-connect-go.git +git clone https://github.com/caldempsey/spark-connect-go.git git submodule update --init --recursive make gen && make test @@ -34,7 +34,7 @@ make gen && make test Building without Spark ``` -git clone https://github.com/apache/spark-connect-go.git +git clone https://github.com/caldempsey/spark-connect-go.git make && make test ``` diff --git a/cmd/spark-connect-example-raw-grpc-client/main.go b/cmd/spark-connect-example-raw-grpc-client/main.go index 1f463db..ab3c405 100644 --- a/cmd/spark-connect-example-raw-grpc-client/main.go +++ b/cmd/spark-connect-example-raw-grpc-client/main.go @@ -22,7 +22,7 @@ import ( "log" "time" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index ec720dc..a05758b 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -22,12 +22,12 @@ import ( "fmt" "log" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/caldempsey/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/caldempsey/spark-connect-go/spark/sql" + "github.com/caldempsey/spark-connect-go/spark/sql/utils" ) var ( diff --git a/go.mod b/go.mod index 3478ddb..379f88e 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -module github.com/apache/spark-connect-go +module github.com/caldempsey/spark-connect-go go 1.23.2 diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index df16620..9d4dde1 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -21,15 +21,15 @@ import ( "os" "testing" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/caldempsey/spark-connect-go/spark/sql/utils" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql/column" + "github.com/caldempsey/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/caldempsey/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/caldempsey/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/tests/integration/functions_test.go b/internal/tests/integration/functions_test.go index 94310d7..18c8cb2 100644 --- a/internal/tests/integration/functions_test.go +++ b/internal/tests/integration/functions_test.go @@ -19,11 +19,11 @@ import ( "context" "testing" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/caldempsey/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/caldempsey/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/internal/tests/integration/helper.go b/internal/tests/integration/helper.go index 902d223..df0b630 100644 --- a/internal/tests/integration/helper.go +++ b/internal/tests/integration/helper.go @@ -22,7 +22,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/caldempsey/spark-connect-go/spark/sql" ) func connect() (context.Context, sql.SparkSession) { diff --git a/internal/tests/integration/spark_runner.go b/internal/tests/integration/spark_runner.go index b6cb688..5796456 100644 --- a/internal/tests/integration/spark_runner.go +++ b/internal/tests/integration/spark_runner.go @@ -23,7 +23,7 @@ import ( "os/exec" "time" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) func StartSparkConnect() (int64, error) { diff --git a/internal/tests/integration/sql_test.go b/internal/tests/integration/sql_test.go index a0e0493..395b285 100644 --- a/internal/tests/integration/sql_test.go +++ b/internal/tests/integration/sql_test.go @@ -22,13 +22,13 @@ import ( "os" "testing" - "github.com/apache/spark-connect-go/spark/sql/column" + "github.com/caldempsey/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/caldempsey/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/caldempsey/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/quick-start.md b/quick-start.md index 7382107..b51f3ee 100644 --- a/quick-start.md +++ b/quick-start.md @@ -5,7 +5,7 @@ In your Go project `go.mod` file, add `spark-connect-go` library: ``` require ( - github.com/apache/spark-connect-go master + github.com/caldempsey/spark-connect-go master ) ``` @@ -23,7 +23,7 @@ import ( "fmt" "log" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/caldempsey/spark-connect-go/spark/sql" ) var ( diff --git a/spark/client/base/base.go b/spark/client/base/base.go index ee1ae79..aa554a0 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -19,11 +19,11 @@ import ( "context" "iter" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/caldempsey/spark-connect-go/spark/sql/utils" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sql/types" ) type SparkConnectRPCClient generated.SparkConnectServiceClient diff --git a/spark/client/channel/channel.go b/spark/client/channel/channel.go index 6403566..cfca0f3 100644 --- a/spark/client/channel/channel.go +++ b/spark/client/channel/channel.go @@ -29,13 +29,13 @@ import ( "strconv" "strings" - "github.com/apache/spark-connect-go/spark" + "github.com/caldempsey/spark-connect-go/spark" "github.com/google/uuid" "google.golang.org/grpc/credentials/insecure" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" diff --git a/spark/client/channel/channel_test.go b/spark/client/channel/channel_test.go index b0f7bea..e4abfc8 100644 --- a/spark/client/channel/channel_test.go +++ b/spark/client/channel/channel_test.go @@ -23,8 +23,8 @@ import ( "github.com/google/uuid" - "github.com/apache/spark-connect-go/spark/client/channel" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/client/channel" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" ) diff --git a/spark/client/client.go b/spark/client/client.go index 7eceee1..2b0c08e 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -22,24 +22,24 @@ import ( "io" "iter" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/caldempsey/spark-connect-go/spark/sql/utils" "google.golang.org/grpc" "google.golang.org/grpc/metadata" - "github.com/apache/spark-connect-go/spark/client/base" - "github.com/apache/spark-connect-go/spark/mocks" + "github.com/caldempsey/spark-connect-go/spark/client/base" + "github.com/caldempsey/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/caldempsey/spark-connect-go/spark/client/options" "github.com/google/uuid" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) type sparkConnectClientImpl struct { diff --git a/spark/client/client_test.go b/spark/client/client_test.go index d9f9ade..c9dc54b 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -12,10 +12,10 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/client" + "github.com/caldempsey/spark-connect-go/spark/mocks" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/spark/client/conf.go b/spark/client/conf.go index 11b301e..22a81e2 100644 --- a/spark/client/conf.go +++ b/spark/client/conf.go @@ -18,8 +18,8 @@ package client import ( "context" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client/base" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/client/base" ) // Public interface RuntimeConfig diff --git a/spark/client/retry.go b/spark/client/retry.go index be7d90c..016aa46 100644 --- a/spark/client/retry.go +++ b/spark/client/retry.go @@ -23,13 +23,13 @@ import ( "strings" "time" - "github.com/apache/spark-connect-go/spark/client/base" + "github.com/caldempsey/spark-connect-go/spark/client/base" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/caldempsey/spark-connect-go/spark/client/options" "google.golang.org/grpc/metadata" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) diff --git a/spark/client/retry_test.go b/spark/client/retry_test.go index a9526e5..83801db 100644 --- a/spark/client/retry_test.go +++ b/spark/client/retry_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/caldempsey/spark-connect-go/spark/client/options" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/client/testutils" + "github.com/caldempsey/spark-connect-go/spark/mocks" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go index c0b3bb5..e38a33b 100644 --- a/spark/client/testutils/utils.go +++ b/spark/client/testutils/utils.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" "google.golang.org/grpc" ) diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go index 600e9e0..fc4e9a6 100644 --- a/spark/mocks/mock_executor.go +++ b/spark/mocks/mock_executor.go @@ -19,13 +19,13 @@ import ( "context" "errors" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/caldempsey/spark-connect-go/spark/sql/utils" - "github.com/apache/spark-connect-go/spark/client/base" + "github.com/caldempsey/spark-connect-go/spark/client/base" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sql/types" ) type TestExecutor struct { diff --git a/spark/mocks/mocks.go b/spark/mocks/mocks.go index 3a313f2..c6232a9 100644 --- a/spark/mocks/mocks.go +++ b/spark/mocks/mocks.go @@ -25,7 +25,7 @@ import ( "github.com/google/uuid" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" "google.golang.org/grpc/metadata" ) diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go index 79e2fad..13eb6a8 100644 --- a/spark/sql/column/column.go +++ b/spark/sql/column/column.go @@ -18,9 +18,9 @@ package column import ( "context" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" ) // Convertible is the interface for all things that can be converted into a protobuf expression. diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go index fa97e80..c62cbde 100644 --- a/spark/sql/column/column_test.go +++ b/spark/sql/column/column_test.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go index 399a611..ca0bfcb 100644 --- a/spark/sql/column/expressions.go +++ b/spark/sql/column/expressions.go @@ -20,11 +20,11 @@ import ( "fmt" "strings" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" ) func newProtoExpression() *proto.Expression { diff --git a/spark/sql/column/expressions_test.go b/spark/sql/column/expressions_test.go index 836d5a8..4fb59e3 100644 --- a/spark/sql/column/expressions_test.go +++ b/spark/sql/column/expressions_test.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index b827af9..4374529 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -22,14 +22,14 @@ import ( "math/rand/v2" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/caldempsey/spark-connect-go/spark/sql/utils" - "github.com/apache/spark-connect-go/spark/sql/column" + "github.com/caldempsey/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) // ResultCollector receives a stream of result rows diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go index 0bee27b..49fb11f 100644 --- a/spark/sql/dataframe_test.go +++ b/spark/sql/dataframe_test.go @@ -20,8 +20,8 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/functions" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sql/functions" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframenafunctions.go b/spark/sql/dataframenafunctions.go index 9845bb1..23f827f 100644 --- a/spark/sql/dataframenafunctions.go +++ b/spark/sql/dataframenafunctions.go @@ -18,7 +18,7 @@ package sql import ( "context" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" ) type DataFrameNaFunctions interface { diff --git a/spark/sql/dataframewriter.go b/spark/sql/dataframewriter.go index 8c096f8..6f98649 100644 --- a/spark/sql/dataframewriter.go +++ b/spark/sql/dataframewriter.go @@ -21,8 +21,8 @@ import ( "fmt" "strings" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) // DataFrameWriter supports writing data frame to storage. diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index bc85f65..8a6775c 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - "github.com/apache/spark-connect-go/spark/client" + "github.com/caldempsey/spark-connect-go/spark/client" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/mocks" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go index 4dca8bf..8f07d82 100644 --- a/spark/sql/functions/buiitins.go +++ b/spark/sql/functions/buiitins.go @@ -16,8 +16,8 @@ package functions import ( - "github.com/apache/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/column" + "github.com/caldempsey/spark-connect-go/spark/sql/types" ) func Expr(expr string) column.Column { diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go index 844fa61..f66f9a7 100644 --- a/spark/sql/functions/generated.go +++ b/spark/sql/functions/generated.go @@ -15,7 +15,7 @@ package functions -import "github.com/apache/spark-connect-go/spark/sql/column" +import "github.com/caldempsey/spark-connect-go/spark/sql/column" // BitwiseNOT - Computes bitwise not. // diff --git a/spark/sql/group.go b/spark/sql/group.go index 3943969..125fee8 100644 --- a/spark/sql/group.go +++ b/spark/sql/group.go @@ -19,12 +19,12 @@ package sql import ( "context" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" - "github.com/apache/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/functions" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/sql/column" + "github.com/caldempsey/spark-connect-go/spark/sql/functions" ) type GroupedData struct { diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go index fe02b49..f6e6c72 100644 --- a/spark/sql/group_test.go +++ b/spark/sql/group_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/client" + "github.com/caldempsey/spark-connect-go/spark/client/testutils" + "github.com/caldempsey/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/plan.go b/spark/sql/plan.go index 022298f..89709ae 100644 --- a/spark/sql/plan.go +++ b/spark/sql/plan.go @@ -19,7 +19,7 @@ package sql import ( "sync/atomic" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" ) var atomicInt64 atomic.Int64 diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index a84bb61..433a1a7 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -23,19 +23,19 @@ import ( "time" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apache/spark-connect-go/spark/client/base" + "github.com/caldempsey/spark-connect-go/spark/client/base" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/caldempsey/spark-connect-go/spark/client/options" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/channel" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/client" + "github.com/caldempsey/spark-connect-go/spark/client/channel" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "github.com/google/uuid" "google.golang.org/grpc/metadata" ) diff --git a/spark/sql/sparksession_integration_test.go b/spark/sql/sparksession_integration_test.go index c23d671..9512fd4 100644 --- a/spark/sql/sparksession_integration_test.go +++ b/spark/sql/sparksession_integration_test.go @@ -21,7 +21,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 11539af..9dedfae 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -29,11 +29,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/client" + "github.com/caldempsey/spark-connect-go/spark/client/testutils" + "github.com/caldempsey/spark-connect-go/spark/mocks" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) func TestSparkSessionTable(t *testing.T) { diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index 6c49bd0..0bfc1ff 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -20,13 +20,13 @@ import ( "bytes" "fmt" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index 92d491b..6d88c66 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -30,8 +30,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sql/types" ) func TestShowArrowBatchData(t *testing.T) { diff --git a/spark/sql/types/builtin.go b/spark/sql/types/builtin.go index 1f74695..944beca 100644 --- a/spark/sql/types/builtin.go +++ b/spark/sql/types/builtin.go @@ -19,7 +19,7 @@ package types import ( "context" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/caldempsey/spark-connect-go/internal/generated" ) type LiteralType interface { diff --git a/spark/sql/types/conversion.go b/spark/sql/types/conversion.go index b2652e2..92a2fc5 100644 --- a/spark/sql/types/conversion.go +++ b/spark/sql/types/conversion.go @@ -19,8 +19,8 @@ package types import ( "errors" - "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sparkerrors" ) func ConvertProtoDataTypeToStructType(input *generated.DataType) (*StructType, error) { diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go index cd62779..109dc92 100644 --- a/spark/sql/types/conversion_test.go +++ b/spark/sql/types/conversion_test.go @@ -19,8 +19,8 @@ package types_test import ( "testing" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + proto "github.com/caldempsey/spark-connect-go/internal/generated" + "github.com/caldempsey/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index b9a2f46..bc39e1d 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/caldempsey/spark-connect-go/spark/sql/types" ) // Helper function to create test records diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go index e1312ef..3dd5f61 100644 --- a/spark/sql/utils/consts.go +++ b/spark/sql/utils/consts.go @@ -15,7 +15,7 @@ package utils -import proto "github.com/apache/spark-connect-go/internal/generated" +import proto "github.com/caldempsey/spark-connect-go/internal/generated" type ExplainMode int From 6799238a4e921630b413639931f8bd82140702fb Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:13:33 +0100 Subject: [PATCH 21/37] fix/5: run build workflow on push to main (#6) The fork's work lives on main, but the build workflow only triggers on push to master. Point the push filter at main so merges to the fork's default branch actually get CI signal. The pull_request trigger already accepts any base branch and doesn't need to change. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ef218ab..51a8898 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,7 +29,7 @@ on: pull_request: push: branches: - - master + - main env: SPARK_VERSION: '4.0.1' From 55b4050e88dfff71bf0015b228d92897595d0694 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:24:11 +0100 Subject: [PATCH 22/37] fix/8: fall back to archive.apache.org when dlcdn 404s (#9) * fix/8: fall back to archive.apache.org when dlcdn 404s dlcdn.apache.org only mirrors current releases and rotates older ones off without warning, which blocks every CI run behind a transient 404. archive.apache.org is the canonical Apache mirror and never rotates. Try dlcdn first for speed, fall back to archive on failure. The cache is keyed on Spark + Hadoop version, so this penalty is paid at most once per cache eviction. Closes #8. * fix/8: run gofumpt on files whose imports shifted in the rename golangci-lint's gofumpt check flagged three files where the import block got out of shape: stdlib imports mixed with third-party, missing blank-line group separator. Let gofumpt produce the canonical layout. --- .github/workflows/build.yml | 7 ++++++- spark/sql/types/rowiterator_test.go | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 51a8898..9c64f51 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,7 +84,12 @@ jobs: echo "Apache Spark is not installed" # Access the directory. mkdir -p ~/deps/ - wget -q https://dlcdn.apache.org/spark/spark-${{ env.SPARK_VERSION }}/spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz + # dlcdn.apache.org only keeps current releases on its mirrors and + # occasionally 404s on older ones. archive.apache.org is the + # canonical mirror and never rotates — use it as a fallback. + ARCHIVE=spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz + wget -q https://dlcdn.apache.org/spark/spark-${{ env.SPARK_VERSION }}/$ARCHIVE || \ + wget -q https://archive.apache.org/dist/spark/spark-${{ env.SPARK_VERSION }}/$ARCHIVE tar -xzf spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz -C ~/deps/ # Delete the old file rm spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index bc39e1d..525d0c0 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -7,6 +7,9 @@ import ( "iter" "testing" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" From 8ad8d7d55e2190d5f70119ee9795d649b134cda8 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:38:11 +0100 Subject: [PATCH 23/37] feat/3: expose gRPC transport options on the channel + session builders (#11) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BaseBuilder.WithDialOptions lets callers append arbitrary grpc.DialOption values — per-call message ceilings, keepalive profiles, interceptors — anything Spark Connect doesn't expose as a server conf. SparkSessionBuilder.WithDialOptions forwards the same knob through the default builder path. Also bumps the default per-call send/receive ceiling from gRPC's 4 MiB to 1 GiB to match the server's typical upper bound (spark.connect.grpc.maxInboundMessageSize). Without that bump, a single Arrow batch truncates silently on non-trivial queries with a ResourceExhausted error buried in Collect() or StreamRows(). Caller-supplied DialOptions are appended after the builder's defaults so a tighter MaxCallRecvMsgSize will override the default. Closes #3. --- spark/client/channel/channel.go | 34 ++++++++++++++++++++++++++++ spark/client/channel/channel_test.go | 26 +++++++++++++++++++++ spark/sql/sparksession.go | 16 +++++++++++++ spark/sql/sparksession_test.go | 17 ++++++++++++++ 4 files changed, 93 insertions(+) diff --git a/spark/client/channel/channel.go b/spark/client/channel/channel.go index cfca0f3..328da4c 100644 --- a/spark/client/channel/channel.go +++ b/spark/client/channel/channel.go @@ -77,6 +77,28 @@ type BaseBuilder struct { headers map[string]string sessionId string userAgent string + dialOpts []grpc.DialOption +} + +// defaultMaxMessageSize is the per-RPC send/receive ceiling we apply +// when the caller doesn't set one. The server advertises +// spark.connect.grpc.maxInboundMessageSize at 128 MiB by default and +// operators routinely raise it to 1 GiB for bulk reads. gRPC's own +// default is 4 MiB, which silently truncates a single Arrow batch on +// any non-trivial query and surfaces as opaque ResourceExhausted +// errors deep inside Collect()/StreamRows(). Matching the upper end +// of the server's practical range keeps first-run queries from +// hitting a floor the caller didn't know existed. +const defaultMaxMessageSize = 1 << 30 // 1 GiB + +// WithDialOptions appends gRPC dial options to the channel builder. +// Callers use this to raise the per-call message ceiling further, set +// keepalive parameters for long-lived streams, inject interceptors, +// or swap in a custom dialer. Options are applied after the builder's +// defaults, so caller-supplied values win. +func (cb *BaseBuilder) WithDialOptions(opts ...grpc.DialOption) *BaseBuilder { + cb.dialOpts = append(cb.dialOpts, opts...) + return cb } func (cb *BaseBuilder) Host() string { @@ -114,6 +136,15 @@ func (cb *BaseBuilder) Build(ctx context.Context) (*grpc.ClientConn, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithAuthority(cb.host)) + + // Raise the per-call send/receive ceiling off gRPC's 4 MiB default. + // Placed before WithDialOptions so caller-supplied DefaultCallOptions + // can still override if they want a tighter limit. + opts = append(opts, grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(defaultMaxMessageSize), + grpc.MaxCallSendMsgSize(defaultMaxMessageSize), + )) + if cb.token == "" { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { @@ -134,6 +165,9 @@ func (cb *BaseBuilder) Build(ctx context.Context) (*grpc.ClientConn, error) { opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: ts})) } + // Caller overrides come last so they win over the defaults above. + opts = append(opts, cb.dialOpts...) + remote := fmt.Sprintf("%v:%v", cb.host, cb.port) conn, err := grpc.NewClient(remote, opts...) if err != nil { diff --git a/spark/client/channel/channel_test.go b/spark/client/channel/channel_test.go index e4abfc8..0bfc334 100644 --- a/spark/client/channel/channel_test.go +++ b/spark/client/channel/channel_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/uuid" + "google.golang.org/grpc" "github.com/caldempsey/spark-connect-go/spark/client/channel" "github.com/caldempsey/spark-connect-go/spark/sparkerrors" @@ -112,3 +113,28 @@ func TestChannelBulder_UserAgent(t *testing.T) { assert.True(t, strings.Contains(cb.UserAgent(), "spark/")) assert.True(t, strings.Contains(cb.UserAgent(), "os/")) } + +// TestChannelBuilder_WithDialOptions_CompilesAndBuilds is a smoke +// test: the builder accepts caller-supplied grpc.DialOption values, +// stores them, and a subsequent Build produces a live ClientConn +// without panicking. We assert the returned conn is non-nil rather +// than exercising the wire — grpc.NewClient is lazy and doesn't dial +// until the first RPC, so a stricter assertion would spin up a +// server here. +func TestChannelBuilder_WithDialOptions_CompilesAndBuilds(t *testing.T) { + cb, err := channel.NewBuilder("sc://localhost") + assert.NoError(t, err) + + // Representative knobs: raise the per-call ceiling further than + // the builder's own default, and give the connection a keepalive + // profile. Neither changes observable behaviour in this unit + // test, but both have to survive the builder round-trip. + cb.WithDialOptions( + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(2<<30)), + grpc.WithUserAgent("dorm-test"), + ) + + conn, err := cb.Build(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, conn) +} diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index 433a1a7..8f9418e 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -37,6 +37,7 @@ import ( "github.com/caldempsey/spark-connect-go/spark/client/channel" "github.com/caldempsey/spark-connect-go/spark/sparkerrors" "github.com/google/uuid" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -58,6 +59,7 @@ func NewSessionBuilder() *SparkSessionBuilder { type SparkSessionBuilder struct { connectionString string channelBuilder channel.Builder + dialOpts []grpc.DialOption } // Remote sets the connection string for remote connection @@ -71,6 +73,17 @@ func (s *SparkSessionBuilder) WithChannelBuilder(cb channel.Builder) *SparkSessi return s } +// WithDialOptions appends gRPC dial options that will be applied when +// the session builds its underlying channel. Useful for raising the +// per-call message ceiling further, tuning keepalive, or installing +// interceptors — anything Spark Connect doesn't expose as a server +// conf. Ignored when WithChannelBuilder was also called; a custom +// channel owns its own dial options. +func (s *SparkSessionBuilder) WithDialOptions(opts ...grpc.DialOption) *SparkSessionBuilder { + s.dialOpts = append(s.dialOpts, opts...) + return s +} + func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) { if s.channelBuilder == nil { cb, err := channel.NewBuilder(s.connectionString) @@ -78,6 +91,9 @@ func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) { return nil, sparkerrors.WithType(fmt.Errorf( "failed to connect to remote %s: %w", s.connectionString, err), sparkerrors.ConnectionError) } + if len(s.dialOpts) > 0 { + cb.WithDialOptions(s.dialOpts...) + } s.channelBuilder = cb } conn, err := s.channelBuilder.Build(ctx) diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 9dedfae..1971809 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" proto "github.com/caldempsey/spark-connect-go/internal/generated" "github.com/caldempsey/spark-connect-go/spark/client" @@ -191,3 +192,19 @@ func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []any{"str2"}, vals) } + +// TestSessionBuilder_WithDialOptions_Forwarded asserts that +// caller-supplied grpc.DialOption values are forwarded to the +// channel builder the session creates. We build against a bogus +// endpoint and don't dial — grpc.NewClient is lazy and doesn't open +// a connection until the first RPC, so the interesting surface is +// that WithDialOptions composes into the builder's state without +// error. +func TestSessionBuilder_WithDialOptions_Forwarded(t *testing.T) { + ctx := context.Background() + _, err := NewSessionBuilder(). + Remote("sc://localhost"). + WithDialOptions(grpc.WithUserAgent("dorm-test")). + Build(ctx) + assert.NoError(t, err) +} From f340a1e0e5420129c37556b923c0ce678a1375d4 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:45:32 +0100 Subject: [PATCH 24/37] feat/4: add typed DataFrame[T] (#12) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DataFrameOf[T] wraps a DataFrame with a cached, reflected row plan and decodes each row directly into T on Collect. Users tag fields with `spark:"colname"` or leave them bare to fall back on snake_case of the Go field name; schema drift (struct field not in the result projection) surfaces at first Collect, not per row. Two entry points: SqlTyped[User](ctx, session, "SELECT id, email FROM users") *DataFrameOf[User] TypedDataFrame[User](df) *DataFrameOf[User] Collect returns []T. DataFrame() drops back to the untyped surface for operations the typed layer doesn't cover — GroupBy, joins, window functions. Streaming (iter.Seq2 over T) is intentionally out of scope here — needs dedicated ExecutePlanClient plumbing and a matching test matrix, lands in a follow-up. Users with large result sets drop to DataFrame().ToLocalIterator() or the streaming primitive in the untyped DataFrame. Closes #4. --- spark/sql/dataframe_typed.go | 322 ++++++++++++++++++++++++++++++ spark/sql/dataframe_typed_test.go | 153 ++++++++++++++ 2 files changed, 475 insertions(+) create mode 100644 spark/sql/dataframe_typed.go create mode 100644 spark/sql/dataframe_typed_test.go diff --git a/spark/sql/dataframe_typed.go b/spark/sql/dataframe_typed.go new file mode 100644 index 0000000..22a19b0 --- /dev/null +++ b/spark/sql/dataframe_typed.go @@ -0,0 +1,322 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "fmt" + "reflect" + "strings" + "sync" + "time" + "unicode" + + "github.com/apache/arrow-go/v18/arrow" +) + +// DataFrameOf[T] is a typed view on a regular DataFrame. Users +// parameterise it on a Go struct; Collect decodes rows directly +// into T instead of handing back []any that callers have to +// type-assert field by field. +// +// Column binding uses struct tags in the same shape sqlx / parquet-go +// already use: +// +// type User struct { +// ID string `spark:"id"` +// Email string `spark:"email"` +// Created time.Time `spark:"created_at"` +// } +// +// Fields without a `spark:"..."` tag are mapped by snake_case'd field +// name, so a plain Go struct works without any tags at all. Fields +// tagged `spark:"-"` are skipped. Columns in the DataFrame that +// don't match any field are ignored — typical of projections +// narrower than the struct. +// +// Schema drift (a struct field that the result's projection doesn't +// contain) surfaces at the first Collect call as a single error +// rather than per-row panics. +// +// Streaming (iter.Seq2 over T) is intentionally not in v0 — the +// ExecutePlanClient plumbing wants a dedicated PR so it can ship +// alongside a proper test matrix. For now, users who need streaming +// call DataFrame() to drop to the untyped ToRecordSequence path. +type DataFrameOf[T any] struct { + df DataFrame + plan *rowPlan +} + +// SqlTyped runs a SQL query and returns a typed DataFrame over the +// result. Equivalent to SparkSession.Sql followed by a struct-tag- +// driven scanner at every row — except the plan is computed once +// and reused for every Collect call on the returned value. +func SqlTyped[T any](ctx context.Context, session SparkSession, query string) (*DataFrameOf[T], error) { + df, err := session.Sql(ctx, query) + if err != nil { + return nil, err + } + return TypedDataFrame[T](df) +} + +// TypedDataFrame wraps an existing DataFrame in the typed surface. +// Useful when the caller already holds a DataFrame produced by an +// operation other than Sql (Read, Table, a chain of transformations). +// Computes and caches the row plan immediately; a malformed struct +// surfaces here rather than per-row inside Collect. +func TypedDataFrame[T any](df DataFrame) (*DataFrameOf[T], error) { + var zero T + rt := reflect.TypeOf(zero) + if rt == nil || rt.Kind() != reflect.Struct { + return nil, fmt.Errorf("DataFrameOf[T]: T must be a struct, got %v", rt) + } + plan, err := buildRowPlan(rt) + if err != nil { + return nil, err + } + return &DataFrameOf[T]{df: df, plan: plan}, nil +} + +// DataFrame returns the underlying untyped DataFrame. Escape hatch +// for operations the typed surface doesn't cover — GroupBy, joins, +// window functions. Chain freely and call TypedDataFrame again on +// the result when the output shape is known. +func (d *DataFrameOf[T]) DataFrame() DataFrame { return d.df } + +// Collect materialises every row into a []T. Holds the whole table +// on the heap for the duration of the call — callers with large +// result sets should project narrower on the SQL side or drop to +// the untyped streaming path via DataFrame(). +func (d *DataFrameOf[T]) Collect(ctx context.Context) ([]T, error) { + rows, err := d.df.Collect(ctx) + if err != nil { + return nil, err + } + if len(rows) == 0 { + return nil, nil + } + cols := rows[0].FieldNames() + bindings, err := d.plan.bind(cols) + if err != nil { + return nil, err + } + out := make([]T, len(rows)) + for i, r := range rows { + if err := decodeRow(d.plan, r.Values(), bindings, &out[i]); err != nil { + return nil, fmt.Errorf("DataFrameOf[T].Collect: row %d: %w", i, err) + } + } + return out, nil +} + +// rowPlan caches the reflected structure of T so Collect doesn't +// reflect on every row. Built once per DataFrameOf[T]. +type rowPlan struct { + goType reflect.Type + fields []plannedField +} + +type plannedField struct { + name string // column name from tag or snake_case'd field name + index []int // reflect.FieldByIndex path + gotyp reflect.Type +} + +// columnBinding maps a result-set column position to the field slot +// in the plan that should receive it. A column that the struct +// doesn't describe has planIndex = -1 and is skipped. +type columnBinding struct { + planIndex int +} + +var rowPlanCache sync.Map // reflect.Type -> *rowPlan + +func buildRowPlan(rt reflect.Type) (*rowPlan, error) { + if cached, ok := rowPlanCache.Load(rt); ok { + return cached.(*rowPlan), nil + } + plan := &rowPlan{goType: rt} + if err := walkPlan(rt, nil, plan); err != nil { + return nil, err + } + rowPlanCache.Store(rt, plan) + return plan, nil +} + +func walkPlan(rt reflect.Type, parent []int, plan *rowPlan) error { + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !sf.IsExported() { + continue + } + idx := append(append([]int{}, parent...), i) + tag := sf.Tag.Get("spark") + if tag == "-" { + continue + } + if sf.Anonymous && sf.Type.Kind() == reflect.Struct && tag == "" { + if err := walkPlan(sf.Type, idx, plan); err != nil { + return err + } + continue + } + name := tag + if name == "" { + name = snakeCase(sf.Name) + } + plan.fields = append(plan.fields, plannedField{ + name: name, + index: idx, + gotyp: sf.Type, + }) + } + return nil +} + +// bind aligns the plan's fields with a concrete set of result +// columns. Result columns that don't map to any planned field are +// tagged planIndex = -1 and dropped at decode time. A planned field +// that doesn't appear in the columns is a schema-drift error — the +// SQL changed in a way the struct doesn't describe and we want the +// caller to know at plan time, not per row. +func (p *rowPlan) bind(columns []string) ([]columnBinding, error) { + byName := make(map[string]int, len(p.fields)) + for i, f := range p.fields { + byName[f.name] = i + } + bindings := make([]columnBinding, len(columns)) + seen := make(map[int]bool, len(p.fields)) + for ci, name := range columns { + if idx, ok := byName[name]; ok { + bindings[ci] = columnBinding{planIndex: idx} + seen[idx] = true + } else { + bindings[ci] = columnBinding{planIndex: -1} + } + } + var missing []string + for i, f := range p.fields { + if !seen[i] { + missing = append(missing, f.name) + } + } + if len(missing) > 0 { + return nil, fmt.Errorf("DataFrameOf[T]: struct field(s) not in result schema: %s", + strings.Join(missing, ", ")) + } + return bindings, nil +} + +// decodeRow writes values into *T using the bindings. dest is a +// pointer to T; we take *T (not T) so the caller can write each +// element of a pre-allocated []T slice without paying for a +// reflect.Value per row. +func decodeRow[T any](plan *rowPlan, values []any, bindings []columnBinding, dest *T) error { + dv := reflect.ValueOf(dest).Elem() + for ci := 0; ci < len(values) && ci < len(bindings); ci++ { + b := bindings[ci] + if b.planIndex < 0 { + continue + } + pf := &plan.fields[b.planIndex] + target := fieldByIndex(dv, pf.index) + if err := assignTypedValue(target, values[ci]); err != nil { + return fmt.Errorf("column %d (%s): %w", ci, pf.name, err) + } + } + return nil +} + +func fieldByIndex(v reflect.Value, index []int) reflect.Value { + cur := v + for _, i := range index { + for cur.Kind() == reflect.Ptr { + if cur.IsNil() { + cur.Set(reflect.New(cur.Type().Elem())) + } + cur = cur.Elem() + } + cur = cur.Field(i) + } + return cur +} + +// assignTypedValue writes src into the reflect field, doing the +// small set of conversions Spark's row values need — nil → zero for +// optional fields, arrow.Timestamp → time.Time for TIMESTAMP +// columns, assignable/convertible for primitives. Anything rarer +// surfaces as an explicit error so callers can tighten their struct +// to match. +// +// Named with the `Typed` suffix to avoid colliding with the existing +// assignValue helper in other files. +func assignTypedValue(dst reflect.Value, src any) error { + if src == nil { + dst.Set(reflect.Zero(dst.Type())) + return nil + } + dt := dst.Type() + isPtr := dt.Kind() == reflect.Ptr + innerType := dt + if isPtr { + innerType = dt.Elem() + } + if ts, ok := src.(arrow.Timestamp); ok && innerType == reflect.TypeOf(time.Time{}) { + setTypedValue(dst, reflect.ValueOf(ts.ToTime(arrow.Microsecond)), isPtr, innerType) + return nil + } + sv := reflect.ValueOf(src) + if sv.Type().AssignableTo(innerType) { + setTypedValue(dst, sv, isPtr, innerType) + return nil + } + if sv.Type().ConvertibleTo(innerType) { + setTypedValue(dst, sv.Convert(innerType), isPtr, innerType) + return nil + } + return fmt.Errorf("cannot assign %T to %v", src, dt) +} + +func setTypedValue(dst, src reflect.Value, isPtr bool, inner reflect.Type) { + if isPtr { + p := reflect.New(inner) + p.Elem().Set(src) + dst.Set(p) + return + } + dst.Set(src) +} + +// snakeCase converts a Go field name to its snake_case form. Matches +// the convention used by sqlx / gorm / jackc so a plain Go struct +// with no tags lines up with columns that follow standard SQL naming. +func snakeCase(s string) string { + if s == "" { + return s + } + var b strings.Builder + runes := []rune(s) + for i, r := range runes { + if i > 0 && unicode.IsUpper(r) { + prev := runes[i-1] + if unicode.IsLower(prev) || (i+1 < len(runes) && unicode.IsLower(runes[i+1])) { + b.WriteByte('_') + } + } + b.WriteRune(unicode.ToLower(r)) + } + return b.String() +} diff --git a/spark/sql/dataframe_typed_test.go b/spark/sql/dataframe_typed_test.go new file mode 100644 index 0000000..1224573 --- /dev/null +++ b/spark/sql/dataframe_typed_test.go @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "reflect" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type typedUser struct { + ID string `spark:"id"` + Email string `spark:"email"` + Country string // no tag — falls back to snake_case of field name + CreatedAt time.Time `spark:"created_at"` + Secret string `spark:"-"` // skipped + ignored string //nolint:unused // unexported → skipped +} + +func TestBuildRowPlan_TagsAndSnakeCase(t *testing.T) { + plan, err := buildRowPlan(reflect.TypeOf(typedUser{})) + require.NoError(t, err) + + names := make([]string, 0, len(plan.fields)) + for _, f := range plan.fields { + names = append(names, f.name) + } + assert.Equal(t, []string{"id", "email", "country", "created_at"}, names, + "tag takes precedence; untagged fields snake_case; `-` and unexported skipped") +} + +func TestBind_MatchesColumns(t *testing.T) { + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + + // Columns may arrive in a different order than fields and can + // include extras the struct doesn't care about; the binder + // should pick the right indices and ignore the stranger. + cols := []string{"created_at", "extra_col", "email", "id", "country"} + bindings, err := plan.bind(cols) + require.NoError(t, err) + + want := []int{3, -1, 1, 0, 2} + for i, b := range bindings { + assert.Equalf(t, want[i], b.planIndex, + "column %q bound to wrong plan index", cols[i]) + } +} + +func TestBind_SchemaDriftErrorsEarly(t *testing.T) { + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + // email column missing — the struct wants it, the result doesn't + // have it. Bind must surface this, not defer to per-row decode. + _, err := plan.bind([]string{"id", "country", "created_at"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "email", + "schema-drift error should name the missing column") +} + +func TestDecodeRow_ArrowTimestampToTime(t *testing.T) { + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + cols := []string{"id", "email", "country", "created_at"} + bindings, err := plan.bind(cols) + require.NoError(t, err) + + when := time.Date(2026, 4, 19, 12, 0, 0, 0, time.UTC) + // arrow.Timestamp is an int64 of microseconds since epoch. Build + // one directly so the decode path's Microsecond branch fires. + ts := arrow.Timestamp(when.UnixMicro()) + + values := []any{"abc", "alice@example.com", "UK", ts} + + var out typedUser + require.NoError(t, decodeRow(plan, values, bindings, &out)) + + assert.Equal(t, "abc", out.ID) + assert.Equal(t, "alice@example.com", out.Email) + assert.Equal(t, "UK", out.Country) + assert.True(t, out.CreatedAt.Equal(when), "TIMESTAMP micros should round-trip to time.Time") +} + +func TestDecodeRow_AssignableAndConvertible(t *testing.T) { + type row struct { + N int `spark:"n"` // assignable from int + S string `spark:"s"` // assignable from string + F int `spark:"f"` // float64 → int is convertible + } + plan, err := buildRowPlan(reflect.TypeOf(row{})) + require.NoError(t, err) + bindings, err := plan.bind([]string{"n", "s", "f"}) + require.NoError(t, err) + + var out row + require.NoError(t, decodeRow(plan, []any{int(42), "hello", float64(7)}, bindings, &out)) + assert.Equal(t, 42, out.N) + assert.Equal(t, "hello", out.S) + assert.Equal(t, 7, out.F) +} + +func TestDecodeRow_NilIsZero(t *testing.T) { + type row struct { + N int `spark:"n"` + S string `spark:"s"` + } + plan, _ := buildRowPlan(reflect.TypeOf(row{})) + bindings, _ := plan.bind([]string{"n", "s"}) + + var out row + require.NoError(t, decodeRow(plan, []any{nil, nil}, bindings, &out)) + assert.Zero(t, out.N) + assert.Zero(t, out.S) +} + +func TestTypedDataFrame_RejectsNonStruct(t *testing.T) { + // TypedDataFrame is supposed to surface the misuse at construction + // time, not at Collect. A map / slice / primitive should fail + // clearly with a pointer back at the caller's T. + type notAStruct = map[string]string + _, err := TypedDataFrame[notAStruct](nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "T must be a struct") +} + +func TestSnakeCase_CommonShapes(t *testing.T) { + cases := map[string]string{ + "ID": "id", + "Email": "email", + "CreatedAt": "created_at", + "HTTPServer": "http_server", + "JSONPayload": "json_payload", + "A": "a", + "": "", + } + for in, want := range cases { + assert.Equalf(t, want, snakeCase(in), "snakeCase(%q)", in) + } +} From 5a8a35d5b380d3487bd4e908bd52b4440f8c287c Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 17:17:34 +0100 Subject: [PATCH 25/37] docs/13: drop upstream "not for production" notice from fork README (#14) The paragraph described the Apache PMC's stance on the upstream client, not this fork's. Carrying it on a fork that downstream projects ship against was misleading on both provenance and intent. --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index 5a502a9..7f9455a 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,6 @@ This project houses the **experimental** client for [Spark Connect](https://spark.apache.org/docs/latest/spark-connect-overview.html) for [Apache Spark](https://spark.apache.org/) written in [Golang](https://go.dev/). -## Current State of the Project - -Currently, the Spark Connect client for Golang is highly experimental and should -not be used in any production setting. In addition, the PMC of the Apache Spark -project reserves the right to withdraw and abandon the development of this project -if it is not sustainable. - ## Getting started This section explains how to run Spark Connect Go locally. From ccfe0de6659d246277a6b79594bef1f1b5836605 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 19:20:16 +0100 Subject: [PATCH 26/37] feat: top-level typed helpers + Dataset[T] alias (#15) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: top-level typed helpers Collect/Stream/First/As/Into + Dataset[T] alias Adds the five top-level typed helpers the bootstrap spec names: - Collect[T](ctx, df) ([]T, error) - Stream[T](ctx, df) iter.Seq2[T, error] - First[T](ctx, df) (*T, error) / ErrNotFound sentinel - As[T](df) (*DataFrameOf[T], error) - Into(ctx, df, dst any) error // non-generic, slice or struct dst Also a `Dataset[T]` type alias for DataFrameOf[T] to match the Scala / Java naming and Apache Spark's Dataset[T] precedent. DataFrameOf[T] stays as the original name and is fully interchangeable via the alias. Stream[T] is the real new capability: the existing DataFrameOf[T] deliberately omits streaming (documented in dataframe_typed.go). The new helper wraps the untyped DataFrame.All iterator and decodes each row in place, yielding constant memory regardless of result size. Into covers the non-generic path where T isn't known at compile time — typical of code-gen consumers or reflection-heavy DSLs. Uses a non-generic decodeRowReflect sibling to decodeRow so already-typed slice slots populate via reflect.Value rather than forcing a T instantiation. Tests cover the guard paths (Into rejects non-pointer / nil pointer / non-slice-non-struct / slice-of-non-struct; As rejects non-struct T; Collect rejects non-struct T; Dataset[T] alias identity at compile time; ErrNotFound sentinel propagates through errors.Is). Full DataFrame-backed coverage for Collect / Stream / First lands with the integration suite (no mockDataFrame exists today; the thin wrappers call already-tested primitives). * chore: bump go.mod to 1.24 for generic type aliases The new Dataset[T] = DataFrameOf[T] alias is a generic type alias, which requires Go 1.24 (GA) or GOEXPERIMENT=aliastypeparams on older toolchains. CI was pinning 1.23.2 via the go-version-file in .github/workflows/build.yml and failing on the typed-helpers compile step. Bump the module's declared Go version to 1.24. The downstream datalakeorm/dorm module already uses go 1.24.9, so nothing downstream regresses; CI workers on 1.24+ pick it up automatically. --------- --- go.mod | 2 +- spark/sql/typed_helpers.go | 203 ++++++++++++++++++++++++++++++++ spark/sql/typed_helpers_test.go | 85 +++++++++++++ 3 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 spark/sql/typed_helpers.go create mode 100644 spark/sql/typed_helpers_test.go diff --git a/go.mod b/go.mod index 379f88e..c42353e 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ module github.com/caldempsey/spark-connect-go -go 1.23.2 +go 1.24 require ( github.com/apache/arrow-go/v18 v18.4.0 diff --git a/spark/sql/typed_helpers.go b/spark/sql/typed_helpers.go new file mode 100644 index 0000000..297b94a --- /dev/null +++ b/spark/sql/typed_helpers.go @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "errors" + "fmt" + "iter" + "reflect" +) + +// Dataset is a type alias for DataFrameOf[T]. Matches the Scala/Java +// Dataset[T] naming; DataFrameOf[T] remains as the original name and +// is fully interchangeable. +type Dataset[T any] = DataFrameOf[T] + +// ErrNotFound is returned by First when the DataFrame produces zero +// rows. +var ErrNotFound = errors.New("spark: no rows returned") + +// Collect materialises every row of df into a []T by wrapping df in +// the typed surface and calling Collect. Equivalent to +// TypedDataFrame[T](df).Collect(ctx) but written as a one-liner for +// callers who already hold a DataFrame. +func Collect[T any](ctx context.Context, df DataFrame) ([]T, error) { + typed, err := TypedDataFrame[T](df) + if err != nil { + return nil, err + } + return typed.Collect(ctx) +} + +// Stream yields typed rows one at a time using the untyped +// DataFrame's streaming primitive underneath. Constant memory +// regardless of result size. Schema binding happens on the first row; +// a subsequent row whose schema diverges from the first surfaces the +// error through the iterator. +// +// Consumers range over the return value with Go 1.23's iter.Seq2: +// +// for row, err := range sql.Stream[User](ctx, df) { +// if err != nil { break } +// // use row +// } +func Stream[T any](ctx context.Context, df DataFrame) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + var zero T + typed, err := TypedDataFrame[T](df) + if err != nil { + yield(zero, err) + return + } + var bindings []columnBinding + for row, rerr := range df.All(ctx) { + if rerr != nil { + yield(zero, rerr) + return + } + if bindings == nil { + b, berr := typed.plan.bind(row.FieldNames()) + if berr != nil { + yield(zero, berr) + return + } + bindings = b + } + var out T + if derr := decodeRow(typed.plan, row.Values(), bindings, &out); derr != nil { + yield(zero, derr) + return + } + if !yield(out, nil) { + return + } + } + } +} + +// First returns the first row of df decoded as T, or ErrNotFound if +// df produced no rows. Runs Collect underneath at v0; the DataFrame +// LIMIT optimisation lands when Dataset[T].Limit stabilises. +func First[T any](ctx context.Context, df DataFrame) (*T, error) { + rows, err := Collect[T](ctx, df) + if err != nil { + return nil, err + } + if len(rows) == 0 { + return nil, ErrNotFound + } + return &rows[0], nil +} + +// As wraps df in the typed surface. Alias for TypedDataFrame[T], +// named to match CLAUDE_CODE_BOOTSTRAP.md and Scala's Encoder-flavoured +// naming. Schema compatibility with T is validated lazily — the first +// call to Collect or Stream surfaces drift, not As itself. +func As[T any](df DataFrame) (*DataFrameOf[T], error) { + return TypedDataFrame[T](df) +} + +// Into scans df into dst where dst is a pointer to either a slice +// (populated with every row) or a single struct (the first row; +// ErrNotFound on empty). Non-generic variant for cases where T is +// not known at compile time — typical of code-generated consumers +// or reflection-heavy DSLs. +func Into(ctx context.Context, df DataFrame, dst any) error { + rv := reflect.ValueOf(dst) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("spark.Into: dst must be a non-nil pointer, got %T", dst) + } + elem := rv.Elem() + switch elem.Kind() { + case reflect.Slice: + return intoSlice(ctx, df, elem) + case reflect.Struct: + return intoStruct(ctx, df, elem) + default: + return fmt.Errorf("spark.Into: dst must point to a slice or struct, got %v", elem.Kind()) + } +} + +func intoSlice(ctx context.Context, df DataFrame, sliceValue reflect.Value) error { + elemType := sliceValue.Type().Elem() + if elemType.Kind() != reflect.Struct { + return fmt.Errorf("spark.Into: slice element type must be a struct, got %v", elemType) + } + plan, err := buildRowPlan(elemType) + if err != nil { + return err + } + rows, err := df.Collect(ctx) + if err != nil { + return err + } + if len(rows) == 0 { + sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0)) + return nil + } + bindings, err := plan.bind(rows[0].FieldNames()) + if err != nil { + return err + } + out := reflect.MakeSlice(sliceValue.Type(), len(rows), len(rows)) + for i, r := range rows { + if err := decodeRowReflect(plan, r.Values(), bindings, out.Index(i)); err != nil { + return fmt.Errorf("spark.Into: row %d: %w", i, err) + } + } + sliceValue.Set(out) + return nil +} + +func intoStruct(ctx context.Context, df DataFrame, structValue reflect.Value) error { + plan, err := buildRowPlan(structValue.Type()) + if err != nil { + return err + } + rows, err := df.Collect(ctx) + if err != nil { + return err + } + if len(rows) == 0 { + return ErrNotFound + } + bindings, err := plan.bind(rows[0].FieldNames()) + if err != nil { + return err + } + return decodeRowReflect(plan, rows[0].Values(), bindings, structValue) +} + +// decodeRowReflect is the non-generic variant of decodeRow used by +// Into. Mirrors decodeRow's logic but walks the destination via +// reflect.Value rather than *T, so the caller can populate one slot +// of an already-allocated []T without instantiating T. +func decodeRowReflect(plan *rowPlan, values []any, bindings []columnBinding, dest reflect.Value) error { + for ci := 0; ci < len(values) && ci < len(bindings); ci++ { + b := bindings[ci] + if b.planIndex < 0 { + continue + } + pf := &plan.fields[b.planIndex] + target := fieldByIndex(dest, pf.index) + if err := assignTypedValue(target, values[ci]); err != nil { + return fmt.Errorf("column %d (%s): %w", ci, pf.name, err) + } + } + return nil +} diff --git a/spark/sql/typed_helpers_test.go b/spark/sql/typed_helpers_test.go new file mode 100644 index 0000000..fab173a --- /dev/null +++ b/spark/sql/typed_helpers_test.go @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "errors" + "strings" + "testing" +) + +// Dataset[T] is the alias; verify at compile time that it's +// interchangeable with DataFrameOf[T]. A failing assertion here is +// a type-level regression caught without running any code. +var _ *DataFrameOf[typedUser] = (*Dataset[typedUser])(nil) + +func TestErrNotFound_Sentinel(t *testing.T) { + // First's empty-result path wraps ErrNotFound; verify callers can + // branch on it via errors.Is. + wrapped := errors.Join(ErrNotFound, errors.New("context")) + if !errors.Is(wrapped, ErrNotFound) { + t.Error("ErrNotFound should match via errors.Is through errors.Join") + } +} + +func TestInto_RejectsNonPointer(t *testing.T) { + err := Into(context.Background(), nil, "not a pointer") + if err == nil || !strings.Contains(err.Error(), "non-nil pointer") { + t.Errorf("want non-nil-pointer error, got %v", err) + } +} + +func TestInto_RejectsNilPointer(t *testing.T) { + var users *[]typedUser + err := Into(context.Background(), nil, users) + if err == nil || !strings.Contains(err.Error(), "non-nil pointer") { + t.Errorf("want non-nil-pointer error, got %v", err) + } +} + +func TestInto_RejectsPointerToNonSliceNonStruct(t *testing.T) { + n := 42 + err := Into(context.Background(), nil, &n) + if err == nil || !strings.Contains(err.Error(), "slice or struct") { + t.Errorf("want slice-or-struct error, got %v", err) + } +} + +func TestInto_RejectsSliceOfNonStruct(t *testing.T) { + // Cover the elemType check in intoSlice before any I/O would fire. + ns := []int{} + err := Into(context.Background(), nil, &ns) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want slice-element-struct error, got %v", err) + } +} + +func TestAs_RejectsNonStructT(t *testing.T) { + // As[T] delegates to TypedDataFrame[T] which enforces T must be a + // struct. Verify the error surfaces before any DataFrame I/O. + _, err := As[int](nil) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} + +func TestCollect_RejectsNonStructT(t *testing.T) { + _, err := Collect[int](context.Background(), nil) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} From 00c6971ab54dfff915e323a8a05aa5e6717db43b Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 19:26:48 +0100 Subject: [PATCH 27/37] chore: rename module path to github.com/datalakego/spark-connect-go (#16) The org rename moved the fork from caldempsey/spark-connect-go to datalakego/spark-connect-go. Update the module declaration in go.mod and sweep all import paths across the tree. 43 files touched (Go sources, tests, mocks, Makefile-less module references, CI). Sanity-checked afterwards: grep -rn "caldempsey/spark-connect-go" . returns zero hits. Tests unchanged; they pass green under the new module path: ok github.com/datalakego/spark-connect-go/spark/sql 1.702s Consumers who depend on this module update their go.mod lines to the new path; a redirect from the old GitHub URL handles clone and checkout automatically. --- README.md | 4 ++-- cmd/spark-connect-example-raw-grpc-client/main.go | 2 +- cmd/spark-connect-example-spark-session/main.go | 8 ++++---- go.mod | 2 +- internal/tests/integration/dataframe_test.go | 10 +++++----- internal/tests/integration/functions_test.go | 6 +++--- internal/tests/integration/helper.go | 2 +- internal/tests/integration/spark_runner.go | 2 +- internal/tests/integration/sql_test.go | 8 ++++---- quick-start.md | 4 ++-- spark/client/base/base.go | 6 +++--- spark/client/channel/channel.go | 4 ++-- spark/client/channel/channel_test.go | 4 ++-- spark/client/client.go | 14 +++++++------- spark/client/client_test.go | 8 ++++---- spark/client/conf.go | 4 ++-- spark/client/retry.go | 8 ++++---- spark/client/retry_test.go | 8 ++++---- spark/client/testutils/utils.go | 2 +- spark/mocks/mock_executor.go | 8 ++++---- spark/mocks/mocks.go | 2 +- spark/sql/column/column.go | 4 ++-- spark/sql/column/column_test.go | 2 +- spark/sql/column/expressions.go | 6 +++--- spark/sql/column/expressions_test.go | 2 +- spark/sql/dataframe.go | 10 +++++----- spark/sql/dataframe_test.go | 4 ++-- spark/sql/dataframenafunctions.go | 2 +- spark/sql/dataframewriter.go | 4 ++-- spark/sql/dataframewriter_test.go | 6 +++--- spark/sql/functions/buiitins.go | 4 ++-- spark/sql/functions/generated.go | 2 +- spark/sql/group.go | 10 +++++----- spark/sql/group_test.go | 8 ++++---- spark/sql/plan.go | 2 +- spark/sql/sparksession.go | 14 +++++++------- spark/sql/sparksession_integration_test.go | 2 +- spark/sql/sparksession_test.go | 10 +++++----- spark/sql/types/arrow.go | 4 ++-- spark/sql/types/arrow_test.go | 4 ++-- spark/sql/types/builtin.go | 2 +- spark/sql/types/conversion.go | 4 ++-- spark/sql/types/conversion_test.go | 4 ++-- spark/sql/types/rowiterator_test.go | 2 +- spark/sql/utils/consts.go | 2 +- 45 files changed, 115 insertions(+), 115 deletions(-) diff --git a/README.md b/README.md index 7f9455a..0d90234 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Step 3: Run the following commands to setup the Spark Connect client. Building with Spark in case you need to re-generate the source files from the proto sources. ``` -git clone https://github.com/caldempsey/spark-connect-go.git +git clone https://github.com/datalakego/spark-connect-go.git git submodule update --init --recursive make gen && make test @@ -27,7 +27,7 @@ make gen && make test Building without Spark ``` -git clone https://github.com/caldempsey/spark-connect-go.git +git clone https://github.com/datalakego/spark-connect-go.git make && make test ``` diff --git a/cmd/spark-connect-example-raw-grpc-client/main.go b/cmd/spark-connect-example-raw-grpc-client/main.go index ab3c405..16a18df 100644 --- a/cmd/spark-connect-example-raw-grpc-client/main.go +++ b/cmd/spark-connect-example-raw-grpc-client/main.go @@ -22,7 +22,7 @@ import ( "log" "time" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index a05758b..45ead19 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -22,12 +22,12 @@ import ( "fmt" "log" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - "github.com/caldempsey/spark-connect-go/spark/sql/functions" + "github.com/datalakego/spark-connect-go/spark/sql/functions" - "github.com/caldempsey/spark-connect-go/spark/sql" - "github.com/caldempsey/spark-connect-go/spark/sql/utils" + "github.com/datalakego/spark-connect-go/spark/sql" + "github.com/datalakego/spark-connect-go/spark/sql/utils" ) var ( diff --git a/go.mod b/go.mod index c42353e..5052b71 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -module github.com/caldempsey/spark-connect-go +module github.com/datalakego/spark-connect-go go 1.24 diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index 9d4dde1..1649dc6 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -21,15 +21,15 @@ import ( "os" "testing" - "github.com/caldempsey/spark-connect-go/spark/sql/utils" + "github.com/datalakego/spark-connect-go/spark/sql/utils" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - "github.com/caldempsey/spark-connect-go/spark/sql/column" + "github.com/datalakego/spark-connect-go/spark/sql/column" - "github.com/caldempsey/spark-connect-go/spark/sql/functions" + "github.com/datalakego/spark-connect-go/spark/sql/functions" - "github.com/caldempsey/spark-connect-go/spark/sql" + "github.com/datalakego/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/tests/integration/functions_test.go b/internal/tests/integration/functions_test.go index 18c8cb2..48da37b 100644 --- a/internal/tests/integration/functions_test.go +++ b/internal/tests/integration/functions_test.go @@ -19,11 +19,11 @@ import ( "context" "testing" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - "github.com/caldempsey/spark-connect-go/spark/sql/functions" + "github.com/datalakego/spark-connect-go/spark/sql/functions" - "github.com/caldempsey/spark-connect-go/spark/sql" + "github.com/datalakego/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/internal/tests/integration/helper.go b/internal/tests/integration/helper.go index df0b630..9ea4dd4 100644 --- a/internal/tests/integration/helper.go +++ b/internal/tests/integration/helper.go @@ -22,7 +22,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/caldempsey/spark-connect-go/spark/sql" + "github.com/datalakego/spark-connect-go/spark/sql" ) func connect() (context.Context, sql.SparkSession) { diff --git a/internal/tests/integration/spark_runner.go b/internal/tests/integration/spark_runner.go index 5796456..4768b18 100644 --- a/internal/tests/integration/spark_runner.go +++ b/internal/tests/integration/spark_runner.go @@ -23,7 +23,7 @@ import ( "os/exec" "time" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) func StartSparkConnect() (int64, error) { diff --git a/internal/tests/integration/sql_test.go b/internal/tests/integration/sql_test.go index 395b285..830c317 100644 --- a/internal/tests/integration/sql_test.go +++ b/internal/tests/integration/sql_test.go @@ -22,13 +22,13 @@ import ( "os" "testing" - "github.com/caldempsey/spark-connect-go/spark/sql/column" + "github.com/datalakego/spark-connect-go/spark/sql/column" - "github.com/caldempsey/spark-connect-go/spark/sql/functions" + "github.com/datalakego/spark-connect-go/spark/sql/functions" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - "github.com/caldempsey/spark-connect-go/spark/sql" + "github.com/datalakego/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/quick-start.md b/quick-start.md index b51f3ee..94151ff 100644 --- a/quick-start.md +++ b/quick-start.md @@ -5,7 +5,7 @@ In your Go project `go.mod` file, add `spark-connect-go` library: ``` require ( - github.com/caldempsey/spark-connect-go master + github.com/datalakego/spark-connect-go master ) ``` @@ -23,7 +23,7 @@ import ( "fmt" "log" - "github.com/caldempsey/spark-connect-go/spark/sql" + "github.com/datalakego/spark-connect-go/spark/sql" ) var ( diff --git a/spark/client/base/base.go b/spark/client/base/base.go index aa554a0..054f1f9 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -19,11 +19,11 @@ import ( "context" "iter" - "github.com/caldempsey/spark-connect-go/spark/sql/utils" + "github.com/datalakego/spark-connect-go/spark/sql/utils" "github.com/apache/arrow-go/v18/arrow" - "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sql/types" ) type SparkConnectRPCClient generated.SparkConnectServiceClient diff --git a/spark/client/channel/channel.go b/spark/client/channel/channel.go index 328da4c..fd61306 100644 --- a/spark/client/channel/channel.go +++ b/spark/client/channel/channel.go @@ -29,13 +29,13 @@ import ( "strconv" "strings" - "github.com/caldempsey/spark-connect-go/spark" + "github.com/datalakego/spark-connect-go/spark" "github.com/google/uuid" "google.golang.org/grpc/credentials/insecure" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" diff --git a/spark/client/channel/channel_test.go b/spark/client/channel/channel_test.go index 0bfc334..7a3abf7 100644 --- a/spark/client/channel/channel_test.go +++ b/spark/client/channel/channel_test.go @@ -24,8 +24,8 @@ import ( "github.com/google/uuid" "google.golang.org/grpc" - "github.com/caldempsey/spark-connect-go/spark/client/channel" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/client/channel" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" ) diff --git a/spark/client/client.go b/spark/client/client.go index 2b0c08e..0e2bff0 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -22,24 +22,24 @@ import ( "io" "iter" - "github.com/caldempsey/spark-connect-go/spark/sql/utils" + "github.com/datalakego/spark-connect-go/spark/sql/utils" "google.golang.org/grpc" "google.golang.org/grpc/metadata" - "github.com/caldempsey/spark-connect-go/spark/client/base" - "github.com/caldempsey/spark-connect-go/spark/mocks" + "github.com/datalakego/spark-connect-go/spark/client/base" + "github.com/datalakego/spark-connect-go/spark/mocks" - "github.com/caldempsey/spark-connect-go/spark/client/options" + "github.com/datalakego/spark-connect-go/spark/client/options" "github.com/google/uuid" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) type sparkConnectClientImpl struct { diff --git a/spark/client/client_test.go b/spark/client/client_test.go index c9dc54b..2e79650 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -12,10 +12,10 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/client" - "github.com/caldempsey/spark-connect-go/spark/mocks" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/client" + "github.com/datalakego/spark-connect-go/spark/mocks" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/spark/client/conf.go b/spark/client/conf.go index 22a81e2..96ad2e0 100644 --- a/spark/client/conf.go +++ b/spark/client/conf.go @@ -18,8 +18,8 @@ package client import ( "context" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/client/base" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/client/base" ) // Public interface RuntimeConfig diff --git a/spark/client/retry.go b/spark/client/retry.go index 016aa46..4bf6161 100644 --- a/spark/client/retry.go +++ b/spark/client/retry.go @@ -23,13 +23,13 @@ import ( "strings" "time" - "github.com/caldempsey/spark-connect-go/spark/client/base" + "github.com/datalakego/spark-connect-go/spark/client/base" - "github.com/caldempsey/spark-connect-go/spark/client/options" + "github.com/datalakego/spark-connect-go/spark/client/options" "google.golang.org/grpc/metadata" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) diff --git a/spark/client/retry_test.go b/spark/client/retry_test.go index 83801db..604419d 100644 --- a/spark/client/retry_test.go +++ b/spark/client/retry_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" - "github.com/caldempsey/spark-connect-go/spark/client/options" + "github.com/datalakego/spark-connect-go/spark/client/options" - "github.com/caldempsey/spark-connect-go/spark/client/testutils" - "github.com/caldempsey/spark-connect-go/spark/mocks" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/client/testutils" + "github.com/datalakego/spark-connect-go/spark/mocks" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go index e38a33b..fb071b5 100644 --- a/spark/client/testutils/utils.go +++ b/spark/client/testutils/utils.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" "google.golang.org/grpc" ) diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go index fc4e9a6..c67e2ed 100644 --- a/spark/mocks/mock_executor.go +++ b/spark/mocks/mock_executor.go @@ -19,13 +19,13 @@ import ( "context" "errors" - "github.com/caldempsey/spark-connect-go/spark/sql/utils" + "github.com/datalakego/spark-connect-go/spark/sql/utils" - "github.com/caldempsey/spark-connect-go/spark/client/base" + "github.com/datalakego/spark-connect-go/spark/client/base" "github.com/apache/arrow-go/v18/arrow" - "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sql/types" ) type TestExecutor struct { diff --git a/spark/mocks/mocks.go b/spark/mocks/mocks.go index c6232a9..3ce7c79 100644 --- a/spark/mocks/mocks.go +++ b/spark/mocks/mocks.go @@ -25,7 +25,7 @@ import ( "github.com/google/uuid" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" "google.golang.org/grpc/metadata" ) diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go index 13eb6a8..19c3629 100644 --- a/spark/sql/column/column.go +++ b/spark/sql/column/column.go @@ -18,9 +18,9 @@ package column import ( "context" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" ) // Convertible is the interface for all things that can be converted into a protobuf expression. diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go index c62cbde..11e2203 100644 --- a/spark/sql/column/column_test.go +++ b/spark/sql/column/column_test.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go index ca0bfcb..30e34e5 100644 --- a/spark/sql/column/expressions.go +++ b/spark/sql/column/expressions.go @@ -20,11 +20,11 @@ import ( "fmt" "strings" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" ) func newProtoExpression() *proto.Expression { diff --git a/spark/sql/column/expressions_test.go b/spark/sql/column/expressions_test.go index 4fb59e3..58741f6 100644 --- a/spark/sql/column/expressions_test.go +++ b/spark/sql/column/expressions_test.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 4374529..4b61326 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -22,14 +22,14 @@ import ( "math/rand/v2" "github.com/apache/arrow-go/v18/arrow" - "github.com/caldempsey/spark-connect-go/spark/sql/utils" + "github.com/datalakego/spark-connect-go/spark/sql/utils" - "github.com/caldempsey/spark-connect-go/spark/sql/column" + "github.com/datalakego/spark-connect-go/spark/sql/column" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) // ResultCollector receives a stream of result rows diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go index 49fb11f..475c270 100644 --- a/spark/sql/dataframe_test.go +++ b/spark/sql/dataframe_test.go @@ -20,8 +20,8 @@ import ( "context" "testing" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sql/functions" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sql/functions" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframenafunctions.go b/spark/sql/dataframenafunctions.go index 23f827f..c288493 100644 --- a/spark/sql/dataframenafunctions.go +++ b/spark/sql/dataframenafunctions.go @@ -18,7 +18,7 @@ package sql import ( "context" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" ) type DataFrameNaFunctions interface { diff --git a/spark/sql/dataframewriter.go b/spark/sql/dataframewriter.go index 6f98649..ee2fb7f 100644 --- a/spark/sql/dataframewriter.go +++ b/spark/sql/dataframewriter.go @@ -21,8 +21,8 @@ import ( "fmt" "strings" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) // DataFrameWriter supports writing data frame to storage. diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index 8a6775c..10baecd 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - "github.com/caldempsey/spark-connect-go/spark/client" + "github.com/datalakego/spark-connect-go/spark/client" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/mocks" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go index 8f07d82..ce9b46d 100644 --- a/spark/sql/functions/buiitins.go +++ b/spark/sql/functions/buiitins.go @@ -16,8 +16,8 @@ package functions import ( - "github.com/caldempsey/spark-connect-go/spark/sql/column" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/column" + "github.com/datalakego/spark-connect-go/spark/sql/types" ) func Expr(expr string) column.Column { diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go index f66f9a7..5af8970 100644 --- a/spark/sql/functions/generated.go +++ b/spark/sql/functions/generated.go @@ -15,7 +15,7 @@ package functions -import "github.com/caldempsey/spark-connect-go/spark/sql/column" +import "github.com/datalakego/spark-connect-go/spark/sql/column" // BitwiseNOT - Computes bitwise not. // diff --git a/spark/sql/group.go b/spark/sql/group.go index 125fee8..b908e76 100644 --- a/spark/sql/group.go +++ b/spark/sql/group.go @@ -19,12 +19,12 @@ package sql import ( "context" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" - "github.com/caldempsey/spark-connect-go/spark/sql/column" - "github.com/caldempsey/spark-connect-go/spark/sql/functions" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/sql/column" + "github.com/datalakego/spark-connect-go/spark/sql/functions" ) type GroupedData struct { diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go index f6e6c72..109b83e 100644 --- a/spark/sql/group_test.go +++ b/spark/sql/group_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/client" - "github.com/caldempsey/spark-connect-go/spark/client/testutils" - "github.com/caldempsey/spark-connect-go/spark/mocks" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/client" + "github.com/datalakego/spark-connect-go/spark/client/testutils" + "github.com/datalakego/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/plan.go b/spark/sql/plan.go index 89709ae..957da2d 100644 --- a/spark/sql/plan.go +++ b/spark/sql/plan.go @@ -19,7 +19,7 @@ package sql import ( "sync/atomic" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" ) var atomicInt64 atomic.Int64 diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index 8f9418e..57c03c9 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -23,19 +23,19 @@ import ( "time" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/caldempsey/spark-connect-go/spark/client/base" + "github.com/datalakego/spark-connect-go/spark/client/base" - "github.com/caldempsey/spark-connect-go/spark/client/options" + "github.com/datalakego/spark-connect-go/spark/client/options" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/client" - "github.com/caldempsey/spark-connect-go/spark/client/channel" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/client" + "github.com/datalakego/spark-connect-go/spark/client/channel" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/metadata" diff --git a/spark/sql/sparksession_integration_test.go b/spark/sql/sparksession_integration_test.go index 9512fd4..9e8ca70 100644 --- a/spark/sql/sparksession_integration_test.go +++ b/spark/sql/sparksession_integration_test.go @@ -21,7 +21,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 1971809..9acc682 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -30,11 +30,11 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/client" - "github.com/caldempsey/spark-connect-go/spark/client/testutils" - "github.com/caldempsey/spark-connect-go/spark/mocks" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/client" + "github.com/datalakego/spark-connect-go/spark/client/testutils" + "github.com/datalakego/spark-connect-go/spark/mocks" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) func TestSparkSessionTable(t *testing.T) { diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index 0bfc1ff..92432f8 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -20,13 +20,13 @@ import ( "bytes" "fmt" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index 6d88c66..e19c3cd 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -30,8 +30,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sql/types" ) func TestShowArrowBatchData(t *testing.T) { diff --git a/spark/sql/types/builtin.go b/spark/sql/types/builtin.go index 944beca..e03f843 100644 --- a/spark/sql/types/builtin.go +++ b/spark/sql/types/builtin.go @@ -19,7 +19,7 @@ package types import ( "context" - proto "github.com/caldempsey/spark-connect-go/internal/generated" + proto "github.com/datalakego/spark-connect-go/internal/generated" ) type LiteralType interface { diff --git a/spark/sql/types/conversion.go b/spark/sql/types/conversion.go index 92a2fc5..b40db72 100644 --- a/spark/sql/types/conversion.go +++ b/spark/sql/types/conversion.go @@ -19,8 +19,8 @@ package types import ( "errors" - "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sparkerrors" + "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sparkerrors" ) func ConvertProtoDataTypeToStructType(input *generated.DataType) (*StructType, error) { diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go index 109dc92..d3a3d00 100644 --- a/spark/sql/types/conversion_test.go +++ b/spark/sql/types/conversion_test.go @@ -19,8 +19,8 @@ package types_test import ( "testing" - proto "github.com/caldempsey/spark-connect-go/internal/generated" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + proto "github.com/datalakego/spark-connect-go/internal/generated" + "github.com/datalakego/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index 525d0c0..a81c8ef 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -16,7 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/caldempsey/spark-connect-go/spark/sql/types" + "github.com/datalakego/spark-connect-go/spark/sql/types" ) // Helper function to create test records diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go index 3dd5f61..53a4e56 100644 --- a/spark/sql/utils/consts.go +++ b/spark/sql/utils/consts.go @@ -15,7 +15,7 @@ package utils -import proto "github.com/caldempsey/spark-connect-go/internal/generated" +import proto "github.com/datalakego/spark-connect-go/internal/generated" type ExplainMode int From abc5e8d241a3d572c549e067b053c95d62eb8c9d Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 20:08:31 +0100 Subject: [PATCH 28/37] feat: Dataset[T].Where / Limit / OrderBy / First / Stream methods (#17) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the five chainable methods the Scala / Java Dataset[T] API advertises but DataFrameOf[T] didn't expose. Completes the "typed DataFrame is a first-class surface" story the fork promised. Surface: func (d *DataFrameOf[T]) Where(sql string, args ...any) *DataFrameOf[T] func (d *DataFrameOf[T]) Limit(n int) *DataFrameOf[T] func (d *DataFrameOf[T]) OrderBy(columns ...string) *DataFrameOf[T] func (d *DataFrameOf[T]) First(ctx context.Context) (*T, error) func (d *DataFrameOf[T]) Stream(ctx context.Context) iter.Seq2[T, error] Where / Limit / OrderBy are lazy: each queues a datasetOp on a cloned DataFrameOf so chains don't mutate shared state. Ops are applied in declaration order by resolveDataFrame when Collect / Stream / First materialises. args on Where is accepted for compatibility with dorm.Query's signature but currently forwarded nowhere — the underlying DataFrame.Where takes a bare string. Callers interpolate with fmt.Sprintf or build predicates via the functions package. First returns ErrNotFound on empty (sentinel already in typed_helpers.go from the previous cut). Stream uses DataFrame.All under the hood, honouring any queued ops via resolveDataFrame. Collect rewired through resolveDataFrame so queued ops apply correctly there too. Existing tests still pass (no behavioural change for a DataFrameOf constructed without chained ops). Tests: chainable shape (Where / Limit / OrderBy each queue one op; chain of three queues three), clone isolation (parent.ops unchanged when child adds more), resolve-with-no-ops returns the underlying DataFrame untouched via a sentinel fake. --- spark/sql/dataframe_typed.go | 36 ++++++++- spark/sql/dataset_methods.go | 126 ++++++++++++++++++++++++++++++ spark/sql/dataset_methods_test.go | 107 +++++++++++++++++++++++++ 3 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 spark/sql/dataset_methods.go create mode 100644 spark/sql/dataset_methods_test.go diff --git a/spark/sql/dataframe_typed.go b/spark/sql/dataframe_typed.go index 22a19b0..37fac9a 100644 --- a/spark/sql/dataframe_typed.go +++ b/spark/sql/dataframe_typed.go @@ -58,6 +58,36 @@ import ( type DataFrameOf[T any] struct { df DataFrame plan *rowPlan + ops []datasetOp // lazy Where/Limit/OrderBy transforms; applied on materialise +} + +// datasetOp is a pending transform on the underlying DataFrame. Ops +// are queued by the chainable Where/Limit/OrderBy methods and applied +// at Collect/Stream/First time so the fluent builder stays ctx-free. +type datasetOp func(ctx context.Context, df DataFrame) (DataFrame, error) + +// resolveDataFrame applies every queued op in declaration order and +// returns the final DataFrame. Used by Collect/Stream/First to get a +// materialisable handle. +func (d *DataFrameOf[T]) resolveDataFrame(ctx context.Context) (DataFrame, error) { + df := d.df + for _, op := range d.ops { + next, err := op(ctx, df) + if err != nil { + return nil, err + } + df = next + } + return df, nil +} + +// clone returns a shallow copy of d with a freshly allocated ops +// slice so that chained operations don't share state with the parent. +// Chainable methods return the clone. +func (d *DataFrameOf[T]) clone() *DataFrameOf[T] { + cp := *d + cp.ops = append([]datasetOp(nil), d.ops...) + return &cp } // SqlTyped runs a SQL query and returns a typed DataFrame over the @@ -101,7 +131,11 @@ func (d *DataFrameOf[T]) DataFrame() DataFrame { return d.df } // result sets should project narrower on the SQL side or drop to // the untyped streaming path via DataFrame(). func (d *DataFrameOf[T]) Collect(ctx context.Context) ([]T, error) { - rows, err := d.df.Collect(ctx) + df, err := d.resolveDataFrame(ctx) + if err != nil { + return nil, err + } + rows, err := df.Collect(ctx) if err != nil { return nil, err } diff --git a/spark/sql/dataset_methods.go b/spark/sql/dataset_methods.go new file mode 100644 index 0000000..c6613a9 --- /dev/null +++ b/spark/sql/dataset_methods.go @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "iter" + + "github.com/datalakego/spark-connect-go/spark/sql/column" + "github.com/datalakego/spark-connect-go/spark/sql/functions" +) + +// Where adds a filter predicate. Lazy: the condition is applied to +// the underlying DataFrame when Collect / Stream / First materialises +// the Dataset. Chainable — each call narrows the projection further. +// +// sqlStr is a Spark SQL fragment (e.g. "country = 'UK'" or +// "id IN ('a', 'b', 'c')"). args is accepted for API compatibility +// with dorm.Query's signature but currently ignored — the underlying +// DataFrame.Where takes a bare string; callers interpolate values +// with fmt.Sprintf or build the fragment via the functions package. +func (d *DataFrameOf[T]) Where(sqlStr string, args ...any) *DataFrameOf[T] { + _ = args + cp := d.clone() + cp.ops = append(cp.ops, func(ctx context.Context, df DataFrame) (DataFrame, error) { + return df.Where(ctx, sqlStr) + }) + return cp +} + +// Limit caps the number of rows materialised. Chainable; repeated +// calls each produce their own Limit relation in the underlying plan, +// and Spark's optimiser collapses them to the minimum. +func (d *DataFrameOf[T]) Limit(n int) *DataFrameOf[T] { + cp := d.clone() + cp.ops = append(cp.ops, func(ctx context.Context, df DataFrame) (DataFrame, error) { + return df.Limit(ctx, int32(n)), nil + }) + return cp +} + +// OrderBy adds an ascending sort by one or more columns. Callers who +// need descending order, null-ordering modifiers, or expression-based +// sort keys drop to DataFrame() and invoke Sort directly with +// column.Convertible values. +func (d *DataFrameOf[T]) OrderBy(columns ...string) *DataFrameOf[T] { + cp := d.clone() + cp.ops = append(cp.ops, func(ctx context.Context, df DataFrame) (DataFrame, error) { + cols := make([]column.Convertible, 0, len(columns)) + for _, name := range columns { + cols = append(cols, functions.Col(name)) + } + return df.Sort(ctx, cols...) + }) + return cp +} + +// First returns a pointer to the first row as T. Applies Limit(1) +// under the hood before materialising. Returns ErrNotFound when the +// (possibly filtered) DataFrame has zero rows. +func (d *DataFrameOf[T]) First(ctx context.Context) (*T, error) { + rows, err := d.Limit(1).Collect(ctx) + if err != nil { + return nil, err + } + if len(rows) == 0 { + return nil, ErrNotFound + } + return &rows[0], nil +} + +// Stream yields typed rows one at a time with constant memory, +// honouring any queued Where / Limit / OrderBy. Uses the untyped +// DataFrame.All streaming primitive underneath. Consumers range with +// Go 1.23's iter.Seq2: +// +// for row, err := range ds.Stream(ctx) { +// if err != nil { break } +// // use row +// } +func (d *DataFrameOf[T]) Stream(ctx context.Context) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + var zero T + df, err := d.resolveDataFrame(ctx) + if err != nil { + yield(zero, err) + return + } + var bindings []columnBinding + for row, rerr := range df.All(ctx) { + if rerr != nil { + yield(zero, rerr) + return + } + if bindings == nil { + b, berr := d.plan.bind(row.FieldNames()) + if berr != nil { + yield(zero, berr) + return + } + bindings = b + } + var out T + if derr := decodeRow(d.plan, row.Values(), bindings, &out); derr != nil { + yield(zero, derr) + return + } + if !yield(out, nil) { + return + } + } + } +} diff --git a/spark/sql/dataset_methods_test.go b/spark/sql/dataset_methods_test.go new file mode 100644 index 0000000..f1cfe55 --- /dev/null +++ b/spark/sql/dataset_methods_test.go @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "reflect" + "testing" +) + +// Where / Limit / OrderBy are lazy — they queue an op on the +// DataFrameOf without touching the underlying DataFrame. Verify the +// queue state, not the materialised output (the latter requires a +// Spark Connect endpoint and lives in the integration suite). + +func newTestDataset(t *testing.T) *DataFrameOf[typedUser] { + t.Helper() + plan, err := buildRowPlan(reflect.TypeOf(typedUser{})) + if err != nil { + t.Fatalf("buildRowPlan: %v", err) + } + return &DataFrameOf[typedUser]{plan: plan} +} + +func TestDataset_WhereQueuesOp(t *testing.T) { + ds := newTestDataset(t).Where("country = 'UK'") + if len(ds.ops) != 1 { + t.Fatalf("expected 1 op after Where, got %d", len(ds.ops)) + } +} + +func TestDataset_LimitQueuesOp(t *testing.T) { + ds := newTestDataset(t).Limit(10) + if len(ds.ops) != 1 { + t.Fatalf("expected 1 op after Limit, got %d", len(ds.ops)) + } +} + +func TestDataset_OrderByQueuesOp(t *testing.T) { + ds := newTestDataset(t).OrderBy("created_at", "id") + if len(ds.ops) != 1 { + t.Fatalf("expected 1 op after OrderBy, got %d", len(ds.ops)) + } +} + +func TestDataset_ChainableWhereLimitOrderBy(t *testing.T) { + ds := newTestDataset(t). + Where("country = 'UK'"). + OrderBy("created_at"). + Limit(10) + if len(ds.ops) != 3 { + t.Fatalf("expected 3 ops after chain, got %d", len(ds.ops)) + } +} + +func TestDataset_CloneIsolatesOps(t *testing.T) { + parent := newTestDataset(t).Where("a = 1") + child := parent.Where("b = 2") + // parent should still have exactly 1 op; child 2. + if len(parent.ops) != 1 { + t.Errorf("parent mutated: %d ops, want 1", len(parent.ops)) + } + if len(child.ops) != 2 { + t.Errorf("child ops: %d, want 2", len(child.ops)) + } +} + +func TestDataset_ResolveWithNoOpsReturnsUnderlying(t *testing.T) { + // A dataset with no queued ops should resolve to its own df field + // unchanged. Use a sentinel DataFrame that errors on any method; + // resolve must not touch it. + var called bool + fake := &sentinelDataFrame{mark: &called} + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + ds := &DataFrameOf[typedUser]{df: fake, plan: plan} + got, err := ds.resolveDataFrame(context.Background()) + if err != nil { + t.Fatalf("resolveDataFrame: %v", err) + } + if got != fake { + t.Errorf("expected underlying df returned unchanged") + } + if called { + t.Errorf("no-op resolve should not invoke any DataFrame method") + } +} + +// sentinelDataFrame is a fake DataFrame whose only job is to be +// returned by resolve-with-no-ops. Methods panic or flag a marker — +// used to prove resolveDataFrame isn't touching the underlying. +type sentinelDataFrame struct { + DataFrame + mark *bool +} From 55f88a6f54cf47adc37c76d9d236acc2ace8cb5e Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 20:24:28 +0100 Subject: [PATCH 29/37] feat: ErrClusterNotReady + IsClusterNotReady + NewClusterNotReady (#19) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cluster warm-up is a generic Spark Connect concern — any cluster on any platform can be in a Pending state during cold start. This was sitting in datalakego/dorm; move the detection + sentinel to the fork where every Spark Connect consumer benefits, not just dorm. Surface (package sparkerrors): var ErrClusterNotReady = errors.New("sparkerrors: cluster not ready") type ClusterNotReady struct { State string RequestID string Message string Cause error } func (e *ClusterNotReady) Error() string func (e *ClusterNotReady) Unwrap() error func (e *ClusterNotReady) Is(target error) bool // matches ErrClusterNotReady func (e *ClusterNotReady) IsRetryable() bool // always true func IsClusterNotReady(err error) bool // errors.As convenience func NewClusterNotReady(err error) *ClusterNotReady NewClusterNotReady inspects err for the canonical [FailedPrecondition] + "state Pending" pattern and returns a typed ClusterNotReady if it matches, else nil. The string-matching looks fragile but is the exact detection pattern that's held across multiple Databricks runtime versions and is the shape self-managed Spark clusters emit too. Retry-loop wiring (spark/client/retry.go — the fork's existing retry code) lands in a follow-up PR alongside OpenSession + SessionOption (Feature 1c). This PR just lands the types so the retry code has something typed to inspect. Tests cover: canonical Databricks pattern detection (State and RequestID extracted correctly; Cause preserved); rejection on nil / unrelated errors / "[FailedPrecondition]" without state Pending; IsClusterNotReady matches typed, %%w-wrapped typed, and %%w-wrapped sentinel; typed.Is(ErrClusterNotReady) holds for errors.Is callers; Unwrap exposes the original cause. --- spark/sparkerrors/cluster_not_ready.go | 126 ++++++++++++++++++++ spark/sparkerrors/cluster_not_ready_test.go | 110 +++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 spark/sparkerrors/cluster_not_ready.go create mode 100644 spark/sparkerrors/cluster_not_ready_test.go diff --git a/spark/sparkerrors/cluster_not_ready.go b/spark/sparkerrors/cluster_not_ready.go new file mode 100644 index 0000000..44193fe --- /dev/null +++ b/spark/sparkerrors/cluster_not_ready.go @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sparkerrors + +import ( + "errors" + "fmt" + "strings" +) + +// ErrClusterNotReady is returned when a Spark cluster is in a +// warm-up state and the RPC is safe to retry after backoff. The +// canonical case is Databricks clusters returning +// `[FailedPrecondition]` with `state Pending` during cold start, +// but the pattern also fires for self-managed clusters whose +// drivers are still resolving JARs. +// +// Wrap the typed ClusterNotReady with %w so errors.Is lookups keep +// working: +// +// if errors.Is(err, sparkerrors.ErrClusterNotReady) { ... } +// if sparkerrors.IsClusterNotReady(err) { ... } +var ErrClusterNotReady = errors.New("sparkerrors: cluster not ready") + +// ClusterNotReady is the typed form of ErrClusterNotReady. Carries +// enough context (state string, request ID, original cause) for an +// operator to correlate with server-side logs. +type ClusterNotReady struct { + State string // e.g. "Pending", "PENDING" + RequestID string + Message string + Cause error +} + +func (e *ClusterNotReady) Error() string { + return fmt.Sprintf("cluster not ready (state=%s): %s", e.State, e.Message) +} + +func (e *ClusterNotReady) Unwrap() error { return e.Cause } + +// Is reports whether target is ErrClusterNotReady. Lets callers +// branch via errors.Is regardless of whether they hold the sentinel +// or the typed form. +func (e *ClusterNotReady) Is(target error) bool { return target == ErrClusterNotReady } + +// IsRetryable marks the error as safe to retry after backoff. +// Callers that drive their own retry loops can branch on this +// without pulling in the package's retry machinery. +func (e *ClusterNotReady) IsRetryable() bool { return true } + +// IsClusterNotReady is the convenience wrapper around errors.As for +// the ClusterNotReady type. Returns true whether err is the typed +// form or wraps the sentinel; either way a caller's retry loop +// makes the right call. +func IsClusterNotReady(err error) bool { + if err == nil { + return false + } + var typed *ClusterNotReady + if errors.As(err, &typed) { + return true + } + return errors.Is(err, ErrClusterNotReady) +} + +// NewClusterNotReady inspects err for the canonical +// `[FailedPrecondition]` + `state Pending` pattern and returns a +// typed ClusterNotReady if it matches, else nil. The string-matching +// looks fragile but is the detection pattern that's held across +// multiple Databricks runtime versions and is the shape self-managed +// Spark clusters emit too. +// +// Callers typically invoke this in an RPC error path: +// +// resp, err := cli.ExecutePlan(ctx, req) +// if cn := sparkerrors.NewClusterNotReady(err); cn != nil { +// // retry with backoff +// } +func NewClusterNotReady(err error) *ClusterNotReady { + if err == nil { + return nil + } + errStr := err.Error() + + if !strings.Contains(errStr, "[FailedPrecondition]") || + (!strings.Contains(errStr, "state Pending") && + !strings.Contains(errStr, "state PENDING")) { + return nil + } + + state := "PENDING" + if idx := strings.Index(errStr, "[state="); idx != -1 { + endIdx := strings.Index(errStr[idx:], "]") + if endIdx != -1 { + state = errStr[idx+7 : idx+endIdx] + } + } + + var requestID string + if idx := strings.Index(errStr, "(requestId="); idx != -1 { + endIdx := strings.Index(errStr[idx:], ")") + if endIdx != -1 { + requestID = errStr[idx+11 : idx+endIdx] + } + } + + return &ClusterNotReady{ + State: state, + RequestID: requestID, + Message: "The cluster is starting up. Please retry your request in a few moments.", + Cause: err, + } +} diff --git a/spark/sparkerrors/cluster_not_ready_test.go b/spark/sparkerrors/cluster_not_ready_test.go new file mode 100644 index 0000000..5e6d18f --- /dev/null +++ b/spark/sparkerrors/cluster_not_ready_test.go @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sparkerrors + +import ( + "errors" + "fmt" + "testing" +) + +func TestNewClusterNotReady_DetectsCanonicalDatabricksPattern(t *testing.T) { + // Shape ported verbatim from production Databricks responses. + raw := errors.New( + "rpc error: code = FailedPrecondition desc = " + + "[FailedPrecondition] cluster with id=0313-xyz is in state Pending " + + "[state=PENDING] (requestId=abc-123)", + ) + got := NewClusterNotReady(raw) + if got == nil { + t.Fatalf("NewClusterNotReady returned nil on canonical Databricks error: %v", raw) + } + if got.State != "PENDING" { + t.Errorf("State = %q, want PENDING", got.State) + } + if got.RequestID != "abc-123" { + t.Errorf("RequestID = %q, want abc-123", got.RequestID) + } + if got.Cause != raw { + t.Errorf("Cause should be the original error unchanged") + } + if !got.IsRetryable() { + t.Errorf("IsRetryable should be true") + } +} + +func TestNewClusterNotReady_RejectsUnrelatedErrors(t *testing.T) { + cases := []error{ + nil, + errors.New("connection refused"), + errors.New("rpc error: code = Internal desc = server error"), + errors.New("[FailedPrecondition] lock held"), // no state Pending marker + } + for _, e := range cases { + if got := NewClusterNotReady(e); got != nil { + t.Errorf("NewClusterNotReady(%v) = %v, want nil", e, got) + } + } +} + +func TestIsClusterNotReady_MatchesTypedAndSentinel(t *testing.T) { + raw := errors.New( + "[FailedPrecondition] state Pending starting up", + ) + typed := NewClusterNotReady(raw) + if !IsClusterNotReady(typed) { + t.Errorf("IsClusterNotReady on typed ClusterNotReady should be true") + } + // Wrap the typed error — IsClusterNotReady should still match + // because errors.As unwinds wrapping. + wrapped := fmt.Errorf("retrying: %w", typed) + if !IsClusterNotReady(wrapped) { + t.Errorf("IsClusterNotReady on wrapped typed error should be true") + } + // A bare sentinel wrap also matches for callers that don't + // produce the typed struct. + sentinelWrap := fmt.Errorf("context: %w", ErrClusterNotReady) + if !IsClusterNotReady(sentinelWrap) { + t.Errorf("IsClusterNotReady on %%w-wrapped sentinel should be true") + } +} + +func TestIsClusterNotReady_FalseOnNilAndUnrelated(t *testing.T) { + if IsClusterNotReady(nil) { + t.Errorf("IsClusterNotReady(nil) should be false") + } + if IsClusterNotReady(errors.New("some other error")) { + t.Errorf("IsClusterNotReady on unrelated error should be false") + } +} + +func TestClusterNotReady_IsMatchesSentinel(t *testing.T) { + // errors.Is should match the sentinel through the typed form's + // own Is method (tested independently of the convenience + // IsClusterNotReady helper). + typed := &ClusterNotReady{State: "PENDING"} + if !errors.Is(typed, ErrClusterNotReady) { + t.Errorf("errors.Is(typed, ErrClusterNotReady) should be true") + } +} + +func TestClusterNotReady_UnwrapExposesCause(t *testing.T) { + cause := errors.New("underlying grpc error") + typed := &ClusterNotReady{Cause: cause} + if got := errors.Unwrap(typed); got != cause { + t.Errorf("Unwrap returned %v, want %v", got, cause) + } +} From 8d64dd581725dba1db27641a8448aadc471d690f Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 20:27:46 +0100 Subject: [PATCH 30/37] feat: SqlAs[T] + TableAs[T] free functions; SqlTyped deprecated (#18) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: SqlAs[T] + TableAs[T] free functions; SqlTyped deprecated Rename SqlTyped -> SqlAs to match the Scala / Java precedent (and the bootstrap spec's nomenclature) and add TableAs as the table- addressed sibling. func SqlAs[T any](ctx, session, query) (*DataFrameOf[T], error) func TableAs[T any](ctx, session, name) (*DataFrameOf[T], error) Both are free functions rather than methods on SparkSession because Go doesn't allow type parameters on interface methods; the session is passed as an explicit second argument. Both wrap the untyped path (session.Sql / session.Table) and pass the result through TypedDataFrame[T] for plan-building, so the row plan is computed exactly once per call. SqlTyped is retained as a deprecated alias that simply delegates to SqlAs, with a //Deprecated comment. No external caller exists today — the fork just landed — but keeping the alias costs nothing and makes the rename zero-friction for any in-flight private usage. Tests: SqlAs forwards the query verbatim and returns a populated Dataset with a cached plan; SqlAs propagates errors from session.Sql unchanged; SqlAs rejects non-struct T with the established "must be a struct" message; TableAs mirrors these for session.Table; SqlTyped forwards identically to SqlAs. * chore: gofumpt sqlas_test.go stubSession alignment --------- --- spark/sql/dataframe_typed.go | 33 +++++++-- spark/sql/sqlas_test.go | 139 +++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 spark/sql/sqlas_test.go diff --git a/spark/sql/dataframe_typed.go b/spark/sql/dataframe_typed.go index 37fac9a..de7e889 100644 --- a/spark/sql/dataframe_typed.go +++ b/spark/sql/dataframe_typed.go @@ -90,11 +90,15 @@ func (d *DataFrameOf[T]) clone() *DataFrameOf[T] { return &cp } -// SqlTyped runs a SQL query and returns a typed DataFrame over the -// result. Equivalent to SparkSession.Sql followed by a struct-tag- -// driven scanner at every row — except the plan is computed once -// and reused for every Collect call on the returned value. -func SqlTyped[T any](ctx context.Context, session SparkSession, query string) (*DataFrameOf[T], error) { +// SqlAs runs a SQL query and returns a typed Dataset over the result. +// Equivalent to session.Sql followed by a struct-tag-driven scanner +// at every row — except the plan is computed once and reused for +// every Collect / Stream / First call on the returned Dataset. +// +// Free function rather than a method on SparkSession because Go +// doesn't permit type parameters on interface methods. The session +// is supplied as the second arg. +func SqlAs[T any](ctx context.Context, session SparkSession, query string) (*DataFrameOf[T], error) { df, err := session.Sql(ctx, query) if err != nil { return nil, err @@ -102,6 +106,25 @@ func SqlTyped[T any](ctx context.Context, session SparkSession, query string) (* return TypedDataFrame[T](df) } +// TableAs returns a typed Dataset over a named catalog table. Same +// shape as SqlAs but addressed by table name instead of ad-hoc SQL. +// Convenience over session.Table + TypedDataFrame[T]. +func TableAs[T any](ctx context.Context, session SparkSession, name string) (*DataFrameOf[T], error) { + df, err := session.Table(name) + if err != nil { + return nil, err + } + return TypedDataFrame[T](df) +} + +// SqlTyped is the legacy name for SqlAs, kept for source +// compatibility with pre-1.0 callers. Prefer SqlAs in new code. +// +// Deprecated: use SqlAs. +func SqlTyped[T any](ctx context.Context, session SparkSession, query string) (*DataFrameOf[T], error) { + return SqlAs[T](ctx, session, query) +} + // TypedDataFrame wraps an existing DataFrame in the typed surface. // Useful when the caller already holds a DataFrame produced by an // operation other than Sql (Read, Table, a chain of transformations). diff --git a/spark/sql/sqlas_test.go b/spark/sql/sqlas_test.go new file mode 100644 index 0000000..bd165d8 --- /dev/null +++ b/spark/sql/sqlas_test.go @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "errors" + "strings" + "testing" +) + +// stubSession returns pre-canned DataFrames / errors from Sql and +// Table so the SqlAs / TableAs call-paths can be exercised without a +// live Spark Connect endpoint. Only Sql and Table need to do anything; +// the other SparkSession methods never fire on these paths. +type stubSession struct { + SparkSession // embed for default no-op behaviour + sqlFn func(ctx context.Context, query string) (DataFrame, error) + tableFn func(name string) (DataFrame, error) +} + +func (s *stubSession) Sql(ctx context.Context, query string) (DataFrame, error) { + return s.sqlFn(ctx, query) +} +func (s *stubSession) Table(name string) (DataFrame, error) { return s.tableFn(name) } + +func TestSqlAs_ForwardsQueryAndWrapsInDataFrameOf(t *testing.T) { + var seen string + session := &stubSession{ + sqlFn: func(_ context.Context, query string) (DataFrame, error) { + seen = query + return nil, nil // no actual DataFrame; TypedDataFrame still builds the plan + }, + } + const q = "SELECT id, email FROM users" + ds, err := SqlAs[typedUser](context.Background(), session, q) + if err != nil { + t.Fatalf("SqlAs: %v", err) + } + if seen != q { + t.Errorf("session.Sql got %q, want %q", seen, q) + } + if ds == nil { + t.Fatalf("SqlAs returned nil Dataset on success") + } + if ds.plan == nil { + t.Error("SqlAs returned Dataset without a cached plan") + } +} + +func TestSqlAs_PropagatesSessionError(t *testing.T) { + wantErr := errors.New("connection refused") + session := &stubSession{ + sqlFn: func(_ context.Context, _ string) (DataFrame, error) { return nil, wantErr }, + } + _, err := SqlAs[typedUser](context.Background(), session, "SELECT 1") + if !errors.Is(err, wantErr) && err != wantErr { + t.Errorf("err = %v, want %v", err, wantErr) + } +} + +func TestSqlAs_RejectsNonStructT(t *testing.T) { + session := &stubSession{ + sqlFn: func(_ context.Context, _ string) (DataFrame, error) { return nil, nil }, + } + _, err := SqlAs[int](context.Background(), session, "SELECT 1") + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} + +func TestTableAs_ForwardsTableNameAndWrapsInDataFrameOf(t *testing.T) { + var seen string + session := &stubSession{ + tableFn: func(name string) (DataFrame, error) { + seen = name + return nil, nil + }, + } + const tbl = "users" + ds, err := TableAs[typedUser](context.Background(), session, tbl) + if err != nil { + t.Fatalf("TableAs: %v", err) + } + if seen != tbl { + t.Errorf("session.Table got %q, want %q", seen, tbl) + } + if ds == nil || ds.plan == nil { + t.Error("TableAs returned empty Dataset on success") + } +} + +func TestTableAs_PropagatesSessionError(t *testing.T) { + wantErr := errors.New("table not found") + session := &stubSession{ + tableFn: func(_ string) (DataFrame, error) { return nil, wantErr }, + } + _, err := TableAs[typedUser](context.Background(), session, "missing") + if !errors.Is(err, wantErr) && err != wantErr { + t.Errorf("err = %v, want %v", err, wantErr) + } +} + +func TestSqlTyped_DeprecatedAliasEqualsSqlAs(t *testing.T) { + // SqlTyped is retained as a deprecated alias that delegates to + // SqlAs. Verify it forwards identically — same query reaches the + // session, same plan is built. + var seen string + session := &stubSession{ + sqlFn: func(_ context.Context, query string) (DataFrame, error) { + seen = query + return nil, nil + }, + } + const q = "SELECT id FROM users" + ds, err := SqlTyped[typedUser](context.Background(), session, q) + if err != nil { + t.Fatalf("SqlTyped: %v", err) + } + if seen != q { + t.Errorf("SqlTyped forwarded %q, want %q", seen, q) + } + if ds == nil { + t.Fatalf("SqlTyped returned nil Dataset") + } +} From 0c8bf054df0030eb9daafda799b9dc6e69b3ff90 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 22:21:38 +0100 Subject: [PATCH 31/37] feat: database/sql driver over Spark Connect (#20) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: database/sql driver over Spark Connect (spark/sql/driver) Every Go tool that speaks database/sql — goose, sqlc, pgx consumers, ad-hoc test harnesses — can now target a Spark-backed lakehouse without learning the native client API. Strictly additive to the fork; zero changes to existing DataFrame surface. Usage: import ( "database/sql" _ "github.com/datalakego/spark-connect-go/spark/sql/driver" ) db, err := sql.Open("spark", "sc://localhost:15002?format=iceberg") DSN grammar: sc://host:port plain sc://host:port?token= bearer-token auth sc://host:port?format=iceberg|delta format passthrough for dialect-aware consumers (goose-spark reads this) Driver surface (all stdlib database/sql/driver interfaces): driver.Driver Open(dsn) / OpenConnector (modern ctx path) driver.Connector Connect(ctx) / Driver() driver.Conn Prepare / PrepareContext / Close / Begin / BeginTx / ExecContext / QueryContext / Ping driver.Stmt Close / NumInput / Exec / Query / ExecContext / QueryContext driver.Rows Columns / Close / Next (iterates a pre-Collect'd DataFrame) driver.Result RowsAffected=-1 (unknown), LastInsertId errors (no auto-increment in lakehouse) driver.Tx Commit = no-op, Rollback = error (lakehouse commit semantics are per- statement inside the table format; a SQL- layer Tx doesn't buy atomicity here) v0 scope is what goose needs: CREATE / INSERT / SELECT against a Spark Connect endpoint. Specifically: - NumInput() = 0. Parameter binding is not supported; callers that pass args see errArgsUnsupported. Migrations author literal SQL so this limitation is not felt; v1+ wires parameter binding via the Spark Connect parameterised-query proto. - No prepared-statement caching; every Exec/Query re-sends the statement. Right-sized for goose's request rate; v1+ adds a server-side plan cache if latency-sensitive callers appear. - Rows wraps a materialised DataFrame.Collect slice. Suits small reads (the SELECTs goose fires against the version table). Larger scans should bypass database/sql entirely and use the native DataFrame / iter.Seq2 path. Tests cover: DSN parsing (plain, with token, with format, case normalisation, missing scheme, empty, unknown format); sql.Register side-effect lands under the name "spark"; OpenConnector rejects bad DSNs; Stmt rejects args in both Exec + Query paths; Rows advances through a fake result and returns io.EOF on exhaustion; Result returns -1 for RowsAffected and errors for LastInsertId; Tx commits silently and errors on Rollback. ok github.com/datalakego/spark-connect-go/spark/sql/driver 0.606s Strictly additive: zero changes to existing files under spark/sql/. Diff against upstream apache/spark-connect-go master is additions only. * chore: gofumpt alignment on rowFake methods in driver_test.go --------- --- spark/sql/driver/conn.go | 141 +++++++++++++++++ spark/sql/driver/driver.go | 126 +++++++++++++++ spark/sql/driver/driver_test.go | 263 ++++++++++++++++++++++++++++++++ spark/sql/driver/dsn.go | 76 +++++++++ spark/sql/driver/result.go | 41 +++++ spark/sql/driver/rows.go | 75 +++++++++ spark/sql/driver/stmt.go | 116 ++++++++++++++ 7 files changed, 838 insertions(+) create mode 100644 spark/sql/driver/conn.go create mode 100644 spark/sql/driver/driver.go create mode 100644 spark/sql/driver/driver_test.go create mode 100644 spark/sql/driver/dsn.go create mode 100644 spark/sql/driver/result.go create mode 100644 spark/sql/driver/rows.go create mode 100644 spark/sql/driver/stmt.go diff --git a/spark/sql/driver/conn.go b/spark/sql/driver/conn.go new file mode 100644 index 0000000..2f3312a --- /dev/null +++ b/spark/sql/driver/conn.go @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "context" + "database/sql/driver" + "errors" + + sparksql "github.com/datalakego/spark-connect-go/spark/sql" +) + +// conn is the per-logical-connection state database/sql keeps in its +// pool. Each Conn holds one Spark Connect session; Close stops it. +// +// Implements: +// - driver.Conn (legacy Prepare / Close / Begin) +// - driver.ConnPrepareContext +// - driver.ConnBeginTx +// - driver.ExecerContext +// - driver.QueryerContext +// - driver.Pinger (Ping via session round-trip) +type conn struct { + session sparksql.SparkSession + cfg *dsnConfig + closed bool +} + +// Prepare wraps the SQL in a driver.Stmt. Legacy (non-context) +// variant of PrepareContext. No server-side preparation — Spark +// Connect doesn't expose a prepare primitive. Each Exec/Query +// re-sends the statement; v1+ adds a plan cache if latency matters. +// +// Implements database/sql/driver.Conn. +func (c *conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +// PrepareContext is the context-aware Prepare. Implements +// database/sql/driver.ConnPrepareContext. +func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return &stmt{conn: c, query: query}, nil +} + +// Close releases the underlying Spark Connect session. Idempotent. +// Implements database/sql/driver.Conn. +func (c *conn) Close() error { + if c.closed { + return nil + } + c.closed = true + if c.session == nil { + return nil + } + return c.session.Stop() +} + +// Begin returns a no-op transaction. Implements +// database/sql/driver.Conn (legacy; ConnBeginTx below is preferred). +func (c *conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +// BeginTx returns a no-op transaction. Lakehouse commit semantics +// (Iceberg snapshot commits, Delta transaction log appends) happen +// at the per-statement level inside Spark, not at a SQL-layer Tx +// boundary. Wrapping statements in a database/sql Tx doesn't buy +// atomicity for lakehouse writes, so this driver's Tx is a no-op +// that lets tools (goose, sqlc) that default to transactional +// execution not break. +// +// Implements database/sql/driver.ConnBeginTx. +func (c *conn) BeginTx(_ context.Context, _ driver.TxOptions) (driver.Tx, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return &tx{}, nil +} + +// ExecContext runs a statement that returns no rows. Forwards to +// the Stmt path so the NumInput-based argument check lives in one +// place. +// +// Implements database/sql/driver.ExecerContext. +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return (&stmt{conn: c, query: query}).ExecContext(ctx, args) +} + +// QueryContext runs a SELECT returning rows. Implements +// database/sql/driver.QueryerContext. +func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return (&stmt{conn: c, query: query}).QueryContext(ctx, args) +} + +// Ping checks the Spark Connect session is alive. Implements +// database/sql/driver.Pinger. +func (c *conn) Ping(ctx context.Context) error { + if c.closed { + return driver.ErrBadConn + } + _, err := c.session.Sql(ctx, "SELECT 1") + return err +} + +// tx is the no-op transaction described on BeginTx. +type tx struct{} + +// Commit succeeds silently. Implements database/sql/driver.Tx. +func (tx) Commit() error { return nil } + +// Rollback returns an error because lakehouse rollback isn't a +// meaningful concept at the database/sql layer. Callers that expect +// Rollback to un-apply statements see the error and can choose to +// handle it; tools that default to rollback-on-error (goose with +// failing migrations) get a signal that rollback didn't happen. +// Implements database/sql/driver.Tx. +func (tx) Rollback() error { + return errors.New("spark driver: transactional rollback is not supported; lakehouse commit semantics are per-statement") +} diff --git a/spark/sql/driver/driver.go b/spark/sql/driver/driver.go new file mode 100644 index 0000000..6c1f24a --- /dev/null +++ b/spark/sql/driver/driver.go @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package driver implements `database/sql`'s driver interfaces over +// Spark Connect, so every Go tool that speaks `database/sql` — goose, +// sqlc, pgx consumers, ad-hoc test harnesses, script code — can target +// a Spark-backed lakehouse without learning the native client API. +// +// The package registers itself on import as the driver name "spark": +// +// import ( +// "database/sql" +// _ "github.com/datalakego/spark-connect-go/spark/sql/driver" +// ) +// +// db, err := sql.Open("spark", "sc://localhost:15002?format=iceberg") +// +// DSN grammar is the plain Spark Connect URL with optional query +// parameters: +// +// sc://host:port +// sc://host:port?token= +// sc://host:port?format=iceberg +// sc://host:port?format=delta +// sc://host:port?token=&format=delta +// +// The `format` parameter is consumed by downstream tools (such as +// goose-spark) that need to emit format-specific DDL; the driver +// itself doesn't interpret it. +// +// v0 scope: sufficient for goose to create its version table, insert +// applied-migration rows, and select them back. Specifically: +// +// - Exec / Query against a Spark Connect endpoint via session.Sql. +// - No statement preparation (every call is `NumInput() == 0`; +// callers that pass arguments get an error so SQL injection +// isn't possible through this driver, and migrations author their +// own literal SQL anyway). +// - No-op transactions — Begin / Commit succeed silently, Rollback +// warns but does not error. Lakehouse commit semantics live in +// Iceberg / Delta at the per-statement level; wrapping statements +// in a `database/sql` Tx doesn't buy atomicity here. Documented +// in the package comment so a caller who expects real transaction +// semantics isn't surprised. +// +// v1+ adds: parameter binding via the Spark Connect parameterised- +// query proto, prepared-statement caching, catalog introspection for +// row metadata (column names + types) in the Rows wrapper. +package driver + +import ( + "context" + "database/sql" + "database/sql/driver" + + sparksql "github.com/datalakego/spark-connect-go/spark/sql" +) + +// init registers the driver under the name "spark". Consumers that +// only want the typed DataFrame surface don't have to import this +// package; consumers that use database/sql do. +func init() { + sql.Register("spark", &sparkDriver{}) +} + +// sparkDriver is the database/sql.Driver implementation. Satisfies +// both the legacy Driver interface (Open) and the modern +// DriverContext interface (OpenConnector). +type sparkDriver struct{} + +// Open opens a new connection using the given DSN. Implements +// database/sql/driver.Driver. +// +// database/sql calls Open for every new connection the pool wants. +// Keeping this lightweight — parse the DSN, return a Connector, +// defer actual session construction to Connect — matches how the +// standard library's pool expects drivers to behave. +func (d *sparkDriver) Open(dsn string) (driver.Conn, error) { + c, err := d.OpenConnector(dsn) + if err != nil { + return nil, err + } + return c.Connect(context.Background()) +} + +// OpenConnector returns a Connector ready to produce Conns. +// Implements database/sql/driver.DriverContext. +func (d *sparkDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := parseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{driver: d, cfg: cfg}, nil +} + +// connector wraps a parsed DSN and produces Conns on demand. +type connector struct { + driver *sparkDriver + cfg *dsnConfig +} + +// Connect opens a Spark Connect session and returns a driver.Conn +// wrapping it. Implements database/sql/driver.Connector. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + session, err := sparksql.NewSessionBuilder().Remote(c.cfg.sparkDSN).Build(ctx) + if err != nil { + return nil, err + } + return &conn{session: session, cfg: c.cfg}, nil +} + +// Driver returns the owning driver. Implements +// database/sql/driver.Connector. +func (c *connector) Driver() driver.Driver { return c.driver } diff --git a/spark/sql/driver/driver_test.go b/spark/sql/driver/driver_test.go new file mode 100644 index 0000000..1b850cf --- /dev/null +++ b/spark/sql/driver/driver_test.go @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "strings" + "testing" + + "github.com/datalakego/spark-connect-go/spark/sql/types" +) + +// --- DSN parsing --------------------------------------------------- + +func TestParseDSN_AcceptsPlainSc(t *testing.T) { + cfg, err := parseDSN("sc://localhost:15002") + if err != nil { + t.Fatalf("parseDSN: %v", err) + } + if cfg.sparkDSN != "sc://localhost:15002" { + t.Errorf("sparkDSN = %q", cfg.sparkDSN) + } + if cfg.format != "" { + t.Errorf("format = %q, want empty", cfg.format) + } +} + +func TestParseDSN_ParsesFormatParameter(t *testing.T) { + cfg, err := parseDSN("sc://localhost:15002?format=iceberg") + if err != nil { + t.Fatalf("parseDSN: %v", err) + } + if cfg.format != "iceberg" { + t.Errorf("format = %q, want iceberg", cfg.format) + } +} + +func TestParseDSN_NormalisesFormatCase(t *testing.T) { + cfg, err := parseDSN("sc://localhost:15002?format=DELTA") + if err != nil { + t.Fatalf("parseDSN: %v", err) + } + if cfg.format != "delta" { + t.Errorf("format = %q, want delta (lowercased)", cfg.format) + } +} + +func TestParseDSN_RejectsUnknownFormat(t *testing.T) { + _, err := parseDSN("sc://localhost:15002?format=duckdb") + if err == nil || !strings.Contains(err.Error(), "unsupported format") { + t.Errorf("err = %v, want unsupported-format error", err) + } +} + +func TestParseDSN_RejectsMissingSchemePrefix(t *testing.T) { + _, err := parseDSN("localhost:15002") + if err == nil || !strings.Contains(err.Error(), "sc://") { + t.Errorf("err = %v, want sc:// required", err) + } +} + +func TestParseDSN_RejectsEmpty(t *testing.T) { + _, err := parseDSN("") + if err == nil || !strings.Contains(err.Error(), "DSN is required") { + t.Errorf("err = %v, want DSN-required error", err) + } +} + +func TestParseDSN_PreservesTokenInSparkDSN(t *testing.T) { + // The token parameter is meaningful to the downstream session + // builder; the driver itself doesn't interpret it but must not + // strip it from the DSN forwarded to NewSessionBuilder().Remote. + cfg, err := parseDSN("sc://host:15002?token=secret&format=iceberg") + if err != nil { + t.Fatalf("parseDSN: %v", err) + } + if !strings.Contains(cfg.sparkDSN, "token=secret") { + t.Errorf("sparkDSN should carry token unchanged; got %q", cfg.sparkDSN) + } +} + +// --- sql.Register side-effect --------------------------------------- + +func TestDriver_RegistersAsSparkOnImport(t *testing.T) { + // Importing this package (the test binary already does) triggers + // init() which calls sql.Register("spark", ...). Verify the + // driver list contains "spark". + found := false + for _, name := range sql.Drivers() { + if name == "spark" { + found = true + break + } + } + if !found { + t.Errorf("driver list = %v; expected to contain \"spark\"", sql.Drivers()) + } +} + +func TestDriver_OpenConnectorInvalidDSN(t *testing.T) { + _, err := (&sparkDriver{}).OpenConnector("bogus://nope") + if err == nil { + t.Error("OpenConnector should reject non-sc:// DSN") + } +} + +// --- stmt argument rejection --------------------------------------- +// +// v0 doesn't implement parameter binding. Callers that pass args +// should see the errArgsUnsupported sentinel before any server call +// fires. + +func TestStmt_ExecContextRejectsArgs(t *testing.T) { + s := &stmt{conn: &conn{}, query: "SELECT 1"} + _, err := s.ExecContext(context.Background(), []driver.NamedValue{ + {Ordinal: 1, Value: 42}, + }) + if err == nil || !errors.Is(err, errArgsUnsupported) { + t.Errorf("ExecContext with args err = %v, want errArgsUnsupported", err) + } +} + +func TestStmt_QueryContextRejectsArgs(t *testing.T) { + s := &stmt{conn: &conn{}, query: "SELECT 1"} + _, err := s.QueryContext(context.Background(), []driver.NamedValue{ + {Ordinal: 1, Value: "x"}, + }) + if err == nil || !errors.Is(err, errArgsUnsupported) { + t.Errorf("QueryContext with args err = %v, want errArgsUnsupported", err) + } +} + +// --- Rows wrapper -------------------------------------------------- + +// rowFake is a minimal types.Row satisfying every method the +// interface declares. Only FieldNames + Values + Len + At get +// exercised from rows.go; the rest are required by the interface +// shape. +type rowFake struct { + names []string + vals []any +} + +func (r rowFake) FieldNames() []string { return r.names } +func (r rowFake) Values() []any { return r.vals } +func (r rowFake) Len() int { return len(r.vals) } +func (r rowFake) At(i int) any { + if i < 0 || i >= len(r.vals) { + return nil + } + return r.vals[i] +} + +func (r rowFake) Value(name string) any { + for i, n := range r.names { + if n == name { + return r.vals[i] + } + } + return nil +} +func (r rowFake) ToJsonString() (string, error) { return "", nil } + +func TestRows_ColumnsReportsFieldNamesFromFirstRow(t *testing.T) { + fakes := []types.Row{ + rowFake{names: []string{"id", "name"}, vals: []any{int64(1), "alice"}}, + rowFake{names: []string{"id", "name"}, vals: []any{int64(2), "bob"}}, + } + r := newRows(fakes) + got := r.Columns() + if len(got) != 2 || got[0] != "id" || got[1] != "name" { + t.Errorf("Columns = %v, want [id name]", got) + } +} + +func TestRows_ColumnsEmptyOnEmptyResult(t *testing.T) { + r := newRows(nil) + if len(r.Columns()) != 0 { + t.Errorf("Columns = %v on empty result, want []", r.Columns()) + } +} + +func TestRows_NextWalksResultAndStopsOnEOF(t *testing.T) { + fakes := []types.Row{ + rowFake{names: []string{"id", "name"}, vals: []any{int64(1), "alice"}}, + rowFake{names: []string{"id", "name"}, vals: []any{int64(2), "bob"}}, + } + r := newRows(fakes) + dest := make([]driver.Value, 2) + + if err := r.Next(dest); err != nil { + t.Fatalf("Next row 1: %v", err) + } + if dest[0].(int64) != 1 || dest[1].(string) != "alice" { + t.Errorf("row 1 = %v, want [1 alice]", dest) + } + + if err := r.Next(dest); err != nil { + t.Fatalf("Next row 2: %v", err) + } + if dest[0].(int64) != 2 || dest[1].(string) != "bob" { + t.Errorf("row 2 = %v, want [2 bob]", dest) + } + + if err := r.Next(dest); err != io.EOF { + t.Errorf("Next past end = %v, want io.EOF", err) + } +} + +// --- no-op transaction -------------------------------------------- + +func TestTx_CommitIsNoop(t *testing.T) { + var transaction tx + if err := transaction.Commit(); err != nil { + t.Errorf("Commit err = %v, want nil", err) + } +} + +func TestTx_RollbackErrors(t *testing.T) { + var transaction tx + err := transaction.Rollback() + if err == nil || !strings.Contains(err.Error(), "rollback is not supported") { + t.Errorf("Rollback err = %v, want rollback-not-supported error", err) + } +} + +// --- result -------------------------------------------------------- + +func TestResult_RowsAffectedReturnsUnknown(t *testing.T) { + r := result{} + n, err := r.RowsAffected() + if err != nil { + t.Errorf("RowsAffected err = %v, want nil", err) + } + if n != -1 { + t.Errorf("RowsAffected = %d, want -1 (unknown)", n) + } +} + +func TestResult_LastInsertIdErrors(t *testing.T) { + r := result{} + _, err := r.LastInsertId() + if err == nil || !strings.Contains(err.Error(), "not supported") { + t.Errorf("LastInsertId err = %v, want not-supported error", err) + } +} diff --git a/spark/sql/driver/dsn.go b/spark/sql/driver/dsn.go new file mode 100644 index 0000000..d80850f --- /dev/null +++ b/spark/sql/driver/dsn.go @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +// dsnConfig is the parsed DSN. Kept minimal — the Spark Connect +// session builder wants the original URL verbatim, so the config +// mostly exists to surface query-parameter knobs (format, token) +// that downstream tools want to read without re-parsing the DSN. +type dsnConfig struct { + // sparkDSN is what gets passed to NewSessionBuilder().Remote(...). + // Includes any `?token=` fragment if present. + sparkDSN string + + // format is the value of the `format` query parameter ("iceberg" + // / "delta" / empty). Driver ignores it; consumers that need + // dialect-aware DDL read it via the exported Connector.Format() + // accessor. + format string +} + +// parseDSN accepts the sc:// Spark Connect URL form with optional +// query parameters. Returns an error for malformed input. +// +// Recognised query parameters: +// +// - token: bearer token forwarded to the Spark Connect server in +// the Authorization header. Preserved in sparkDSN verbatim so +// the session builder picks it up. +// - format: lakehouse table format ("iceberg" | "delta"). Driver- +// layer passthrough; used by consumers like goose-spark to pick +// a dialect-appropriate CREATE TABLE. +// +// Unrecognised parameters are ignored (preserved in the DSN as-is) +// so the driver doesn't fight with future Spark Connect flags. +func parseDSN(dsn string) (*dsnConfig, error) { + if dsn == "" { + return nil, errors.New("spark driver: DSN is required") + } + if !strings.HasPrefix(dsn, "sc://") { + return nil, fmt.Errorf("spark driver: DSN must start with sc://, got %q", dsn) + } + u, err := url.Parse(dsn) + if err != nil { + return nil, fmt.Errorf("spark driver: parse DSN: %w", err) + } + + cfg := &dsnConfig{ + sparkDSN: dsn, + format: strings.ToLower(u.Query().Get("format")), + } + if cfg.format != "" && cfg.format != "iceberg" && cfg.format != "delta" { + return nil, fmt.Errorf( + "spark driver: unsupported format %q; expected iceberg or delta", cfg.format) + } + return cfg, nil +} diff --git a/spark/sql/driver/result.go b/spark/sql/driver/result.go new file mode 100644 index 0000000..b24384e --- /dev/null +++ b/spark/sql/driver/result.go @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import "errors" + +// result is the driver.Result returned by every Exec / ExecContext. +// Spark Connect's session.Sql doesn't surface a row-count the way +// Postgres' CommandComplete message does, so we return -1 per the +// driver.Result docs: "RowsAffected returns the number of rows +// affected by an update, insert, or delete. Not every database or +// database driver may support this." +type result struct{} + +// LastInsertId is not supported — lakehouse tables don't have +// auto-increment semantics. Returns an error so callers that reach +// for it see the intent clearly rather than getting a misleading +// zero. +func (result) LastInsertId() (int64, error) { + return 0, errors.New("spark driver: LastInsertId is not supported; lakehouse tables have no auto-increment sequence") +} + +// RowsAffected returns -1 to signal "unknown." Matches the +// convention several other drivers (notably snowflake, clickhouse) +// follow when the underlying engine doesn't emit the metric. +func (result) RowsAffected() (int64, error) { + return -1, nil +} diff --git a/spark/sql/driver/rows.go b/spark/sql/driver/rows.go new file mode 100644 index 0000000..fffdd9f --- /dev/null +++ b/spark/sql/driver/rows.go @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql/driver" + "io" + + "github.com/datalakego/spark-connect-go/spark/sql/types" +) + +// rows wraps a materialised slice of sparksql types.Row values so +// they satisfy database/sql/driver.Rows. The SELECT path in stmt.go +// collects a DataFrame first and hands the result here — fine for +// the small reads goose makes against the version table; larger +// scans should bypass the database/sql driver and use the native +// DataFrame / iter.Seq2 path instead. +type rows struct { + rows []types.Row + cols []string + index int +} + +// newRows builds a Rows handle from a slice of sparksql rows. +// Columns are read from the first row's FieldNames; if the slice +// is empty, the column list is empty and Next immediately reports +// io.EOF. +func newRows(collected []types.Row) *rows { + var cols []string + if len(collected) > 0 { + cols = collected[0].FieldNames() + } + return &rows{rows: collected, cols: cols} +} + +// Columns returns the column names. Implements +// database/sql/driver.Rows. +func (r *rows) Columns() []string { return r.cols } + +// Close releases any driver-held resources. The slice is already +// fully materialised; this is a no-op. +// Implements database/sql/driver.Rows. +func (r *rows) Close() error { return nil } + +// Next advances to the next row and writes its values into dest. +// Returns io.EOF when the slice is exhausted. Implements +// database/sql/driver.Rows. +func (r *rows) Next(dest []driver.Value) error { + if r.index >= len(r.rows) { + return io.EOF + } + values := r.rows[r.index].Values() + for i := range dest { + if i >= len(values) { + dest[i] = nil + continue + } + dest[i] = values[i] + } + r.index++ + return nil +} diff --git a/spark/sql/driver/stmt.go b/spark/sql/driver/stmt.go new file mode 100644 index 0000000..811a4d2 --- /dev/null +++ b/spark/sql/driver/stmt.go @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "context" + "database/sql/driver" + "errors" +) + +// stmt wraps one query string against a conn. The driver doesn't +// cache a prepared plan on the server; every Exec/Query re-sends +// the statement. Works for goose's load (dozens of statements per +// migration run) and keeps the implementation simple; v1+ adds a +// server-side plan cache if latency-sensitive callers appear. +// +// Implements: +// - driver.Stmt (legacy Close / NumInput / Exec / Query) +// - driver.StmtExecContext +// - driver.StmtQueryContext +type stmt struct { + conn *conn + query string +} + +// Close is a no-op; the statement holds no server-side state. +// Implements database/sql/driver.Stmt. +func (*stmt) Close() error { return nil } + +// NumInput reports the number of placeholders this statement +// expects. v0 doesn't support parameter binding, so we return 0 +// and reject non-empty arg slices in ExecContext / QueryContext. +// goose's generated SQL (CreateVersionTable / InsertVersion) and +// hand-authored migration files don't use placeholders, so this +// limitation isn't felt in practice. +// +// Implements database/sql/driver.Stmt. +func (*stmt) NumInput() int { return 0 } + +// Exec is the legacy (non-context) Exec entry. Implements +// database/sql/driver.Stmt. +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { + named := make([]driver.NamedValue, len(args)) + for i, v := range args { + named[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return s.ExecContext(context.Background(), named) +} + +// Query is the legacy (non-context) Query entry. Implements +// database/sql/driver.Stmt. +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { + named := make([]driver.NamedValue, len(args)) + for i, v := range args { + named[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return s.QueryContext(context.Background(), named) +} + +// ExecContext runs the statement and discards rows. Returns a +// Result with RowsAffected=-1 because Spark Connect doesn't surface +// that metric reliably at the session.Sql layer. database/sql +// allows -1 per the driver.Result docs. +// +// Implements database/sql/driver.StmtExecContext. +func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if len(args) > 0 { + return nil, errArgsUnsupported + } + _, err := s.conn.session.Sql(ctx, s.query) + if err != nil { + return nil, err + } + return result{}, nil +} + +// QueryContext runs the statement and returns rows. The underlying +// DataFrame is materialised via Collect and walked by the Rows +// wrapper. For small result sets (the version-table SELECTs goose +// fires) this is right-sized; large result sets should bypass the +// driver and use the native DataFrame / iter.Seq2 path directly. +// +// Implements database/sql/driver.StmtQueryContext. +func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if len(args) > 0 { + return nil, errArgsUnsupported + } + df, err := s.conn.session.Sql(ctx, s.query) + if err != nil { + return nil, err + } + rows, err := df.Collect(ctx) + if err != nil { + return nil, err + } + return newRows(rows), nil +} + +// errArgsUnsupported is the v0 signal that parameter binding isn't +// implemented. Sentinel so callers can errors.Is on it. +var errArgsUnsupported = errors.New( + "spark driver: parameter binding is not supported in v0; interpolate values into the SQL string or upgrade when v1 ships", +) From d36ed13f34e2f35494b169167336d837f3f9b4ca Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Sun, 19 Apr 2026 22:47:51 +0100 Subject: [PATCH 32/37] refactor: trim typed API to edge-wrapper shape only (#21) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The typed surface on the fork drifted past its charter. DataFrameOf[T] had Where / Limit / OrderBy transformation methods, and the package exposed SqlAs / TableAs / SqlTyped / TypedDataFrame / Into on top of As. The design contract (lakeorm/TYPING.md) is that typed wrappers are edge-only — transformations stay on the untyped DataFrame; the type parameter participates only at materialisation. Trim to exactly the promised surface: - Type: DataFrameOf[T any] - New: As[T](df) (*DataFrameOf[T], error) -- constructor - Keep: Collect[T](ctx, df) ([]T, error) -- free fn - Keep: Stream[T](ctx, df) iter.Seq2[T, error] -- free fn - Keep: First[T](ctx, df) (*T, error) -- free fn - Methods: (*DataFrameOf[T]).{DataFrame, Collect, Stream, First} Removed: .Where / .Limit / .OrderBy methods, Dataset[T] alias, SqlAs / TableAs / SqlTyped / TypedDataFrame free functions, Into non-generic scanner, and the internal ops/datasetOp/resolveDataFrame/ clone machinery that existed only to support the removed transformation methods. Callers that were using SqlAs / TableAs move to session.Sql / session.Table + As[T]. Callers that were using .Where / .Limit / .OrderBy move to untyped DataFrame.Where / Limit / Sort + re-typing via As[T] at the materialisation edge. --- spark/sql/dataframe_typed.go | 182 ++++++++++++++---------------- spark/sql/dataframe_typed_test.go | 10 +- spark/sql/dataset_methods.go | 126 --------------------- spark/sql/dataset_methods_test.go | 107 ------------------ spark/sql/sqlas_test.go | 139 ----------------------- spark/sql/typed_helpers.go | 154 +++---------------------- spark/sql/typed_helpers_test.go | 46 ++------ 7 files changed, 110 insertions(+), 654 deletions(-) delete mode 100644 spark/sql/dataset_methods.go delete mode 100644 spark/sql/dataset_methods_test.go delete mode 100644 spark/sql/sqlas_test.go diff --git a/spark/sql/dataframe_typed.go b/spark/sql/dataframe_typed.go index de7e889..cd41f0a 100644 --- a/spark/sql/dataframe_typed.go +++ b/spark/sql/dataframe_typed.go @@ -18,6 +18,7 @@ package sql import ( "context" "fmt" + "iter" "reflect" "strings" "sync" @@ -27,10 +28,16 @@ import ( "github.com/apache/arrow-go/v18/arrow" ) -// DataFrameOf[T] is a typed view on a regular DataFrame. Users -// parameterise it on a Go struct; Collect decodes rows directly -// into T instead of handing back []any that callers have to -// type-assert field by field. +// DataFrameOf[T] is an edge-typing wrapper around a regular DataFrame. +// Its role is narrow: cache the reflected row plan so a caller who +// already knows the result shape can materialise into []T / iter / +// *T without re-validating T on every call. +// +// The wrapper deliberately has no transformation methods (no Where, +// Limit, OrderBy, Select, Join, GroupBy). Transformations change row +// shape and make the type parameter lie. Callers compose +// transformations on the untyped DataFrame, then re-type at the edge +// via As[T] when ready to collect. // // Column binding uses struct tags in the same shape sqlx / parquet-go // already use: @@ -44,93 +51,27 @@ import ( // Fields without a `spark:"..."` tag are mapped by snake_case'd field // name, so a plain Go struct works without any tags at all. Fields // tagged `spark:"-"` are skipped. Columns in the DataFrame that -// don't match any field are ignored — typical of projections -// narrower than the struct. +// don't match any field are ignored — typical of projections narrower +// than the struct. // // Schema drift (a struct field that the result's projection doesn't -// contain) surfaces at the first Collect call as a single error -// rather than per-row panics. -// -// Streaming (iter.Seq2 over T) is intentionally not in v0 — the -// ExecutePlanClient plumbing wants a dedicated PR so it can ship -// alongside a proper test matrix. For now, users who need streaming -// call DataFrame() to drop to the untyped ToRecordSequence path. +// contain) surfaces at the first Collect / Stream / First call as a +// single error rather than per-row panics. type DataFrameOf[T any] struct { df DataFrame plan *rowPlan - ops []datasetOp // lazy Where/Limit/OrderBy transforms; applied on materialise -} - -// datasetOp is a pending transform on the underlying DataFrame. Ops -// are queued by the chainable Where/Limit/OrderBy methods and applied -// at Collect/Stream/First time so the fluent builder stays ctx-free. -type datasetOp func(ctx context.Context, df DataFrame) (DataFrame, error) - -// resolveDataFrame applies every queued op in declaration order and -// returns the final DataFrame. Used by Collect/Stream/First to get a -// materialisable handle. -func (d *DataFrameOf[T]) resolveDataFrame(ctx context.Context) (DataFrame, error) { - df := d.df - for _, op := range d.ops { - next, err := op(ctx, df) - if err != nil { - return nil, err - } - df = next - } - return df, nil -} - -// clone returns a shallow copy of d with a freshly allocated ops -// slice so that chained operations don't share state with the parent. -// Chainable methods return the clone. -func (d *DataFrameOf[T]) clone() *DataFrameOf[T] { - cp := *d - cp.ops = append([]datasetOp(nil), d.ops...) - return &cp -} - -// SqlAs runs a SQL query and returns a typed Dataset over the result. -// Equivalent to session.Sql followed by a struct-tag-driven scanner -// at every row — except the plan is computed once and reused for -// every Collect / Stream / First call on the returned Dataset. -// -// Free function rather than a method on SparkSession because Go -// doesn't permit type parameters on interface methods. The session -// is supplied as the second arg. -func SqlAs[T any](ctx context.Context, session SparkSession, query string) (*DataFrameOf[T], error) { - df, err := session.Sql(ctx, query) - if err != nil { - return nil, err - } - return TypedDataFrame[T](df) } -// TableAs returns a typed Dataset over a named catalog table. Same -// shape as SqlAs but addressed by table name instead of ad-hoc SQL. -// Convenience over session.Table + TypedDataFrame[T]. -func TableAs[T any](ctx context.Context, session SparkSession, name string) (*DataFrameOf[T], error) { - df, err := session.Table(name) - if err != nil { - return nil, err - } - return TypedDataFrame[T](df) -} - -// SqlTyped is the legacy name for SqlAs, kept for source -// compatibility with pre-1.0 callers. Prefer SqlAs in new code. +// As wraps df in the typed surface. T must be a struct; the row plan +// is computed and cached once so subsequent Collect / Stream / First +// calls don't re-reflect. Schema compatibility with T is validated +// lazily — the first materialisation surfaces drift, not As itself. // -// Deprecated: use SqlAs. -func SqlTyped[T any](ctx context.Context, session SparkSession, query string) (*DataFrameOf[T], error) { - return SqlAs[T](ctx, session, query) -} - -// TypedDataFrame wraps an existing DataFrame in the typed surface. -// Useful when the caller already holds a DataFrame produced by an -// operation other than Sql (Read, Table, a chain of transformations). -// Computes and caches the row plan immediately; a malformed struct -// surfaces here rather than per-row inside Collect. -func TypedDataFrame[T any](df DataFrame) (*DataFrameOf[T], error) { +// This is the only supported constructor for DataFrameOf[T]. Callers +// hold a DataFrame (from session.Sql, session.Table, a chain of +// untyped transformations) and hand it here at the point the result +// shape is known. +func As[T any](df DataFrame) (*DataFrameOf[T], error) { var zero T rt := reflect.TypeOf(zero) if rt == nil || rt.Kind() != reflect.Struct { @@ -144,29 +85,25 @@ func TypedDataFrame[T any](df DataFrame) (*DataFrameOf[T], error) { } // DataFrame returns the underlying untyped DataFrame. Escape hatch -// for operations the typed surface doesn't cover — GroupBy, joins, -// window functions. Chain freely and call TypedDataFrame again on -// the result when the output shape is known. +// for transformations the typed surface deliberately doesn't carry +// — GroupBy, joins, window functions, Where, Limit, OrderBy. Chain +// on the untyped handle, then call As[T] again when the output +// shape is known. func (d *DataFrameOf[T]) DataFrame() DataFrame { return d.df } // Collect materialises every row into a []T. Holds the whole table // on the heap for the duration of the call — callers with large -// result sets should project narrower on the SQL side or drop to -// the untyped streaming path via DataFrame(). +// result sets should project narrower on the SQL side or use Stream +// for constant-memory iteration. func (d *DataFrameOf[T]) Collect(ctx context.Context) ([]T, error) { - df, err := d.resolveDataFrame(ctx) - if err != nil { - return nil, err - } - rows, err := df.Collect(ctx) + rows, err := d.df.Collect(ctx) if err != nil { return nil, err } if len(rows) == 0 { return nil, nil } - cols := rows[0].FieldNames() - bindings, err := d.plan.bind(cols) + bindings, err := d.plan.bind(rows[0].FieldNames()) if err != nil { return nil, err } @@ -179,6 +116,58 @@ func (d *DataFrameOf[T]) Collect(ctx context.Context) ([]T, error) { return out, nil } +// Stream yields typed rows one at a time with constant memory. Uses +// the untyped DataFrame.All streaming primitive underneath. Consumers +// range with Go 1.23's iter.Seq2: +// +// for row, err := range ds.Stream(ctx) { +// if err != nil { break } +// // use row +// } +func (d *DataFrameOf[T]) Stream(ctx context.Context) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + var zero T + var bindings []columnBinding + for row, rerr := range d.df.All(ctx) { + if rerr != nil { + yield(zero, rerr) + return + } + if bindings == nil { + b, berr := d.plan.bind(row.FieldNames()) + if berr != nil { + yield(zero, berr) + return + } + bindings = b + } + var out T + if derr := decodeRow(d.plan, row.Values(), bindings, &out); derr != nil { + yield(zero, derr) + return + } + if !yield(out, nil) { + return + } + } + } +} + +// First returns a pointer to the first row as T, or ErrNotFound when +// the DataFrame produces zero rows. Runs Collect underneath at v0; +// a LIMIT 1 pushdown lands alongside the Spark Connect plan-cache +// work in a future cycle. +func (d *DataFrameOf[T]) First(ctx context.Context) (*T, error) { + rows, err := d.Collect(ctx) + if err != nil { + return nil, err + } + if len(rows) == 0 { + return nil, ErrNotFound + } + return &rows[0], nil +} + // rowPlan caches the reflected structure of T so Collect doesn't // reflect on every row. Built once per DataFrameOf[T]. type rowPlan struct { @@ -317,9 +306,6 @@ func fieldByIndex(v reflect.Value, index []int) reflect.Value { // columns, assignable/convertible for primitives. Anything rarer // surfaces as an explicit error so callers can tighten their struct // to match. -// -// Named with the `Typed` suffix to avoid colliding with the existing -// assignValue helper in other files. func assignTypedValue(dst reflect.Value, src any) error { if src == nil { dst.Set(reflect.Zero(dst.Type())) diff --git a/spark/sql/dataframe_typed_test.go b/spark/sql/dataframe_typed_test.go index 1224573..0261fa0 100644 --- a/spark/sql/dataframe_typed_test.go +++ b/spark/sql/dataframe_typed_test.go @@ -127,12 +127,12 @@ func TestDecodeRow_NilIsZero(t *testing.T) { assert.Zero(t, out.S) } -func TestTypedDataFrame_RejectsNonStruct(t *testing.T) { - // TypedDataFrame is supposed to surface the misuse at construction - // time, not at Collect. A map / slice / primitive should fail - // clearly with a pointer back at the caller's T. +func TestAs_RejectsNonStruct(t *testing.T) { + // As[T] surfaces the misuse at construction time, not at Collect. + // A map / slice / primitive should fail clearly with a pointer + // back at the caller's T. type notAStruct = map[string]string - _, err := TypedDataFrame[notAStruct](nil) + _, err := As[notAStruct](nil) require.Error(t, err) assert.Contains(t, err.Error(), "T must be a struct") } diff --git a/spark/sql/dataset_methods.go b/spark/sql/dataset_methods.go deleted file mode 100644 index c6613a9..0000000 --- a/spark/sql/dataset_methods.go +++ /dev/null @@ -1,126 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "context" - "iter" - - "github.com/datalakego/spark-connect-go/spark/sql/column" - "github.com/datalakego/spark-connect-go/spark/sql/functions" -) - -// Where adds a filter predicate. Lazy: the condition is applied to -// the underlying DataFrame when Collect / Stream / First materialises -// the Dataset. Chainable — each call narrows the projection further. -// -// sqlStr is a Spark SQL fragment (e.g. "country = 'UK'" or -// "id IN ('a', 'b', 'c')"). args is accepted for API compatibility -// with dorm.Query's signature but currently ignored — the underlying -// DataFrame.Where takes a bare string; callers interpolate values -// with fmt.Sprintf or build the fragment via the functions package. -func (d *DataFrameOf[T]) Where(sqlStr string, args ...any) *DataFrameOf[T] { - _ = args - cp := d.clone() - cp.ops = append(cp.ops, func(ctx context.Context, df DataFrame) (DataFrame, error) { - return df.Where(ctx, sqlStr) - }) - return cp -} - -// Limit caps the number of rows materialised. Chainable; repeated -// calls each produce their own Limit relation in the underlying plan, -// and Spark's optimiser collapses them to the minimum. -func (d *DataFrameOf[T]) Limit(n int) *DataFrameOf[T] { - cp := d.clone() - cp.ops = append(cp.ops, func(ctx context.Context, df DataFrame) (DataFrame, error) { - return df.Limit(ctx, int32(n)), nil - }) - return cp -} - -// OrderBy adds an ascending sort by one or more columns. Callers who -// need descending order, null-ordering modifiers, or expression-based -// sort keys drop to DataFrame() and invoke Sort directly with -// column.Convertible values. -func (d *DataFrameOf[T]) OrderBy(columns ...string) *DataFrameOf[T] { - cp := d.clone() - cp.ops = append(cp.ops, func(ctx context.Context, df DataFrame) (DataFrame, error) { - cols := make([]column.Convertible, 0, len(columns)) - for _, name := range columns { - cols = append(cols, functions.Col(name)) - } - return df.Sort(ctx, cols...) - }) - return cp -} - -// First returns a pointer to the first row as T. Applies Limit(1) -// under the hood before materialising. Returns ErrNotFound when the -// (possibly filtered) DataFrame has zero rows. -func (d *DataFrameOf[T]) First(ctx context.Context) (*T, error) { - rows, err := d.Limit(1).Collect(ctx) - if err != nil { - return nil, err - } - if len(rows) == 0 { - return nil, ErrNotFound - } - return &rows[0], nil -} - -// Stream yields typed rows one at a time with constant memory, -// honouring any queued Where / Limit / OrderBy. Uses the untyped -// DataFrame.All streaming primitive underneath. Consumers range with -// Go 1.23's iter.Seq2: -// -// for row, err := range ds.Stream(ctx) { -// if err != nil { break } -// // use row -// } -func (d *DataFrameOf[T]) Stream(ctx context.Context) iter.Seq2[T, error] { - return func(yield func(T, error) bool) { - var zero T - df, err := d.resolveDataFrame(ctx) - if err != nil { - yield(zero, err) - return - } - var bindings []columnBinding - for row, rerr := range df.All(ctx) { - if rerr != nil { - yield(zero, rerr) - return - } - if bindings == nil { - b, berr := d.plan.bind(row.FieldNames()) - if berr != nil { - yield(zero, berr) - return - } - bindings = b - } - var out T - if derr := decodeRow(d.plan, row.Values(), bindings, &out); derr != nil { - yield(zero, derr) - return - } - if !yield(out, nil) { - return - } - } - } -} diff --git a/spark/sql/dataset_methods_test.go b/spark/sql/dataset_methods_test.go deleted file mode 100644 index f1cfe55..0000000 --- a/spark/sql/dataset_methods_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "context" - "reflect" - "testing" -) - -// Where / Limit / OrderBy are lazy — they queue an op on the -// DataFrameOf without touching the underlying DataFrame. Verify the -// queue state, not the materialised output (the latter requires a -// Spark Connect endpoint and lives in the integration suite). - -func newTestDataset(t *testing.T) *DataFrameOf[typedUser] { - t.Helper() - plan, err := buildRowPlan(reflect.TypeOf(typedUser{})) - if err != nil { - t.Fatalf("buildRowPlan: %v", err) - } - return &DataFrameOf[typedUser]{plan: plan} -} - -func TestDataset_WhereQueuesOp(t *testing.T) { - ds := newTestDataset(t).Where("country = 'UK'") - if len(ds.ops) != 1 { - t.Fatalf("expected 1 op after Where, got %d", len(ds.ops)) - } -} - -func TestDataset_LimitQueuesOp(t *testing.T) { - ds := newTestDataset(t).Limit(10) - if len(ds.ops) != 1 { - t.Fatalf("expected 1 op after Limit, got %d", len(ds.ops)) - } -} - -func TestDataset_OrderByQueuesOp(t *testing.T) { - ds := newTestDataset(t).OrderBy("created_at", "id") - if len(ds.ops) != 1 { - t.Fatalf("expected 1 op after OrderBy, got %d", len(ds.ops)) - } -} - -func TestDataset_ChainableWhereLimitOrderBy(t *testing.T) { - ds := newTestDataset(t). - Where("country = 'UK'"). - OrderBy("created_at"). - Limit(10) - if len(ds.ops) != 3 { - t.Fatalf("expected 3 ops after chain, got %d", len(ds.ops)) - } -} - -func TestDataset_CloneIsolatesOps(t *testing.T) { - parent := newTestDataset(t).Where("a = 1") - child := parent.Where("b = 2") - // parent should still have exactly 1 op; child 2. - if len(parent.ops) != 1 { - t.Errorf("parent mutated: %d ops, want 1", len(parent.ops)) - } - if len(child.ops) != 2 { - t.Errorf("child ops: %d, want 2", len(child.ops)) - } -} - -func TestDataset_ResolveWithNoOpsReturnsUnderlying(t *testing.T) { - // A dataset with no queued ops should resolve to its own df field - // unchanged. Use a sentinel DataFrame that errors on any method; - // resolve must not touch it. - var called bool - fake := &sentinelDataFrame{mark: &called} - plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) - ds := &DataFrameOf[typedUser]{df: fake, plan: plan} - got, err := ds.resolveDataFrame(context.Background()) - if err != nil { - t.Fatalf("resolveDataFrame: %v", err) - } - if got != fake { - t.Errorf("expected underlying df returned unchanged") - } - if called { - t.Errorf("no-op resolve should not invoke any DataFrame method") - } -} - -// sentinelDataFrame is a fake DataFrame whose only job is to be -// returned by resolve-with-no-ops. Methods panic or flag a marker — -// used to prove resolveDataFrame isn't touching the underlying. -type sentinelDataFrame struct { - DataFrame - mark *bool -} diff --git a/spark/sql/sqlas_test.go b/spark/sql/sqlas_test.go deleted file mode 100644 index bd165d8..0000000 --- a/spark/sql/sqlas_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "context" - "errors" - "strings" - "testing" -) - -// stubSession returns pre-canned DataFrames / errors from Sql and -// Table so the SqlAs / TableAs call-paths can be exercised without a -// live Spark Connect endpoint. Only Sql and Table need to do anything; -// the other SparkSession methods never fire on these paths. -type stubSession struct { - SparkSession // embed for default no-op behaviour - sqlFn func(ctx context.Context, query string) (DataFrame, error) - tableFn func(name string) (DataFrame, error) -} - -func (s *stubSession) Sql(ctx context.Context, query string) (DataFrame, error) { - return s.sqlFn(ctx, query) -} -func (s *stubSession) Table(name string) (DataFrame, error) { return s.tableFn(name) } - -func TestSqlAs_ForwardsQueryAndWrapsInDataFrameOf(t *testing.T) { - var seen string - session := &stubSession{ - sqlFn: func(_ context.Context, query string) (DataFrame, error) { - seen = query - return nil, nil // no actual DataFrame; TypedDataFrame still builds the plan - }, - } - const q = "SELECT id, email FROM users" - ds, err := SqlAs[typedUser](context.Background(), session, q) - if err != nil { - t.Fatalf("SqlAs: %v", err) - } - if seen != q { - t.Errorf("session.Sql got %q, want %q", seen, q) - } - if ds == nil { - t.Fatalf("SqlAs returned nil Dataset on success") - } - if ds.plan == nil { - t.Error("SqlAs returned Dataset without a cached plan") - } -} - -func TestSqlAs_PropagatesSessionError(t *testing.T) { - wantErr := errors.New("connection refused") - session := &stubSession{ - sqlFn: func(_ context.Context, _ string) (DataFrame, error) { return nil, wantErr }, - } - _, err := SqlAs[typedUser](context.Background(), session, "SELECT 1") - if !errors.Is(err, wantErr) && err != wantErr { - t.Errorf("err = %v, want %v", err, wantErr) - } -} - -func TestSqlAs_RejectsNonStructT(t *testing.T) { - session := &stubSession{ - sqlFn: func(_ context.Context, _ string) (DataFrame, error) { return nil, nil }, - } - _, err := SqlAs[int](context.Background(), session, "SELECT 1") - if err == nil || !strings.Contains(err.Error(), "must be a struct") { - t.Errorf("want struct-required error, got %v", err) - } -} - -func TestTableAs_ForwardsTableNameAndWrapsInDataFrameOf(t *testing.T) { - var seen string - session := &stubSession{ - tableFn: func(name string) (DataFrame, error) { - seen = name - return nil, nil - }, - } - const tbl = "users" - ds, err := TableAs[typedUser](context.Background(), session, tbl) - if err != nil { - t.Fatalf("TableAs: %v", err) - } - if seen != tbl { - t.Errorf("session.Table got %q, want %q", seen, tbl) - } - if ds == nil || ds.plan == nil { - t.Error("TableAs returned empty Dataset on success") - } -} - -func TestTableAs_PropagatesSessionError(t *testing.T) { - wantErr := errors.New("table not found") - session := &stubSession{ - tableFn: func(_ string) (DataFrame, error) { return nil, wantErr }, - } - _, err := TableAs[typedUser](context.Background(), session, "missing") - if !errors.Is(err, wantErr) && err != wantErr { - t.Errorf("err = %v, want %v", err, wantErr) - } -} - -func TestSqlTyped_DeprecatedAliasEqualsSqlAs(t *testing.T) { - // SqlTyped is retained as a deprecated alias that delegates to - // SqlAs. Verify it forwards identically — same query reaches the - // session, same plan is built. - var seen string - session := &stubSession{ - sqlFn: func(_ context.Context, query string) (DataFrame, error) { - seen = query - return nil, nil - }, - } - const q = "SELECT id FROM users" - ds, err := SqlTyped[typedUser](context.Background(), session, q) - if err != nil { - t.Fatalf("SqlTyped: %v", err) - } - if seen != q { - t.Errorf("SqlTyped forwarded %q, want %q", seen, q) - } - if ds == nil { - t.Fatalf("SqlTyped returned nil Dataset") - } -} diff --git a/spark/sql/typed_helpers.go b/spark/sql/typed_helpers.go index 297b94a..31d33b1 100644 --- a/spark/sql/typed_helpers.go +++ b/spark/sql/typed_helpers.go @@ -18,26 +18,18 @@ package sql import ( "context" "errors" - "fmt" "iter" - "reflect" ) -// Dataset is a type alias for DataFrameOf[T]. Matches the Scala/Java -// Dataset[T] naming; DataFrameOf[T] remains as the original name and -// is fully interchangeable. -type Dataset[T any] = DataFrameOf[T] - // ErrNotFound is returned by First when the DataFrame produces zero // rows. var ErrNotFound = errors.New("spark: no rows returned") -// Collect materialises every row of df into a []T by wrapping df in -// the typed surface and calling Collect. Equivalent to -// TypedDataFrame[T](df).Collect(ctx) but written as a one-liner for -// callers who already hold a DataFrame. +// Collect materialises every row of df into a []T. Thin wrapper over +// As[T] + (*DataFrameOf[T]).Collect for callers who hold an untyped +// DataFrame and want the result in one call. func Collect[T any](ctx context.Context, df DataFrame) ([]T, error) { - typed, err := TypedDataFrame[T](df) + typed, err := As[T](df) if err != nil { return nil, err } @@ -46,9 +38,9 @@ func Collect[T any](ctx context.Context, df DataFrame) ([]T, error) { // Stream yields typed rows one at a time using the untyped // DataFrame's streaming primitive underneath. Constant memory -// regardless of result size. Schema binding happens on the first row; -// a subsequent row whose schema diverges from the first surfaces the -// error through the iterator. +// regardless of result size. Schema binding happens on the first +// row; a subsequent row whose schema diverges from the first +// surfaces the error through the iterator. // // Consumers range over the return value with Go 1.23's iter.Seq2: // @@ -59,31 +51,13 @@ func Collect[T any](ctx context.Context, df DataFrame) ([]T, error) { func Stream[T any](ctx context.Context, df DataFrame) iter.Seq2[T, error] { return func(yield func(T, error) bool) { var zero T - typed, err := TypedDataFrame[T](df) + typed, err := As[T](df) if err != nil { yield(zero, err) return } - var bindings []columnBinding - for row, rerr := range df.All(ctx) { - if rerr != nil { - yield(zero, rerr) - return - } - if bindings == nil { - b, berr := typed.plan.bind(row.FieldNames()) - if berr != nil { - yield(zero, berr) - return - } - bindings = b - } - var out T - if derr := decodeRow(typed.plan, row.Values(), bindings, &out); derr != nil { - yield(zero, derr) - return - } - if !yield(out, nil) { + for row, rerr := range typed.Stream(ctx) { + if !yield(row, rerr) { return } } @@ -91,113 +65,11 @@ func Stream[T any](ctx context.Context, df DataFrame) iter.Seq2[T, error] { } // First returns the first row of df decoded as T, or ErrNotFound if -// df produced no rows. Runs Collect underneath at v0; the DataFrame -// LIMIT optimisation lands when Dataset[T].Limit stabilises. +// df produced no rows. func First[T any](ctx context.Context, df DataFrame) (*T, error) { - rows, err := Collect[T](ctx, df) + typed, err := As[T](df) if err != nil { return nil, err } - if len(rows) == 0 { - return nil, ErrNotFound - } - return &rows[0], nil -} - -// As wraps df in the typed surface. Alias for TypedDataFrame[T], -// named to match CLAUDE_CODE_BOOTSTRAP.md and Scala's Encoder-flavoured -// naming. Schema compatibility with T is validated lazily — the first -// call to Collect or Stream surfaces drift, not As itself. -func As[T any](df DataFrame) (*DataFrameOf[T], error) { - return TypedDataFrame[T](df) -} - -// Into scans df into dst where dst is a pointer to either a slice -// (populated with every row) or a single struct (the first row; -// ErrNotFound on empty). Non-generic variant for cases where T is -// not known at compile time — typical of code-generated consumers -// or reflection-heavy DSLs. -func Into(ctx context.Context, df DataFrame, dst any) error { - rv := reflect.ValueOf(dst) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("spark.Into: dst must be a non-nil pointer, got %T", dst) - } - elem := rv.Elem() - switch elem.Kind() { - case reflect.Slice: - return intoSlice(ctx, df, elem) - case reflect.Struct: - return intoStruct(ctx, df, elem) - default: - return fmt.Errorf("spark.Into: dst must point to a slice or struct, got %v", elem.Kind()) - } -} - -func intoSlice(ctx context.Context, df DataFrame, sliceValue reflect.Value) error { - elemType := sliceValue.Type().Elem() - if elemType.Kind() != reflect.Struct { - return fmt.Errorf("spark.Into: slice element type must be a struct, got %v", elemType) - } - plan, err := buildRowPlan(elemType) - if err != nil { - return err - } - rows, err := df.Collect(ctx) - if err != nil { - return err - } - if len(rows) == 0 { - sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0)) - return nil - } - bindings, err := plan.bind(rows[0].FieldNames()) - if err != nil { - return err - } - out := reflect.MakeSlice(sliceValue.Type(), len(rows), len(rows)) - for i, r := range rows { - if err := decodeRowReflect(plan, r.Values(), bindings, out.Index(i)); err != nil { - return fmt.Errorf("spark.Into: row %d: %w", i, err) - } - } - sliceValue.Set(out) - return nil -} - -func intoStruct(ctx context.Context, df DataFrame, structValue reflect.Value) error { - plan, err := buildRowPlan(structValue.Type()) - if err != nil { - return err - } - rows, err := df.Collect(ctx) - if err != nil { - return err - } - if len(rows) == 0 { - return ErrNotFound - } - bindings, err := plan.bind(rows[0].FieldNames()) - if err != nil { - return err - } - return decodeRowReflect(plan, rows[0].Values(), bindings, structValue) -} - -// decodeRowReflect is the non-generic variant of decodeRow used by -// Into. Mirrors decodeRow's logic but walks the destination via -// reflect.Value rather than *T, so the caller can populate one slot -// of an already-allocated []T without instantiating T. -func decodeRowReflect(plan *rowPlan, values []any, bindings []columnBinding, dest reflect.Value) error { - for ci := 0; ci < len(values) && ci < len(bindings); ci++ { - b := bindings[ci] - if b.planIndex < 0 { - continue - } - pf := &plan.fields[b.planIndex] - target := fieldByIndex(dest, pf.index) - if err := assignTypedValue(target, values[ci]); err != nil { - return fmt.Errorf("column %d (%s): %w", ci, pf.name, err) - } - } - return nil + return typed.First(ctx) } diff --git a/spark/sql/typed_helpers_test.go b/spark/sql/typed_helpers_test.go index fab173a..d10b823 100644 --- a/spark/sql/typed_helpers_test.go +++ b/spark/sql/typed_helpers_test.go @@ -22,11 +22,6 @@ import ( "testing" ) -// Dataset[T] is the alias; verify at compile time that it's -// interchangeable with DataFrameOf[T]. A failing assertion here is -// a type-level regression caught without running any code. -var _ *DataFrameOf[typedUser] = (*Dataset[typedUser])(nil) - func TestErrNotFound_Sentinel(t *testing.T) { // First's empty-result path wraps ErrNotFound; verify callers can // branch on it via errors.Is. @@ -36,40 +31,8 @@ func TestErrNotFound_Sentinel(t *testing.T) { } } -func TestInto_RejectsNonPointer(t *testing.T) { - err := Into(context.Background(), nil, "not a pointer") - if err == nil || !strings.Contains(err.Error(), "non-nil pointer") { - t.Errorf("want non-nil-pointer error, got %v", err) - } -} - -func TestInto_RejectsNilPointer(t *testing.T) { - var users *[]typedUser - err := Into(context.Background(), nil, users) - if err == nil || !strings.Contains(err.Error(), "non-nil pointer") { - t.Errorf("want non-nil-pointer error, got %v", err) - } -} - -func TestInto_RejectsPointerToNonSliceNonStruct(t *testing.T) { - n := 42 - err := Into(context.Background(), nil, &n) - if err == nil || !strings.Contains(err.Error(), "slice or struct") { - t.Errorf("want slice-or-struct error, got %v", err) - } -} - -func TestInto_RejectsSliceOfNonStruct(t *testing.T) { - // Cover the elemType check in intoSlice before any I/O would fire. - ns := []int{} - err := Into(context.Background(), nil, &ns) - if err == nil || !strings.Contains(err.Error(), "must be a struct") { - t.Errorf("want slice-element-struct error, got %v", err) - } -} - func TestAs_RejectsNonStructT(t *testing.T) { - // As[T] delegates to TypedDataFrame[T] which enforces T must be a + // As[T] is the sole constructor for DataFrameOf[T]; T must be a // struct. Verify the error surfaces before any DataFrame I/O. _, err := As[int](nil) if err == nil || !strings.Contains(err.Error(), "must be a struct") { @@ -83,3 +46,10 @@ func TestCollect_RejectsNonStructT(t *testing.T) { t.Errorf("want struct-required error, got %v", err) } } + +func TestFirst_RejectsNonStructT(t *testing.T) { + _, err := First[int](context.Background(), nil) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} From 5c4c4a3a388fb158584730a7f1ad4946facf2686 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Mon, 20 Apr 2026 00:29:02 +0100 Subject: [PATCH 33/37] =?UTF-8?q?refactor:=20rename=20org=20github.com/dat?= =?UTF-8?q?alakego=20=E2=86=92=20github.com/datalake-go=20(#23)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The GitHub org is now datalake-go. Sweep module path + every import across the fork to match so `go get` resolves canonically without relying on GitHub's org-level redirect. - go.mod: module github.com/datalakego/spark-connect-go → module github.com/datalake-go/spark-connect-go - Every .go import: github.com/datalakego/spark-connect-go/... → github.com/datalake-go/spark-connect-go/... No API surface change; upstream-additive invariant against apache/spark-connect-go is unaffected. --- README.md | 4 ++-- cmd/spark-connect-example-raw-grpc-client/main.go | 2 +- cmd/spark-connect-example-spark-session/main.go | 8 ++++---- go.mod | 2 +- internal/tests/integration/dataframe_test.go | 10 +++++----- internal/tests/integration/functions_test.go | 6 +++--- internal/tests/integration/helper.go | 2 +- internal/tests/integration/spark_runner.go | 2 +- internal/tests/integration/sql_test.go | 8 ++++---- quick-start.md | 4 ++-- spark/client/base/base.go | 6 +++--- spark/client/channel/channel.go | 4 ++-- spark/client/channel/channel_test.go | 4 ++-- spark/client/client.go | 14 +++++++------- spark/client/client_test.go | 8 ++++---- spark/client/conf.go | 4 ++-- spark/client/retry.go | 8 ++++---- spark/client/retry_test.go | 8 ++++---- spark/client/testutils/utils.go | 2 +- spark/mocks/mock_executor.go | 8 ++++---- spark/mocks/mocks.go | 2 +- spark/sql/column/column.go | 4 ++-- spark/sql/column/column_test.go | 2 +- spark/sql/column/expressions.go | 6 +++--- spark/sql/column/expressions_test.go | 2 +- spark/sql/dataframe.go | 10 +++++----- spark/sql/dataframe_test.go | 4 ++-- spark/sql/dataframenafunctions.go | 2 +- spark/sql/dataframewriter.go | 4 ++-- spark/sql/dataframewriter_test.go | 6 +++--- spark/sql/driver/conn.go | 2 +- spark/sql/driver/driver.go | 4 ++-- spark/sql/driver/driver_test.go | 2 +- spark/sql/driver/rows.go | 2 +- spark/sql/functions/buiitins.go | 4 ++-- spark/sql/functions/generated.go | 2 +- spark/sql/group.go | 10 +++++----- spark/sql/group_test.go | 8 ++++---- spark/sql/plan.go | 2 +- spark/sql/sparksession.go | 14 +++++++------- spark/sql/sparksession_integration_test.go | 2 +- spark/sql/sparksession_test.go | 10 +++++----- spark/sql/types/arrow.go | 4 ++-- spark/sql/types/arrow_test.go | 4 ++-- spark/sql/types/builtin.go | 2 +- spark/sql/types/conversion.go | 4 ++-- spark/sql/types/conversion_test.go | 4 ++-- spark/sql/types/rowiterator_test.go | 2 +- spark/sql/utils/consts.go | 2 +- 49 files changed, 120 insertions(+), 120 deletions(-) diff --git a/README.md b/README.md index 0d90234..e8fe04a 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Step 3: Run the following commands to setup the Spark Connect client. Building with Spark in case you need to re-generate the source files from the proto sources. ``` -git clone https://github.com/datalakego/spark-connect-go.git +git clone https://github.com/datalake-go/spark-connect-go.git git submodule update --init --recursive make gen && make test @@ -27,7 +27,7 @@ make gen && make test Building without Spark ``` -git clone https://github.com/datalakego/spark-connect-go.git +git clone https://github.com/datalake-go/spark-connect-go.git make && make test ``` diff --git a/cmd/spark-connect-example-raw-grpc-client/main.go b/cmd/spark-connect-example-raw-grpc-client/main.go index 16a18df..e22f051 100644 --- a/cmd/spark-connect-example-raw-grpc-client/main.go +++ b/cmd/spark-connect-example-raw-grpc-client/main.go @@ -22,7 +22,7 @@ import ( "log" "time" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index 45ead19..aef6640 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -22,12 +22,12 @@ import ( "fmt" "log" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/datalakego/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/datalakego/spark-connect-go/spark/sql" - "github.com/datalakego/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" ) var ( diff --git a/go.mod b/go.mod index 5052b71..112cb00 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -module github.com/datalakego/spark-connect-go +module github.com/datalake-go/spark-connect-go go 1.24 diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index 1649dc6..2fb26bc 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -21,15 +21,15 @@ import ( "os" "testing" - "github.com/datalakego/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/datalakego/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/column" - "github.com/datalakego/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/datalakego/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/tests/integration/functions_test.go b/internal/tests/integration/functions_test.go index 48da37b..287f45b 100644 --- a/internal/tests/integration/functions_test.go +++ b/internal/tests/integration/functions_test.go @@ -19,11 +19,11 @@ import ( "context" "testing" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/datalakego/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/datalakego/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/internal/tests/integration/helper.go b/internal/tests/integration/helper.go index 9ea4dd4..28086de 100644 --- a/internal/tests/integration/helper.go +++ b/internal/tests/integration/helper.go @@ -22,7 +22,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/datalakego/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" ) func connect() (context.Context, sql.SparkSession) { diff --git a/internal/tests/integration/spark_runner.go b/internal/tests/integration/spark_runner.go index 4768b18..6398a56 100644 --- a/internal/tests/integration/spark_runner.go +++ b/internal/tests/integration/spark_runner.go @@ -23,7 +23,7 @@ import ( "os/exec" "time" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func StartSparkConnect() (int64, error) { diff --git a/internal/tests/integration/sql_test.go b/internal/tests/integration/sql_test.go index 830c317..11ed17b 100644 --- a/internal/tests/integration/sql_test.go +++ b/internal/tests/integration/sql_test.go @@ -22,13 +22,13 @@ import ( "os" "testing" - "github.com/datalakego/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/column" - "github.com/datalakego/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/datalakego/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/quick-start.md b/quick-start.md index 94151ff..221c82a 100644 --- a/quick-start.md +++ b/quick-start.md @@ -5,7 +5,7 @@ In your Go project `go.mod` file, add `spark-connect-go` library: ``` require ( - github.com/datalakego/spark-connect-go master + github.com/datalake-go/spark-connect-go master ) ``` @@ -23,7 +23,7 @@ import ( "fmt" "log" - "github.com/datalakego/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" ) var ( diff --git a/spark/client/base/base.go b/spark/client/base/base.go index 054f1f9..ef9e77d 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -19,11 +19,11 @@ import ( "context" "iter" - "github.com/datalakego/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" "github.com/apache/arrow-go/v18/arrow" - "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) type SparkConnectRPCClient generated.SparkConnectServiceClient diff --git a/spark/client/channel/channel.go b/spark/client/channel/channel.go index fd61306..6abe1b8 100644 --- a/spark/client/channel/channel.go +++ b/spark/client/channel/channel.go @@ -29,13 +29,13 @@ import ( "strconv" "strings" - "github.com/datalakego/spark-connect-go/spark" + "github.com/datalake-go/spark-connect-go/spark" "github.com/google/uuid" "google.golang.org/grpc/credentials/insecure" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" diff --git a/spark/client/channel/channel_test.go b/spark/client/channel/channel_test.go index 7a3abf7..b33084d 100644 --- a/spark/client/channel/channel_test.go +++ b/spark/client/channel/channel_test.go @@ -24,8 +24,8 @@ import ( "github.com/google/uuid" "google.golang.org/grpc" - "github.com/datalakego/spark-connect-go/spark/client/channel" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/client/channel" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" ) diff --git a/spark/client/client.go b/spark/client/client.go index 0e2bff0..d6b577d 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -22,24 +22,24 @@ import ( "io" "iter" - "github.com/datalakego/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" "google.golang.org/grpc" "google.golang.org/grpc/metadata" - "github.com/datalakego/spark-connect-go/spark/client/base" - "github.com/datalakego/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/mocks" - "github.com/datalakego/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" "github.com/google/uuid" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) type sparkConnectClientImpl struct { diff --git a/spark/client/client_test.go b/spark/client/client_test.go index 2e79650..8b6f4c6 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -12,10 +12,10 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/client" - "github.com/datalakego/spark-connect-go/spark/mocks" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/spark/client/conf.go b/spark/client/conf.go index 96ad2e0..5dc0d2d 100644 --- a/spark/client/conf.go +++ b/spark/client/conf.go @@ -18,8 +18,8 @@ package client import ( "context" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/client/base" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client/base" ) // Public interface RuntimeConfig diff --git a/spark/client/retry.go b/spark/client/retry.go index 4bf6161..c802fa4 100644 --- a/spark/client/retry.go +++ b/spark/client/retry.go @@ -23,13 +23,13 @@ import ( "strings" "time" - "github.com/datalakego/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/client/base" - "github.com/datalakego/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" "google.golang.org/grpc/metadata" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) diff --git a/spark/client/retry_test.go b/spark/client/retry_test.go index 604419d..cf3f717 100644 --- a/spark/client/retry_test.go +++ b/spark/client/retry_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" - "github.com/datalakego/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" - "github.com/datalakego/spark-connect-go/spark/client/testutils" - "github.com/datalakego/spark-connect-go/spark/mocks" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/client/testutils" + "github.com/datalake-go/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go index fb071b5..a3a62a7 100644 --- a/spark/client/testutils/utils.go +++ b/spark/client/testutils/utils.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "google.golang.org/grpc" ) diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go index c67e2ed..2a16765 100644 --- a/spark/mocks/mock_executor.go +++ b/spark/mocks/mock_executor.go @@ -19,13 +19,13 @@ import ( "context" "errors" - "github.com/datalakego/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" - "github.com/datalakego/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/client/base" "github.com/apache/arrow-go/v18/arrow" - "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) type TestExecutor struct { diff --git a/spark/mocks/mocks.go b/spark/mocks/mocks.go index 3ce7c79..7569e0f 100644 --- a/spark/mocks/mocks.go +++ b/spark/mocks/mocks.go @@ -25,7 +25,7 @@ import ( "github.com/google/uuid" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "google.golang.org/grpc/metadata" ) diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go index 19c3629..7ed2080 100644 --- a/spark/sql/column/column.go +++ b/spark/sql/column/column.go @@ -18,9 +18,9 @@ package column import ( "context" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) // Convertible is the interface for all things that can be converted into a protobuf expression. diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go index 11e2203..7f1134e 100644 --- a/spark/sql/column/column_test.go +++ b/spark/sql/column/column_test.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go index 30e34e5..fc8780e 100644 --- a/spark/sql/column/expressions.go +++ b/spark/sql/column/expressions.go @@ -20,11 +20,11 @@ import ( "fmt" "strings" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) func newProtoExpression() *proto.Expression { diff --git a/spark/sql/column/expressions_test.go b/spark/sql/column/expressions_test.go index 58741f6..c1c2fb1 100644 --- a/spark/sql/column/expressions_test.go +++ b/spark/sql/column/expressions_test.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 4b61326..82f19db 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -22,14 +22,14 @@ import ( "math/rand/v2" "github.com/apache/arrow-go/v18/arrow" - "github.com/datalakego/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" - "github.com/datalakego/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/column" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) // ResultCollector receives a stream of result rows diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go index 475c270..3859629 100644 --- a/spark/sql/dataframe_test.go +++ b/spark/sql/dataframe_test.go @@ -20,8 +20,8 @@ import ( "context" "testing" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sql/functions" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframenafunctions.go b/spark/sql/dataframenafunctions.go index c288493..552788e 100644 --- a/spark/sql/dataframenafunctions.go +++ b/spark/sql/dataframenafunctions.go @@ -18,7 +18,7 @@ package sql import ( "context" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) type DataFrameNaFunctions interface { diff --git a/spark/sql/dataframewriter.go b/spark/sql/dataframewriter.go index ee2fb7f..ca05b64 100644 --- a/spark/sql/dataframewriter.go +++ b/spark/sql/dataframewriter.go @@ -21,8 +21,8 @@ import ( "fmt" "strings" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) // DataFrameWriter supports writing data frame to storage. diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index 10baecd..3a60fe8 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - "github.com/datalakego/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/mocks" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/driver/conn.go b/spark/sql/driver/conn.go index 2f3312a..65b7f9d 100644 --- a/spark/sql/driver/conn.go +++ b/spark/sql/driver/conn.go @@ -20,7 +20,7 @@ import ( "database/sql/driver" "errors" - sparksql "github.com/datalakego/spark-connect-go/spark/sql" + sparksql "github.com/datalake-go/spark-connect-go/spark/sql" ) // conn is the per-logical-connection state database/sql keeps in its diff --git a/spark/sql/driver/driver.go b/spark/sql/driver/driver.go index 6c1f24a..75eb5ca 100644 --- a/spark/sql/driver/driver.go +++ b/spark/sql/driver/driver.go @@ -22,7 +22,7 @@ // // import ( // "database/sql" -// _ "github.com/datalakego/spark-connect-go/spark/sql/driver" +// _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" // ) // // db, err := sql.Open("spark", "sc://localhost:15002?format=iceberg") @@ -65,7 +65,7 @@ import ( "database/sql" "database/sql/driver" - sparksql "github.com/datalakego/spark-connect-go/spark/sql" + sparksql "github.com/datalake-go/spark-connect-go/spark/sql" ) // init registers the driver under the name "spark". Consumers that diff --git a/spark/sql/driver/driver_test.go b/spark/sql/driver/driver_test.go index 1b850cf..b11c9c0 100644 --- a/spark/sql/driver/driver_test.go +++ b/spark/sql/driver/driver_test.go @@ -24,7 +24,7 @@ import ( "strings" "testing" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) // --- DSN parsing --------------------------------------------------- diff --git a/spark/sql/driver/rows.go b/spark/sql/driver/rows.go index fffdd9f..7a39b54 100644 --- a/spark/sql/driver/rows.go +++ b/spark/sql/driver/rows.go @@ -19,7 +19,7 @@ import ( "database/sql/driver" "io" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) // rows wraps a materialised slice of sparksql types.Row values so diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go index ce9b46d..8d30662 100644 --- a/spark/sql/functions/buiitins.go +++ b/spark/sql/functions/buiitins.go @@ -16,8 +16,8 @@ package functions import ( - "github.com/datalakego/spark-connect-go/spark/sql/column" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) func Expr(expr string) column.Column { diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go index 5af8970..071d33b 100644 --- a/spark/sql/functions/generated.go +++ b/spark/sql/functions/generated.go @@ -15,7 +15,7 @@ package functions -import "github.com/datalakego/spark-connect-go/spark/sql/column" +import "github.com/datalake-go/spark-connect-go/spark/sql/column" // BitwiseNOT - Computes bitwise not. // diff --git a/spark/sql/group.go b/spark/sql/group.go index b908e76..87c306f 100644 --- a/spark/sql/group.go +++ b/spark/sql/group.go @@ -19,12 +19,12 @@ package sql import ( "context" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" - "github.com/datalakego/spark-connect-go/spark/sql/column" - "github.com/datalakego/spark-connect-go/spark/sql/functions" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" ) type GroupedData struct { diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go index 109b83e..562ed45 100644 --- a/spark/sql/group_test.go +++ b/spark/sql/group_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/client" - "github.com/datalakego/spark-connect-go/spark/client/testutils" - "github.com/datalakego/spark-connect-go/spark/mocks" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client/testutils" + "github.com/datalake-go/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/plan.go b/spark/sql/plan.go index 957da2d..73d5310 100644 --- a/spark/sql/plan.go +++ b/spark/sql/plan.go @@ -19,7 +19,7 @@ package sql import ( "sync/atomic" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) var atomicInt64 atomic.Int64 diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index 57c03c9..071396d 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -23,19 +23,19 @@ import ( "time" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/datalakego/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/client/base" - "github.com/datalakego/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/client" - "github.com/datalakego/spark-connect-go/spark/client/channel" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client/channel" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/metadata" diff --git a/spark/sql/sparksession_integration_test.go b/spark/sql/sparksession_integration_test.go index 9e8ca70..d6cecdc 100644 --- a/spark/sql/sparksession_integration_test.go +++ b/spark/sql/sparksession_integration_test.go @@ -21,7 +21,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 9acc682..b6d1d3c 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -30,11 +30,11 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/client" - "github.com/datalakego/spark-connect-go/spark/client/testutils" - "github.com/datalakego/spark-connect-go/spark/mocks" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client/testutils" + "github.com/datalake-go/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func TestSparkSessionTable(t *testing.T) { diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index 92432f8..ef4c434 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -20,13 +20,13 @@ import ( "bytes" "fmt" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index e19c3cd..9fdc40f 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -30,8 +30,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sql/types" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) func TestShowArrowBatchData(t *testing.T) { diff --git a/spark/sql/types/builtin.go b/spark/sql/types/builtin.go index e03f843..24fc2b2 100644 --- a/spark/sql/types/builtin.go +++ b/spark/sql/types/builtin.go @@ -19,7 +19,7 @@ package types import ( "context" - proto "github.com/datalakego/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) type LiteralType interface { diff --git a/spark/sql/types/conversion.go b/spark/sql/types/conversion.go index b40db72..8f5d7f1 100644 --- a/spark/sql/types/conversion.go +++ b/spark/sql/types/conversion.go @@ -19,8 +19,8 @@ package types import ( "errors" - "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func ConvertProtoDataTypeToStructType(input *generated.DataType) (*StructType, error) { diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go index d3a3d00..9302c56 100644 --- a/spark/sql/types/conversion_test.go +++ b/spark/sql/types/conversion_test.go @@ -19,8 +19,8 @@ package types_test import ( "testing" - proto "github.com/datalakego/spark-connect-go/internal/generated" - "github.com/datalakego/spark-connect-go/spark/sql/types" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index a81c8ef..d1aab17 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -16,7 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/datalakego/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) // Helper function to create test records diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go index 53a4e56..83bdad2 100644 --- a/spark/sql/utils/consts.go +++ b/spark/sql/utils/consts.go @@ -15,7 +15,7 @@ package utils -import proto "github.com/datalakego/spark-connect-go/internal/generated" +import proto "github.com/datalake-go/spark-connect-go/internal/generated" type ExplainMode int From 6c4fc2581321584f6871dc7a207521af6ee48065 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Mon, 20 Apr 2026 00:32:31 +0100 Subject: [PATCH 34/37] feat: parameter binding + drop format DSN param from database/sql driver (#22) Two changes, both keeping the driver a boring reusable transport layer: 1. Parameter binding. database/sql-aware tools (goose, sqlc, pgx consumers) need $N placeholder substitution on Exec/Query. A new render.go walks the query, swaps $N for the corresponding NamedValue, and quotes the literal as Spark SQL. NumInput now returns -1 so the sql package passes args through rather than pre-rejecting them. Spark Connect's protocol-level param-binding proto isn't round-trip-stable across every supported Spark version, so client-side rendering is the compatible path. 2. Drop the `?format=iceberg|delta` DSN parameter. The driver is the transport layer; it should stay out of lakehouse table-format concerns. Migration SQL and application SQL pick Iceberg vs Delta vs Parquet via Spark's native `USING ` DDL clause, exactly as upstream Spark documents. Consumers that want format-dependent bookkeeping (like goose's version table) select a dialect at that layer, not via the DSN. render supports the stdlib-sql default types (nil, bool, int64, float64, string, []byte, time.Time). String single-quotes are doubled; `$N` inside a quoted literal stays literal; unknown types surface an error naming the query token rather than silently interpolating garbage. --- spark/sql/driver/driver_test.go | 74 ++++------------- spark/sql/driver/dsn.go | 54 ++++--------- spark/sql/driver/render.go | 135 ++++++++++++++++++++++++++++++++ spark/sql/driver/render_test.go | 119 ++++++++++++++++++++++++++++ spark/sql/driver/stmt.go | 43 +++++----- 5 files changed, 306 insertions(+), 119 deletions(-) create mode 100644 spark/sql/driver/render.go create mode 100644 spark/sql/driver/render_test.go diff --git a/spark/sql/driver/driver_test.go b/spark/sql/driver/driver_test.go index b11c9c0..c6e8e4b 100644 --- a/spark/sql/driver/driver_test.go +++ b/spark/sql/driver/driver_test.go @@ -16,10 +16,8 @@ package driver import ( - "context" "database/sql" "database/sql/driver" - "errors" "io" "strings" "testing" @@ -37,36 +35,6 @@ func TestParseDSN_AcceptsPlainSc(t *testing.T) { if cfg.sparkDSN != "sc://localhost:15002" { t.Errorf("sparkDSN = %q", cfg.sparkDSN) } - if cfg.format != "" { - t.Errorf("format = %q, want empty", cfg.format) - } -} - -func TestParseDSN_ParsesFormatParameter(t *testing.T) { - cfg, err := parseDSN("sc://localhost:15002?format=iceberg") - if err != nil { - t.Fatalf("parseDSN: %v", err) - } - if cfg.format != "iceberg" { - t.Errorf("format = %q, want iceberg", cfg.format) - } -} - -func TestParseDSN_NormalisesFormatCase(t *testing.T) { - cfg, err := parseDSN("sc://localhost:15002?format=DELTA") - if err != nil { - t.Fatalf("parseDSN: %v", err) - } - if cfg.format != "delta" { - t.Errorf("format = %q, want delta (lowercased)", cfg.format) - } -} - -func TestParseDSN_RejectsUnknownFormat(t *testing.T) { - _, err := parseDSN("sc://localhost:15002?format=duckdb") - if err == nil || !strings.Contains(err.Error(), "unsupported format") { - t.Errorf("err = %v, want unsupported-format error", err) - } } func TestParseDSN_RejectsMissingSchemePrefix(t *testing.T) { @@ -83,17 +51,20 @@ func TestParseDSN_RejectsEmpty(t *testing.T) { } } -func TestParseDSN_PreservesTokenInSparkDSN(t *testing.T) { - // The token parameter is meaningful to the downstream session - // builder; the driver itself doesn't interpret it but must not - // strip it from the DSN forwarded to NewSessionBuilder().Remote. - cfg, err := parseDSN("sc://host:15002?token=secret&format=iceberg") +func TestParseDSN_PreservesQueryParamsInSparkDSN(t *testing.T) { + // The driver does not interpret query parameters — not `token`, + // not anything else. Parameters must round-trip unchanged into + // sparkDSN so the upstream SparkSessionBuilder.Remote sees them. + cfg, err := parseDSN("sc://host:15002?token=secret&user=alice") if err != nil { t.Fatalf("parseDSN: %v", err) } if !strings.Contains(cfg.sparkDSN, "token=secret") { t.Errorf("sparkDSN should carry token unchanged; got %q", cfg.sparkDSN) } + if !strings.Contains(cfg.sparkDSN, "user=alice") { + t.Errorf("sparkDSN should carry user unchanged; got %q", cfg.sparkDSN) + } } // --- sql.Register side-effect --------------------------------------- @@ -121,29 +92,14 @@ func TestDriver_OpenConnectorInvalidDSN(t *testing.T) { } } -// --- stmt argument rejection --------------------------------------- -// -// v0 doesn't implement parameter binding. Callers that pass args -// should see the errArgsUnsupported sentinel before any server call -// fires. - -func TestStmt_ExecContextRejectsArgs(t *testing.T) { - s := &stmt{conn: &conn{}, query: "SELECT 1"} - _, err := s.ExecContext(context.Background(), []driver.NamedValue{ - {Ordinal: 1, Value: 42}, - }) - if err == nil || !errors.Is(err, errArgsUnsupported) { - t.Errorf("ExecContext with args err = %v, want errArgsUnsupported", err) - } -} +// Parameter binding coverage lives in render_test.go. Here we only +// pin that NumInput returns -1 so database/sql passes args through +// untouched instead of rejecting them up front. -func TestStmt_QueryContextRejectsArgs(t *testing.T) { - s := &stmt{conn: &conn{}, query: "SELECT 1"} - _, err := s.QueryContext(context.Background(), []driver.NamedValue{ - {Ordinal: 1, Value: "x"}, - }) - if err == nil || !errors.Is(err, errArgsUnsupported) { - t.Errorf("QueryContext with args err = %v, want errArgsUnsupported", err) +func TestStmt_NumInputIsNegativeOne(t *testing.T) { + s := &stmt{} + if got := s.NumInput(); got != -1 { + t.Errorf("NumInput = %d, want -1 (database/sql should not pre-count args)", got) } } diff --git a/spark/sql/driver/dsn.go b/spark/sql/driver/dsn.go index d80850f..8ff4e73 100644 --- a/spark/sql/driver/dsn.go +++ b/spark/sql/driver/dsn.go @@ -18,40 +18,29 @@ package driver import ( "errors" "fmt" - "net/url" "strings" ) -// dsnConfig is the parsed DSN. Kept minimal — the Spark Connect -// session builder wants the original URL verbatim, so the config -// mostly exists to surface query-parameter knobs (format, token) -// that downstream tools want to read without re-parsing the DSN. +// dsnConfig is the parsed DSN. The driver is a boring transport +// layer — it stays out of table-format concerns (Iceberg vs Delta +// vs parquet) and keeps the original URL intact for the Spark +// Connect session builder to interpret. type dsnConfig struct { // sparkDSN is what gets passed to NewSessionBuilder().Remote(...). - // Includes any `?token=` fragment if present. + // Includes any `?token=` or other query fragment verbatim so the + // upstream builder sees the URL exactly as the caller wrote it. sparkDSN string - - // format is the value of the `format` query parameter ("iceberg" - // / "delta" / empty). Driver ignores it; consumers that need - // dialect-aware DDL read it via the exported Connector.Format() - // accessor. - format string } -// parseDSN accepts the sc:// Spark Connect URL form with optional -// query parameters. Returns an error for malformed input. -// -// Recognised query parameters: +// parseDSN accepts the `sc://host:port[?token=...]` Spark Connect +// URL form. The driver does not interpret query parameters beyond +// validating the scheme — any `token`, `user`, or future Spark +// Connect flag rides through in `sparkDSN` untouched. // -// - token: bearer token forwarded to the Spark Connect server in -// the Authorization header. Preserved in sparkDSN verbatim so -// the session builder picks it up. -// - format: lakehouse table format ("iceberg" | "delta"). Driver- -// layer passthrough; used by consumers like goose-spark to pick -// a dialect-appropriate CREATE TABLE. -// -// Unrecognised parameters are ignored (preserved in the DSN as-is) -// so the driver doesn't fight with future Spark Connect flags. +// Table format selection (Iceberg, Delta, parquet) is explicitly +// NOT a driver concern. Migrations and application SQL pick format +// via Spark's native `USING ` DDL clause; the transport +// layer stays boring and reusable. func parseDSN(dsn string) (*dsnConfig, error) { if dsn == "" { return nil, errors.New("spark driver: DSN is required") @@ -59,18 +48,5 @@ func parseDSN(dsn string) (*dsnConfig, error) { if !strings.HasPrefix(dsn, "sc://") { return nil, fmt.Errorf("spark driver: DSN must start with sc://, got %q", dsn) } - u, err := url.Parse(dsn) - if err != nil { - return nil, fmt.Errorf("spark driver: parse DSN: %w", err) - } - - cfg := &dsnConfig{ - sparkDSN: dsn, - format: strings.ToLower(u.Query().Get("format")), - } - if cfg.format != "" && cfg.format != "iceberg" && cfg.format != "delta" { - return nil, fmt.Errorf( - "spark driver: unsupported format %q; expected iceberg or delta", cfg.format) - } - return cfg, nil + return &dsnConfig{sparkDSN: dsn}, nil } diff --git a/spark/sql/driver/render.go b/spark/sql/driver/render.go new file mode 100644 index 0000000..b06e65c --- /dev/null +++ b/spark/sql/driver/render.go @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql/driver" + "fmt" + "strconv" + "strings" + "time" +) + +// render interpolates driver.NamedValue arguments into a query +// string at `$N` placeholders ($1, $2, ...). Values are quoted as +// Spark SQL literals on the way through. +// +// Spark Connect's parameterised-query proto exists but doesn't +// round-trip reliably across every Spark version the driver +// supports. Client-side rendering keeps the driver compatible with +// the full 3.4+ range and mirrors what the native Spark driver in +// datalakego/lakeorm already does for the same reason. +// +// Accepts `$N` placeholders only. `?` is reserved for a future cycle +// if a caller asks for it — most database/sql-adjacent codegen +// (sqlc, goose's own dialects, pgx patterns) emits `$N`, so the +// narrower set suffices and keeps the renderer simple. +// +// Supported types are the stdlib-sql defaults: nil, bool, int64, +// float64, string, []byte, time.Time. Anything rarer surfaces as an +// error so the caller tightens their query before seeing a silently +// mis-interpolated value on the wire. +func render(query string, args []driver.NamedValue) (string, error) { + if len(args) == 0 { + return query, nil + } + byOrdinal := make(map[int]driver.Value, len(args)) + for _, a := range args { + byOrdinal[a.Ordinal] = a.Value + } + + var b strings.Builder + b.Grow(len(query) + 32) + for i := 0; i < len(query); { + c := query[i] + + // Pass single-quoted string literals through untouched, so a + // `'$1'` in the caller's query stays literal rather than + // getting parameter-substituted. + if c == '\'' { + b.WriteByte(c) + i++ + for i < len(query) { + if query[i] == '\'' { + // Doubled '' is an escaped quote; keep scanning. + if i+1 < len(query) && query[i+1] == '\'' { + b.WriteByte('\'') + b.WriteByte('\'') + i += 2 + continue + } + b.WriteByte('\'') + i++ + break + } + b.WriteByte(query[i]) + i++ + } + continue + } + + if c == '$' && i+1 < len(query) && query[i+1] >= '0' && query[i+1] <= '9' { + j := i + 1 + for j < len(query) && query[j] >= '0' && query[j] <= '9' { + j++ + } + ord, _ := strconv.Atoi(query[i+1 : j]) + v, ok := byOrdinal[ord] + if !ok { + return "", fmt.Errorf("spark driver: no argument for placeholder $%d", ord) + } + lit, err := sqlLiteral(v) + if err != nil { + return "", fmt.Errorf("spark driver: placeholder $%d: %w", ord, err) + } + b.WriteString(lit) + i = j + continue + } + + b.WriteByte(c) + i++ + } + return b.String(), nil +} + +// sqlLiteral renders a single Go value as a Spark SQL literal token. +// Strings are single-quoted with embedded quotes doubled. Time values +// render as TIMESTAMP literals in UTC microsecond precision — matches +// Spark's default timestamp parsing. +func sqlLiteral(v driver.Value) (string, error) { + switch x := v.(type) { + case nil: + return "NULL", nil + case bool: + if x { + return "TRUE", nil + } + return "FALSE", nil + case int64: + return strconv.FormatInt(x, 10), nil + case float64: + return strconv.FormatFloat(x, 'g', -1, 64), nil + case string: + return "'" + strings.ReplaceAll(x, "'", "''") + "'", nil + case []byte: + return "'" + strings.ReplaceAll(string(x), "'", "''") + "'", nil + case time.Time: + return "TIMESTAMP '" + x.UTC().Format("2006-01-02 15:04:05.000000") + "'", nil + default: + return "", fmt.Errorf("unsupported arg type %T", v) + } +} diff --git a/spark/sql/driver/render_test.go b/spark/sql/driver/render_test.go new file mode 100644 index 0000000..358ffe6 --- /dev/null +++ b/spark/sql/driver/render_test.go @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql/driver" + "strings" + "testing" + "time" +) + +func namedArgs(vs ...any) []driver.NamedValue { + out := make([]driver.NamedValue, len(vs)) + for i, v := range vs { + out[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return out +} + +func TestRender_NoArgsPassesThrough(t *testing.T) { + got, err := render("SELECT 1", nil) + if err != nil { + t.Fatalf("render: %v", err) + } + if got != "SELECT 1" { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestRender_Int64AndBool(t *testing.T) { + // goose's Insert always passes (version int64, is_applied bool). + // Verify the exact shape it emits renders to valid Spark SQL. + q := `INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2)` + got, err := render(q, namedArgs(int64(20260419000001), true)) + if err != nil { + t.Fatalf("render: %v", err) + } + want := `INSERT INTO goose_db_version (version_id, is_applied) VALUES (20260419000001, TRUE)` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestRender_StringEscapesEmbeddedQuotes(t *testing.T) { + got, err := render(`SELECT * FROM t WHERE name = $1`, namedArgs("O'Brien")) + if err != nil { + t.Fatalf("render: %v", err) + } + if !strings.Contains(got, "'O''Brien'") { + t.Errorf("got %q, want doubled-quote escape", got) + } +} + +func TestRender_LeavesLiteralInsideStringAlone(t *testing.T) { + // A `$1` that lives inside a quoted literal must stay literal. + // Otherwise something like `WHERE s = '$1'` would try to + // substitute into the user-intended string. + got, err := render(`SELECT '$1 then $2' WHERE id = $1`, namedArgs(int64(42))) + if err != nil { + t.Fatalf("render: %v", err) + } + if !strings.Contains(got, "'$1 then $2'") { + t.Errorf("got %q, string-literal $-tokens should stay untouched", got) + } + if !strings.Contains(got, "id = 42") { + t.Errorf("got %q, outer $1 should have been replaced with 42", got) + } +} + +func TestRender_NilBecomesSQLNull(t *testing.T) { + got, err := render(`SELECT $1`, namedArgs(nil)) + if err != nil { + t.Fatalf("render: %v", err) + } + if got != "SELECT NULL" { + t.Errorf("got %q, want SELECT NULL", got) + } +} + +func TestRender_TimeRendersAsTimestampLiteral(t *testing.T) { + ts := time.Date(2026, 4, 19, 12, 34, 56, 0, time.UTC) + got, err := render(`SELECT $1`, namedArgs(ts)) + if err != nil { + t.Fatalf("render: %v", err) + } + if !strings.Contains(got, "TIMESTAMP '2026-04-19 12:34:56") { + t.Errorf("got %q, want TIMESTAMP literal", got) + } +} + +func TestRender_UnsupportedTypeErrors(t *testing.T) { + type custom struct{ N int } + _, err := render(`SELECT $1`, namedArgs(custom{N: 5})) + if err == nil || !strings.Contains(err.Error(), "unsupported arg type") { + t.Errorf("err = %v, want unsupported-arg-type", err) + } +} + +func TestRender_MissingOrdinalErrors(t *testing.T) { + // $3 referenced but only two args supplied — surfaces at render + // time rather than on the wire, so the error points at the query. + _, err := render(`INSERT INTO t VALUES ($1, $2, $3)`, namedArgs(int64(1), int64(2))) + if err == nil || !strings.Contains(err.Error(), "$3") { + t.Errorf("err = %v, want missing-ordinal error naming $3", err) + } +} diff --git a/spark/sql/driver/stmt.go b/spark/sql/driver/stmt.go index 811a4d2..e1c5dc2 100644 --- a/spark/sql/driver/stmt.go +++ b/spark/sql/driver/stmt.go @@ -18,7 +18,6 @@ package driver import ( "context" "database/sql/driver" - "errors" ) // stmt wraps one query string against a conn. The driver doesn't @@ -40,15 +39,14 @@ type stmt struct { // Implements database/sql/driver.Stmt. func (*stmt) Close() error { return nil } -// NumInput reports the number of placeholders this statement -// expects. v0 doesn't support parameter binding, so we return 0 -// and reject non-empty arg slices in ExecContext / QueryContext. -// goose's generated SQL (CreateVersionTable / InsertVersion) and -// hand-authored migration files don't use placeholders, so this -// limitation isn't felt in practice. +// NumInput reports -1 — "driver doesn't know the number of +// placeholders." The renderer parses `$N` tokens inline at Exec / +// Query time, so the sql package's up-front arg-count check doesn't +// apply. Returning -1 tells database/sql to hand any arg count +// through unchanged. // // Implements database/sql/driver.Stmt. -func (*stmt) NumInput() int { return 0 } +func (*stmt) NumInput() int { return -1 } // Exec is the legacy (non-context) Exec entry. Implements // database/sql/driver.Stmt. @@ -75,15 +73,20 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { // that metric reliably at the session.Sql layer. database/sql // allows -1 per the driver.Result docs. // +// Arguments are rendered into the query at $N placeholders before +// the session call fires; Spark Connect's protocol-level parameter +// binding doesn't round-trip reliably across every supported Spark +// version, so client-side rendering is the compatible path. +// // Implements database/sql/driver.StmtExecContext. func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - if len(args) > 0 { - return nil, errArgsUnsupported - } - _, err := s.conn.session.Sql(ctx, s.query) + q, err := render(s.query, args) if err != nil { return nil, err } + if _, err := s.conn.session.Sql(ctx, q); err != nil { + return nil, err + } return result{}, nil } @@ -93,12 +96,16 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive // fires) this is right-sized; large result sets should bypass the // driver and use the native DataFrame / iter.Seq2 path directly. // +// Arguments are rendered into the query at $N placeholders before +// the session call fires — see ExecContext for rationale. +// // Implements database/sql/driver.StmtQueryContext. func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - if len(args) > 0 { - return nil, errArgsUnsupported + q, err := render(s.query, args) + if err != nil { + return nil, err } - df, err := s.conn.session.Sql(ctx, s.query) + df, err := s.conn.session.Sql(ctx, q) if err != nil { return nil, err } @@ -108,9 +115,3 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv } return newRows(rows), nil } - -// errArgsUnsupported is the v0 signal that parameter binding isn't -// implemented. Sentinel so callers can errors.Is on it. -var errArgsUnsupported = errors.New( - "spark driver: parameter binding is not supported in v0; interpolate values into the SQL string or upgrade when v1 ships", -) From b7342d7264a8e77209bb054bd22d2fe5ff953f7b Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Mon, 20 Apr 2026 02:13:44 +0100 Subject: [PATCH 35/37] fix: dedupe rowiterator_test.go imports after fork reset The cherry-pick of #9's gofumpt import shuffle landed on top of callum/SPARK-52780's rowiterator_test.go, whose import block already carried arrow/array and arrow/memory. The auto-merge produced duplicate entries that broke the build. Collapse to one import block. Pure cleanup artifact of the cherry-pick replay; no functional change. --- spark/sql/types/rowiterator_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go index d1aab17..931dca6 100644 --- a/spark/sql/types/rowiterator_test.go +++ b/spark/sql/types/rowiterator_test.go @@ -7,9 +7,6 @@ import ( "iter" "testing" - "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" From 6812fad3c9322e304fc02d6cfb8443b08226a132 Mon Sep 17 00:00:00 2001 From: caldempsey <8885269+caldempsey@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:05:53 +0100 Subject: [PATCH 36/37] docs: refresh README for datalake-go fork Replaces the upstream Apache README with fork-specific framing: what this repo adds over apache/spark-connect-go (database/sql driver, generic typed helpers, exposed gRPC dial options, IsClusterNotReady), import conventions, and the relationship to the rest of the datalake-go ecosystem. database/sql parameter binding is documented as client-side rendering at $N placeholders, reflecting how render.go actually works (native parameter proto isn't reliable across the 3.4+ Spark range). --- README.md | 125 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 89 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index e8fe04a..393a1ae 100644 --- a/README.md +++ b/README.md @@ -1,69 +1,122 @@ -# Apache Spark Connect Client for Golang +# spark-connect-go -This project houses the **experimental** client for [Spark -Connect](https://spark.apache.org/docs/latest/spark-connect-overview.html) for -[Apache Spark](https://spark.apache.org/) written in [Golang](https://go.dev/). +> The [datalake-go](https://github.com/datalake-go) fork of Apache's Spark Connect Go client. Adds a `database/sql` driver, generic typed helpers (`Collect[T]`, `Stream[T]`, `First[T]`), and exposed gRPC dial options on top of the upstream gRPC client. Tracks `apache/spark-connect-go`; deltas are intended to upstream. -## Getting started +Spark Connect is Spark's [language-neutral gRPC protocol](https://spark.apache.org/docs/latest/spark-connect-overview.html). The upstream Go client is the official reference implementation; this fork carries the extras the rest of the [datalake-go](https://github.com/datalake-go) ecosystem needs while those patches work their way upstream. -This section explains how to run Spark Connect Go locally. +## Why the fork -Step 1: Install Golang: https://go.dev/doc/install. +The upstream client is correct and minimal. Building an ORM and a composed runtime on top of it needs a handful of things that aren't in the upstream surface yet: -Step 2: Ensure you have installed `buf CLI` installed, [more info here](https://buf.build/docs/installation/) +- **`database/sql` driver.** A DSN-shaped entrypoint so existing Go tooling (goose, sqlc, pgx-consumer code) can speak Spark Connect without a wrapper layer. Driver name is `spark`; it speaks `sc://host:port` DSNs and supports `$N` positional parameters (rendered client-side into Spark SQL literals — the native parameter proto isn't reliable across every supported Spark version). +- **Generic typed helpers.** `Collect[T]`, `Stream[T]`, `First[T]` — Go generics over the untyped `DataFrame`. The upstream surface is untyped by design; these are thin wrappers over `As[T]` + `DataFrameOf[T]`. +- **Exposed gRPC dial options.** `SparkSessionBuilder.WithDialOptions(opts ...grpc.DialOption)` — auth interceptors, TLS config, observability handlers wire in without subclassing. +- **`IsClusterNotReady` error surface.** A typed error (`sparkerrors.ErrClusterNotReady` / `IsClusterNotReady(err)`) the caller can check instead of string-matching. Databricks serverless clusters take 30–90s to warm, and retry logic up the stack needs a reliable signal. -Step 3: Run the following commands to setup the Spark Connect client. +Every delta is tracked as a PR against this fork and queued for upstream. When a delta lands in `apache/spark-connect-go`, we drop it from the fork. The long-term goal is zero deltas. -Building with Spark in case you need to re-generate the source files from the proto sources. +## What this repo is +A drop-in replacement for `github.com/apache/spark-connect-go` at the same package names. Swap the import path and existing code compiles; the session API, DataFrame surface, and protobuf stubs are unchanged. + +```go +import ( + sparksql "github.com/datalake-go/spark-connect-go/spark/sql" + _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" // registers "spark" for database/sql +) ``` -git clone https://github.com/datalake-go/spark-connect-go.git -git submodule update --init --recursive -make gen && make test +(The `sparksql` alias avoids collision with stdlib `database/sql` — the actual package name is `sql`.) -``` +You can use this fork without lake-orm or lakehouse. The composed runtime ([lakehouse](https://github.com/datalake-go/lakehouse)) and the ORM ([lake-orm](https://github.com/datalake-go/lake-orm)) depend on this fork's deltas; nothing in this fork requires them. -Building without Spark +## Quick start -``` -git clone https://github.com/datalake-go/spark-connect-go.git -make && make test +```go +session, err := sparksql.NewSessionBuilder(). + Remote("sc://spark.internal:15002"). + Build(ctx) +if err != nil { /* ... */ } +defer session.Stop() + +df, _ := session.Sql(ctx, "SELECT id, email FROM users WHERE tier = 'gold'") +_ = df.Show(ctx, 20, false) ``` -Step 4: Setup the Spark Driver on localhost. +### Typed reads -1. [Download Spark distribution](https://spark.apache.org/downloads.html) (4.0.0+), unzip the package. +```go +type User struct { + ID string `spark:"id"` + Email string `spark:"email"` +} -2. Start the Spark Connect server with the following command (make sure to use a package version that matches your Spark distribution): +df, _ := session.Sql(ctx, "SELECT id, email FROM users WHERE tier = 'gold'") +users, _ := sparksql.Collect[User](ctx, df) +// or one row: +alice, err := sparksql.First[User](ctx, df) +if errors.Is(err, sparksql.ErrNotFound) { /* 404 */ } + +// or streaming: +for row, rerr := range sparksql.Stream[User](ctx, df) { + if rerr != nil { break } + // use row +} ``` -sbin/start-connect-server.sh -``` -Step 5: Run the example Go application. +### `database/sql` driver + +```go +import ( + "database/sql" + _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" +) +db, _ := sql.Open("spark", "sc://spark.internal:15002") +rows, _ := db.QueryContext(ctx, "SELECT id FROM users WHERE tier = $1", "gold") ``` -go run cmd/spark-connect-example-spark-session/main.go + +`$N` placeholders are rendered client-side into Spark SQL literals with type-aware quoting (strings, numbers, bools, `[]byte`, `time.Time`). `?` placeholders aren't supported — most `database/sql`-adjacent codegen (sqlc, goose dialects, pgx patterns) emits `$N`, so the narrower grammar keeps the renderer simple. + +## Features (deltas over upstream) + +- **`database/sql` driver.** `sql.Open("spark", "sc://...")` — any `database/sql` consumer (goose, sqlc-generated code, ad-hoc scripts) plugs in. Registered under name `spark` in `spark/sql/driver`. +- **Typed helpers over `DataFrame`.** `Collect[T]`, `Stream[T]`, `First[T]`, `As[T]`, `DataFrameOf[T]`. Struct tag is `spark:""`. Decode rows into struct types at the materialization edge. +- **`SparkSessionBuilder.WithDialOptions`.** gRPC dial options exposed on the builder. Wire auth interceptors, TLS, observability without wrapping the builder. +- **`sparkerrors.IsClusterNotReady(err)`.** Classified error for cluster cold-start states. lake-orm uses it upstack for retry decisions. +- **Upstream parity.** Tracks `apache/master`; upstream merges flow through periodically with the fork's commits rebased on top. + +## Install + +```bash +go get github.com/datalake-go/spark-connect-go ``` -## Runnning Spark Connect Go Application in a Spark Cluster +Requires a Spark Connect server (Spark 3.4+). See [lake-k8s](https://github.com/datalake-go/lake-k8s) for a pre-baked Spark 4.0 + Iceberg + Delta image and a `docker compose up` laptop stack. -To run the Spark Connect Go application in a Spark Cluster, you need to build the Go application and submit it to the Spark Cluster. You can find a more detailed example runner and wrapper script in the `java` directory. +## Building from source -See the guide here: [Sample Spark-Submit Wrapper](java/README.md). +```bash +git clone https://github.com/datalake-go/spark-connect-go.git +cd spark-connect-go +make && make test +``` -## How to write Spark Connect Go Application in your own project +Regenerating protobuf stubs from the Spark submodule: -See [Quick Start Guide](quick-start.md) +```bash +git submodule update --init --recursive +make gen && make test +``` -## High Level Design +## Related -The overall goal of the design is to find a good balance of principle of the least surprise for -develoeprs that are familiar with the APIs of Apache Spark and idiomatic Go usage. The high-level -structure of the packages follows roughly the PySpark giudance but with Go idioms. +- [lake-orm](https://github.com/datalake-go/lake-orm) — ORM that uses this fork's typed helpers and `database/sql` driver. +- [lakehouse](https://github.com/datalake-go/lakehouse) — composed runtime that wires this session alongside the ORM, migrations, and dashboard. +- [lake-k8s](https://github.com/datalake-go/lake-k8s) — pre-baked Spark Connect server + laptop-mode compose stack. +- [apache/spark-connect-go](https://github.com/apache/spark-connect-go) — the upstream project this fork tracks. ## Contributing -Please review the [Contribution to Spark guide](https://spark.apache.org/contributing.html) -for information on how to get started contributing to the project. +Feature work that could land upstream should be proposed against `apache/spark-connect-go` first. Fork-only changes (anything that wouldn't be accepted upstream) stay on this tree. See [CONTRIBUTING.md](CONTRIBUTING.md). From b1b5edd7a643ee8766c0b2c07e7c4215d2d4738d Mon Sep 17 00:00:00 2001 From: Callum Dempsey Leach Date: Mon, 20 Apr 2026 20:16:52 +0100 Subject: [PATCH 37/37] docs(readme): refocus on the maintained-fork framing, restructure usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The README opened with lake-orm / lakehouse positioning ("the fork those projects need") instead of standing on its own as a maintained Spark Connect Go client. Reframe as "a maintained fork of apache/spark-connect-go with these deltas" so the repo reads usefully on its own. Drop the Related/lake-orm cross-links and the "you can use this without lake-orm" caveat that only made sense alongside the old framing. Restructure usage into three sections: Using DataFrames — untyped surface, composition primitive Using Typed DataFrames — As[T] → *DataFrameOf[T], edge-typing rationale Streaming Results — Stream[T] over iter.Seq2 (Go 1.23) Streaming Results highlights what Go gives us over the Python / Scala clients: a real pull-based iterator decoding rows as the gRPC stream resolves them, constant memory, no buffer-everything, no callback API. Update the typed surface docs to the current generics shape: As[T] returns *DataFrameOf[T] which caches the row plan; Collect / First / Stream are methods on that, with top-level helpers as one-shot wrappers. First[T] returns *T and ErrNotFound on empty results. Document the no-transformations design (re-type at the edge via typed.DataFrame()). --- README.md | 135 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 82 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 393a1ae..6716843 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,33 @@ # spark-connect-go -> The [datalake-go](https://github.com/datalake-go) fork of Apache's Spark Connect Go client. Adds a `database/sql` driver, generic typed helpers (`Collect[T]`, `Stream[T]`, `First[T]`), and exposed gRPC dial options on top of the upstream gRPC client. Tracks `apache/spark-connect-go`; deltas are intended to upstream. +> A maintained fork of [`apache/spark-connect-go`](https://github.com/apache/spark-connect-go) with a `database/sql` driver, edge-typed DataFrames, exposed gRPC dial options, and a typed `ClusterNotReady` error. Tracks upstream; deltas are queued to upstream. -Spark Connect is Spark's [language-neutral gRPC protocol](https://spark.apache.org/docs/latest/spark-connect-overview.html). The upstream Go client is the official reference implementation; this fork carries the extras the rest of the [datalake-go](https://github.com/datalake-go) ecosystem needs while those patches work their way upstream. +Spark Connect is Spark's [language-neutral gRPC protocol](https://spark.apache.org/docs/latest/spark-connect-overview.html). The upstream Go client is the official reference implementation. This fork carries the deltas needed for production usage while those patches work their way upstream — drop in by swapping the import path; the session API, DataFrame surface, and protobuf stubs are unchanged. -## Why the fork +## What's added -The upstream client is correct and minimal. Building an ORM and a composed runtime on top of it needs a handful of things that aren't in the upstream surface yet: +- **`database/sql` driver.** `sql.Open("spark", "sc://host:port")` works with goose, sqlc-generated code, pgx-style consumers — anything that speaks `database/sql`. Registered under the name `spark` in `spark/sql/driver`. `$N` positional placeholders are rendered client-side into Spark SQL literals (the native parameter proto isn't reliable across every supported Spark version). +- **Edge-typed DataFrames.** `As[T](df) → *DataFrameOf[T]` caches a reflected row plan once; `Collect`, `Stream`, `First` materialise into struct types at the point you know the result shape. Top-level `Collect[T] / Stream[T] / First[T]` helpers do the `As[T]` plus the call in one shot. +- **`SparkSessionBuilder.WithDialOptions`.** gRPC dial options exposed on the builder — auth interceptors, TLS, observability handlers wire in without subclassing. +- **`sparkerrors.IsClusterNotReady(err)`.** Typed error for cluster cold-start states. Databricks serverless clusters take 30-90s to warm; retry logic upstack needs a reliable signal instead of string-matching on error messages. -- **`database/sql` driver.** A DSN-shaped entrypoint so existing Go tooling (goose, sqlc, pgx-consumer code) can speak Spark Connect without a wrapper layer. Driver name is `spark`; it speaks `sc://host:port` DSNs and supports `$N` positional parameters (rendered client-side into Spark SQL literals — the native parameter proto isn't reliable across every supported Spark version). -- **Generic typed helpers.** `Collect[T]`, `Stream[T]`, `First[T]` — Go generics over the untyped `DataFrame`. The upstream surface is untyped by design; these are thin wrappers over `As[T]` + `DataFrameOf[T]`. -- **Exposed gRPC dial options.** `SparkSessionBuilder.WithDialOptions(opts ...grpc.DialOption)` — auth interceptors, TLS config, observability handlers wire in without subclassing. -- **`IsClusterNotReady` error surface.** A typed error (`sparkerrors.ErrClusterNotReady` / `IsClusterNotReady(err)`) the caller can check instead of string-matching. Databricks serverless clusters take 30–90s to warm, and retry logic up the stack needs a reliable signal. +Every delta is tracked as a PR queued for `apache/spark-connect-go`. When a delta lands upstream we drop it from the fork. Long-term goal is zero deltas. -Every delta is tracked as a PR against this fork and queued for upstream. When a delta lands in `apache/spark-connect-go`, we drop it from the fork. The long-term goal is zero deltas. +## Install + +```bash +go get github.com/datalake-go/spark-connect-go +``` -## What this repo is +Requires a Spark Connect server (Spark 3.4+). -A drop-in replacement for `github.com/apache/spark-connect-go` at the same package names. Swap the import path and existing code compiles; the session API, DataFrame surface, and protobuf stubs are unchanged. +## Quick start ```go import ( sparksql "github.com/datalake-go/spark-connect-go/spark/sql" - _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" // registers "spark" for database/sql ) -``` - -(The `sparksql` alias avoids collision with stdlib `database/sql` — the actual package name is `sql`.) - -You can use this fork without lake-orm or lakehouse. The composed runtime ([lakehouse](https://github.com/datalake-go/lakehouse)) and the ORM ([lake-orm](https://github.com/datalake-go/lake-orm)) depend on this fork's deltas; nothing in this fork requires them. -## Quick start - -```go session, err := sparksql.NewSessionBuilder(). Remote("sc://spark.internal:15002"). Build(ctx) @@ -43,28 +38,74 @@ df, _ := session.Sql(ctx, "SELECT id, email FROM users WHERE tier = 'gold'") _ = df.Show(ctx, 20, false) ``` -### Typed reads +The `sparksql` alias avoids collision with stdlib `database/sql` — the actual package name is `sql`. + +### Using DataFrames + +The untyped `DataFrame` is the building block — same surface as upstream. Transformations (`Where`, `Limit`, `OrderBy`, `Select`, `Join`, `GroupBy`) compose lazily and execute on the Spark side; materialisers (`Show`, `Collect`, `First`, `Count`) round-trip and return `[]types.Row`. + +```go +df, _ := session.Sql(ctx, "SELECT id, email, created_at FROM users") + +filtered, _ := df.Where(ctx, "tier = 'gold'") +top, _ := filtered.OrderBy(ctx, "created_at DESC").Limit(ctx, 100) + +rows, _ := top.Collect(ctx) +for _, r := range rows { + // r is types.Row — positional access by index or by name +} +``` + +Use this when the result shape is dynamic, or as the composition surface that you eventually re-type at the edge. + +### Using Typed DataFrames + +`As[T](df) → *DataFrameOf[T]` is the typed surface. It binds a result shape to a struct, caches the reflected row plan once, and materialises into `[]T` / `*T` without re-validating on every call. ```go type User struct { - ID string `spark:"id"` - Email string `spark:"email"` + ID string `spark:"id"` + Email string `spark:"email"` + Created time.Time `spark:"created_at"` } -df, _ := session.Sql(ctx, "SELECT id, email FROM users WHERE tier = 'gold'") +df, _ := session.Sql(ctx, "SELECT id, email, created_at FROM users WHERE tier = 'gold'") +typed, _ := sparksql.As[User](df) + +users, _ := typed.Collect(ctx) +alice, err := typed.First(ctx) +if errors.Is(err, sparksql.ErrNotFound) { /* zero rows */ } +``` + +If you only need the result once, `Collect[T] / First[T] / Stream[T]` are top-level helpers that fold `As[T]` into the call: + +```go users, _ := sparksql.Collect[User](ctx, df) +``` -// or one row: -alice, err := sparksql.First[User](ctx, df) -if errors.Is(err, sparksql.ErrNotFound) { /* 404 */ } +Untagged fields map by snake_case'd field name, so plain Go structs work without tags. `spark:"-"` skips a field. `*DataFrameOf[T]` deliberately has no transformation methods — `Where` / `Limit` / `Select` / `Join` change the row shape and would make `T` lie. Compose on the untyped `DataFrame`, then re-type at the edge: -// or streaming: -for row, rerr := range sparksql.Stream[User](ctx, df) { - if rerr != nil { break } - // use row +```go +typed, _ := sparksql.As[User](df) +narrower, _ := typed.DataFrame().Select(ctx, "id", "email") // back to untyped +ids, _ := sparksql.Collect[struct{ ID string `spark:"id"` }](ctx, narrower) +``` + +### Streaming Results + +`Stream[T]` returns a Go 1.23 [`iter.Seq2[T, error]`](https://pkg.go.dev/iter#Seq2). One of the things Go gives us over the Python / Scala clients is a real pull-based iterator — rows decode one at a time as the gRPC stream resolves them, with constant memory regardless of result size. No need to buffer the whole result, no callback API: just `range`. + +```go +for row, err := range sparksql.Stream[User](ctx, df) { + if err != nil { break } + // use row — decoded from the next Arrow batch as it lands } ``` +Schema binding happens on the first row; if a later row's schema diverges from the first, the error surfaces through the iterator (no per-row panics). + +Use `Stream[T]` when result sets are large, when you want to short-circuit early without dragging the rest of the rows over the wire, or when you're piping into another `iter.Seq2` consumer. + ### `database/sql` driver ```go @@ -73,28 +114,23 @@ import ( _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" ) -db, _ := sql.Open("spark", "sc://spark.internal:15002") +db, _ := sql.Open("spark", "sc://spark.internal:15002") rows, _ := db.QueryContext(ctx, "SELECT id FROM users WHERE tier = $1", "gold") ``` -`$N` placeholders are rendered client-side into Spark SQL literals with type-aware quoting (strings, numbers, bools, `[]byte`, `time.Time`). `?` placeholders aren't supported — most `database/sql`-adjacent codegen (sqlc, goose dialects, pgx patterns) emits `$N`, so the narrower grammar keeps the renderer simple. - -## Features (deltas over upstream) +`$N` placeholders render with type-aware quoting (strings, numbers, bools, `[]byte`, `time.Time`). `?` placeholders aren't supported — most `database/sql`-adjacent codegen (sqlc, goose dialects, pgx patterns) emits `$N`, so the narrower grammar keeps the renderer simple. -- **`database/sql` driver.** `sql.Open("spark", "sc://...")` — any `database/sql` consumer (goose, sqlc-generated code, ad-hoc scripts) plugs in. Registered under name `spark` in `spark/sql/driver`. -- **Typed helpers over `DataFrame`.** `Collect[T]`, `Stream[T]`, `First[T]`, `As[T]`, `DataFrameOf[T]`. Struct tag is `spark:""`. Decode rows into struct types at the materialization edge. -- **`SparkSessionBuilder.WithDialOptions`.** gRPC dial options exposed on the builder. Wire auth interceptors, TLS, observability without wrapping the builder. -- **`sparkerrors.IsClusterNotReady(err)`.** Classified error for cluster cold-start states. lake-orm uses it upstack for retry decisions. -- **Upstream parity.** Tracks `apache/master`; upstream merges flow through periodically with the fork's commits rebased on top. +### Cluster cold-start -## Install +```go +import "github.com/datalake-go/spark-connect-go/spark/sparkerrors" -```bash -go get github.com/datalake-go/spark-connect-go +df, err := session.Sql(ctx, query) +if sparkerrors.IsClusterNotReady(err) { + // retry with backoff — Databricks serverless usually warms in 30-90s +} ``` -Requires a Spark Connect server (Spark 3.4+). See [lake-k8s](https://github.com/datalake-go/lake-k8s) for a pre-baked Spark 4.0 + Iceberg + Delta image and a `docker compose up` laptop stack. - ## Building from source ```bash @@ -110,13 +146,6 @@ git submodule update --init --recursive make gen && make test ``` -## Related - -- [lake-orm](https://github.com/datalake-go/lake-orm) — ORM that uses this fork's typed helpers and `database/sql` driver. -- [lakehouse](https://github.com/datalake-go/lakehouse) — composed runtime that wires this session alongside the ORM, migrations, and dashboard. -- [lake-k8s](https://github.com/datalake-go/lake-k8s) — pre-baked Spark Connect server + laptop-mode compose stack. -- [apache/spark-connect-go](https://github.com/apache/spark-connect-go) — the upstream project this fork tracks. - ## Contributing -Feature work that could land upstream should be proposed against `apache/spark-connect-go` first. Fork-only changes (anything that wouldn't be accepted upstream) stay on this tree. See [CONTRIBUTING.md](CONTRIBUTING.md). +Feature work that could land upstream should be proposed against [`apache/spark-connect-go`](https://github.com/apache/spark-connect-go) first. Fork-only changes (anything that wouldn't be accepted upstream) stay on this tree. See [CONTRIBUTING.md](CONTRIBUTING.md).