Skip to content

Commit a2ba725

Browse files
committed
fix: rename
1 parent f2a039b commit a2ba725

8 files changed

Lines changed: 58 additions & 58 deletions

File tree

docs/generator-api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def perform_init(self, init_input):
8585
# Primary worker: populate work queue
8686
work_items = [chunk.serialize() for chunk in self.create_chunks()]
8787
self.enqueue_work(work_items)
88-
return GlobalInitResult(self.init_identifier)
88+
return InitResult(self.init_identifier)
8989

9090
def process(self):
9191
# All workers: pull from queue until empty

docs/lifecycle.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Understanding when lifecycle methods are called is critical for resource managem
1010
│ ↓ │
1111
│ output_schema (property accessed) │
1212
│ ↓ │
13-
│ perform_init(init_batch) → GlobalInitResult
13+
│ perform_init(init_batch) → InitResult
1414
│ ↓ │
1515
│ setup() ← Acquire resources here (DB connections, files) │
1616
│ ↓ │

docs/protocol.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Client Worker
1414
│◀──── OutputSpec (output schema) ──────│
1515
│ │
1616
│──── GlobalStateInitInput ────────────▶│
17-
│◀──── GlobalInitResult ────────────────│ perform_init()
17+
│◀──── InitResult ────────────────│ perform_init()
1818
│ │
1919
│◀──── Output Batch 1 ──────────────────│ process() yields
2020
│◀──── Output Batch 2 ──────────────────│
@@ -35,7 +35,7 @@ Client Worker
3535
│◀──── OutputSpec (output schema) ──────│
3636
│ │
3737
│──── GlobalStateInitInput ────────────▶│
38-
│◀──── GlobalInitResult ────────────────│ perform_init()
38+
│◀──── InitResult ────────────────│ perform_init()
3939
│ │
4040
│──── Input Batch 1 ───────────────────▶│
4141
│◀──── Output Batch 1 (NEED_MORE_INPUT)─│ transform() / process()

tests/test_protocol_classes.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Unit tests for VGI protocol classes.
22
3-
Tests cover Invocation, Arguments, GlobalInitResult, and table_function classes.
3+
Tests cover Invocation, Arguments, InitResult, and table_function classes.
44
"""
55

66
from __future__ import annotations
@@ -13,7 +13,7 @@
1313
Arg,
1414
Arguments,
1515
ArgumentValidationError,
16-
GlobalInitResult,
16+
InitResult,
1717
Invocation,
1818
InvocationType,
1919
)
@@ -310,20 +310,20 @@ def test_with_global_init_identifier(self) -> None:
310310
global_init_identifier=None,
311311
)
312312

313-
init_result = GlobalInitResult(global_init_identifier=b"init-data")
313+
init_result = InitResult(global_init_identifier=b"init-data")
314314
updated = original.with_global_init_identifier(init_result)
315315

316316
assert updated.function_name == original.function_name
317317
assert updated.global_init_identifier == init_result
318318
assert original.global_init_identifier is None # Original unchanged
319319

320320

321-
class TestGlobalInitResult:
322-
"""Tests for GlobalInitResult serialization."""
321+
class TestInitResult:
322+
"""Tests for InitResult serialization."""
323323

324324
def test_basic_round_trip(self) -> None:
325-
"""GlobalInitResult should serialize and deserialize correctly."""
326-
original = GlobalInitResult(global_init_identifier=b"test-init-id")
325+
"""InitResult should serialize and deserialize correctly."""
326+
original = InitResult(global_init_identifier=b"test-init-id")
327327

328328
serialized = original.serialize()
329329
assert isinstance(serialized, bytes)
@@ -332,20 +332,20 @@ def test_basic_round_trip(self) -> None:
332332

333333
reader = ipc.open_stream(serialized)
334334
batch = reader.read_next_batch()
335-
deserialized = GlobalInitResult.deserialize(batch)
335+
deserialized = InitResult.deserialize(batch)
336336

337337
assert deserialized.global_init_identifier == b"test-init-id"
338338

339339
def test_null_identifier(self) -> None:
340-
"""GlobalInitResult with null identifier should round-trip correctly."""
341-
original = GlobalInitResult(global_init_identifier=None)
340+
"""InitResult with null identifier should round-trip correctly."""
341+
original = InitResult(global_init_identifier=None)
342342

343343
serialized = original.serialize()
344344
from pyarrow import ipc
345345

346346
reader = ipc.open_stream(serialized)
347347
batch = reader.read_next_batch()
348-
deserialized = GlobalInitResult.deserialize(batch)
348+
deserialized = InitResult.deserialize(batch)
349349

350350
assert deserialized.global_init_identifier is None
351351

@@ -357,15 +357,15 @@ def test_has_identifier_true(self) -> None:
357357
[pa.field("global_init_identifier", pa.binary(), nullable=True)]
358358
),
359359
)
360-
assert GlobalInitResult.has_identifier(batch) is True
360+
assert InitResult.has_identifier(batch) is True
361361

362362
def test_has_identifier_false(self) -> None:
363363
"""has_identifier should return False when field doesn't exist."""
364364
batch = pa.RecordBatch.from_pylist(
365365
[{"other_field": "value"}],
366366
schema=make_schema([pa.field("other_field", pa.string())]),
367367
)
368-
assert GlobalInitResult.has_identifier(batch) is False
368+
assert InitResult.has_identifier(batch) is False
369369

370370
def test_deserialize_empty_batch_raises(self) -> None:
371371
"""Deserializing empty batch should raise ValueError."""
@@ -377,7 +377,7 @@ def test_deserialize_empty_batch_raises(self) -> None:
377377
)
378378

379379
with pytest.raises(ValueError, match="empty RecordBatch"):
380-
GlobalInitResult.deserialize(empty_batch)
380+
InitResult.deserialize(empty_batch)
381381

382382
def test_deserialize_multi_row_batch_raises(self) -> None:
383383
"""Deserializing multi-row batch should raise ValueError."""
@@ -392,11 +392,11 @@ def test_deserialize_multi_row_batch_raises(self) -> None:
392392
)
393393

394394
with pytest.raises(ValueError, match="single-row"):
395-
GlobalInitResult.deserialize(multi_row_batch)
395+
InitResult.deserialize(multi_row_batch)
396396

397397
def test_schema(self) -> None:
398398
"""schema() should return correct Arrow schema."""
399-
result = GlobalInitResult(global_init_identifier=b"test")
399+
result = InitResult(global_init_identifier=b"test")
400400
schema = result.schema()
401401

402402
assert len(schema) == 1

vgi/client/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from vgi.function import (
7373
Arguments,
7474
FunctionInitInput,
75-
GlobalInitResult,
75+
InitResult,
7676
Invocation,
7777
InvocationType,
7878
)
@@ -375,7 +375,7 @@ def _initialize_stream_common(
375375
function_type: InvocationType,
376376
bind_result_callback: Callable[[pa.RecordBatch], None] | None,
377377
projection_ids: list[int] | None,
378-
) -> tuple[_BindResult, GlobalInitResult, Invocation]:
378+
) -> tuple[_BindResult, InitResult, Invocation]:
379379
"""Perform the common initialization handshake with the primary worker.
380380
381381
Executes the VGI protocol initialization sequence:
@@ -385,7 +385,7 @@ def _initialize_stream_common(
385385
4. Validates protocol version compatibility
386386
5. Applies CPU/max_workers limits to max_processes
387387
6. Sends init data (FunctionInitInput or TableFunctionInitInput)
388-
7. Reads GlobalInitResult (shared state identifier for parallel workers)
388+
7. Reads InitResult (shared state identifier for parallel workers)
389389
8. Creates an Invocation with global_init_identifier for additional workers
390390
391391
Args:
@@ -406,7 +406,7 @@ def _initialize_stream_common(
406406
Returns:
407407
A tuple of (bind_result, global_init_result, request_with_init):
408408
- bind_result: Parsed _BindResult with output_schema, max_processes
409-
- global_init_result: GlobalInitResult containing shared state ID
409+
- global_init_result: InitResult containing shared state ID
410410
- request_with_init: Invocation with global_init_identifier set,
411411
suitable for initializing additional parallel workers
412412
@@ -496,7 +496,7 @@ def _initialize_stream_common(
496496
except IPCError as e:
497497
raise ClientError(str(e)) from e
498498

499-
global_init_result = GlobalInitResult.deserialize(init_result_batch)
499+
global_init_result = InitResult.deserialize(init_result_batch)
500500
log.debug(
501501
"init_result_received",
502502
has_identifier=global_init_result.global_init_identifier is not None,
@@ -835,7 +835,7 @@ def _initialize_additional_worker(
835835
and stdout_buffered handles. The data_writer field will be set
836836
if input_schema is not None.
837837
request_with_init: Invocation containing the global_init_identifier
838-
from the primary worker's GlobalInitResult. This ensures all
838+
from the primary worker's InitResult. This ensures all
839839
workers share the same initialization state.
840840
input_schema: Schema for the input data stream. If provided, a
841841
RecordBatchStreamWriter is created and assigned to worker.data_writer.

vgi/examples/table.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pyarrow as pa
2020

2121
from vgi.arguments import Arg
22-
from vgi.function import GlobalInitResult
22+
from vgi.function import InitResult
2323
from vgi.log import Level, Message
2424
from vgi.metadata import FunctionExample
2525
from vgi.table_function import (
@@ -473,7 +473,7 @@ def cardinality(self) -> CardinalityInfo:
473473
"""
474474
return CardinalityInfo(estimate=self.count, max=self.count)
475475

476-
def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult:
476+
def perform_init(self, init_input: pa.RecordBatch) -> InitResult:
477477
"""Populate the work queue with range chunks."""
478478
# Parse init data and store in init_storage
479479
self.init_data = TableFunctionInitInput.deserialize(init_input)
@@ -489,7 +489,7 @@ def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult:
489489
if work_items:
490490
self.enqueue_work(work_items)
491491

492-
return GlobalInitResult(self.init_identifier)
492+
return InitResult(self.init_identifier)
493493

494494
def process(self) -> OutputGenerator:
495495
"""Generate values by pulling chunks from the work queue."""

vgi/function.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"Arguments",
5353
"Function",
5454
"InvocationType",
55-
"GlobalInitResult",
55+
"InitResult",
5656
"Level",
5757
"Message",
5858
"OutputSpec",
@@ -253,11 +253,11 @@ def deserialize(cls, data: bytes) -> Self:
253253

254254

255255
@dataclass(frozen=True, slots=True)
256-
class GlobalInitResult:
256+
class InitResult:
257257
"""Result from the global initialization phase of a function.
258258
259259
When a function supports parallel execution (max_processes > 1), the first
260-
worker runs perform_init() which returns a GlobalInitResult. This result
260+
worker runs perform_init() which returns a InitResult. This result
261261
contains an identifier that is passed to all subsequent parallel workers
262262
via retrieve_init(), allowing them to share state or coordinate processing.
263263
@@ -285,7 +285,7 @@ def has_identifier(cls, data: pa.RecordBatch) -> bool:
285285
return cls._IDENTIFIER_FIELD_NAME in data.schema.names
286286

287287
def schema(self) -> pa.Schema:
288-
"""Return Arrow schema used when serializing GlobalInitResult.
288+
"""Return Arrow schema used when serializing InitResult.
289289
290290
Returns:
291291
Arrow schema with fields for each serialized attribute.
@@ -298,10 +298,10 @@ def schema(self) -> pa.Schema:
298298
)
299299

300300
def serialize(self) -> bytes:
301-
"""Serialize GlobalInitResult to an Arrow RecordBatch.
301+
"""Serialize InitResult to an Arrow RecordBatch.
302302
303303
Returns:
304-
RecordBatch containing serialized GlobalInitResult fields.
304+
RecordBatch containing serialized InitResult fields.
305305
306306
"""
307307
batch = pa.RecordBatch.from_pylist(
@@ -315,20 +315,20 @@ def serialize(self) -> bytes:
315315
return vgi.ipc_utils.serialize_record_batch(batch)
316316

317317
@classmethod
318-
def deserialize(cls, data: pa.RecordBatch) -> "GlobalInitResult":
319-
"""Deserialize GlobalInitResult from an Arrow RecordBatch.
318+
def deserialize(cls, data: pa.RecordBatch) -> "InitResult":
319+
"""Deserialize InitResult from an Arrow RecordBatch.
320320
321321
Args:
322-
data: RecordBatch containing serialized GlobalInitResult fields.
322+
data: RecordBatch containing serialized InitResult fields.
323323
324324
Returns:
325-
Deserialized GlobalInitResult instance.
325+
Deserialized InitResult instance.
326326
327327
"""
328328
first_row = vgi.ipc_utils.validate_single_row_batch(
329-
data, "GlobalInitResult", required_fields=[cls._IDENTIFIER_FIELD_NAME]
329+
data, "InitResult", required_fields=[cls._IDENTIFIER_FIELD_NAME]
330330
)
331-
return GlobalInitResult(
331+
return InitResult(
332332
global_init_identifier=first_row[cls._IDENTIFIER_FIELD_NAME],
333333
)
334334

@@ -383,13 +383,13 @@ class Invocation:
383383
# The unique identifier for the call, typically this may be a uuid.
384384
invocation_id: bytes | None
385385

386-
global_init_identifier: GlobalInitResult | None = None
386+
global_init_identifier: InitResult | None = None
387387
arguments: Arguments = Arguments()
388388
client_features: frozenset[str] = frozenset()
389389
attach_id: bytes | None = None
390390

391391
def with_global_init_identifier(
392-
self, global_init_identifier: GlobalInitResult
392+
self, global_init_identifier: InitResult
393393
) -> "Invocation":
394394
"""Return a new Invocation with the given global_init_identifier."""
395395
return replace(self, global_init_identifier=global_init_identifier)
@@ -423,7 +423,7 @@ def serialize(self) -> bytes:
423423
"function_type": self.function_type.value,
424424
"invocation_id": self.invocation_id,
425425
"correlation_id": self.correlation_id,
426-
GlobalInitResult._IDENTIFIER_FIELD_NAME: (
426+
InitResult._IDENTIFIER_FIELD_NAME: (
427427
self.global_init_identifier.global_init_identifier
428428
if self.global_init_identifier
429429
else None
@@ -441,7 +441,7 @@ def serialize(self) -> bytes:
441441
pa.field("invocation_id", pa.binary(), nullable=True),
442442
pa.field("correlation_id", pa.string(), nullable=False),
443443
pa.field(
444-
GlobalInitResult._IDENTIFIER_FIELD_NAME,
444+
InitResult._IDENTIFIER_FIELD_NAME,
445445
pa.binary(),
446446
nullable=True,
447447
),
@@ -486,13 +486,13 @@ def deserialize(data: pa.RecordBatch) -> "Invocation":
486486
# Parse function_type from string value
487487
function_type = InvocationType(first_row["function_type"])
488488

489-
# Parse global_init_identifier - only create GlobalInitResult if field exists
489+
# Parse global_init_identifier - only create InitResult if field exists
490490
# and has a non-None value
491491
global_init_identifier = None
492-
if GlobalInitResult._IDENTIFIER_FIELD_NAME in data.schema.names:
493-
identifier_value = first_row[GlobalInitResult._IDENTIFIER_FIELD_NAME]
492+
if InitResult._IDENTIFIER_FIELD_NAME in data.schema.names:
493+
identifier_value = first_row[InitResult._IDENTIFIER_FIELD_NAME]
494494
if identifier_value is not None:
495-
global_init_identifier = GlobalInitResult(identifier_value)
495+
global_init_identifier = InitResult(identifier_value)
496496

497497
# Parse client_features - default to empty set for backward compatibility
498498
client_features: frozenset[str] = frozenset()
@@ -1211,7 +1211,7 @@ def enqueue_work(self, work_items: list[bytes]) -> int:
12111211
ValueError: If init_identifier has not been set.
12121212
12131213
Example:
1214-
def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult:
1214+
def perform_init(self, init_input: pa.RecordBatch) -> InitResult:
12151215
result = super().perform_init(init_input)
12161216
# Create work items (e.g., ranges to process)
12171217
work_items = [struct.pack(">QQ", start, end) for start, end in ranges]
@@ -1314,20 +1314,20 @@ def _validate_input_schema(self, batch: pa.RecordBatch) -> None:
13141314
context=f"input to {type(self).__name__}",
13151315
)
13161316

1317-
def perform_init(self, init_input: pa.RecordBatch) -> GlobalInitResult:
1317+
def perform_init(self, init_input: pa.RecordBatch) -> InitResult:
13181318
"""Perform a new init call and store it in the storage."""
13191319
self.init_data = self.InitDataCls.deserialize(init_input)
13201320
assert self.init_data is not None
13211321
self.init_identifier = self.init_storage.create(self.init_data.serialize())
1322-
return GlobalInitResult(self.init_identifier)
1322+
return InitResult(self.init_identifier)
13231323

1324-
def retrieve_init(self, init_input: GlobalInitResult) -> None:
1324+
def retrieve_init(self, init_input: InitResult) -> None:
13251325
"""Retrieve and store init data from the storage."""
13261326
if init_input.global_init_identifier is None:
13271327
raise ValueError(
13281328
"global_init_identifier is required but was None. "
1329-
"This indicates the GlobalInitResult was not properly initialized. "
1330-
"Ensure perform_init() returns a GlobalInitResult with a valid "
1329+
"This indicates the InitResult was not properly initialized. "
1330+
"Ensure perform_init() returns a InitResult with a valid "
13311331
"identifier."
13321332
)
13331333
self.init_identifier = init_input.global_init_identifier

0 commit comments

Comments
 (0)