Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 231 additions & 26 deletions dpsynth/text/bulk_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from collections.abc import Sequence
import dataclasses
import enum
import functools
import re
import time
from typing import Protocol

from absl import logging
Expand Down Expand Up @@ -63,10 +65,11 @@ def annotate(

Args:
texts: Input texts to annotate.
schema: Pydantic model class defining the features to extract. The model's
field names, ``Literal`` type annotations, and field descriptions guide
the LLM. This same class is used as the ``response_schema`` for
constrained decoding in supported backends.
schema: Pydantic model class defining the features to extract. Fields may
use ``Literal`` type annotations for constrained decoding (the model is
forced to pick from the allowed values) or plain types such as ``str``
for open-ended annotation where the model can produce any value. Field
names and descriptions guide the LLM.
system_prompt: System-level instructions for the LLM describing how to
annotate the texts.

Expand All @@ -90,7 +93,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]:
...


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class GenAIBackend:
"""TextGenerationBackend using the google.genai API.

Expand All @@ -100,19 +103,26 @@ class GenAIBackend:
Attributes:
model: Model name string (e.g., ``'gemini-2.5-flash-lite'``). Accepts any
``ModelName`` enum value or arbitrary string for unlisted models.
api_key: API key for authentication. If None, uses Application Default
Credentials (ADC).
api_key: API key for authentication.
poll_interval_seconds: How often to poll for batch job completion.
chunk_size: Number of texts per batch job.
max_concurrent_jobs: Maximum number of active parallel batch jobs.
"""

model: str = ModelName.GEMINI_2_5_FLASH_LITE
api_key: str | None = None
poll_interval_seconds: int = 60
chunk_size: int = 100
max_concurrent_jobs: int = 2

def _make_client(self) -> genai.Client:
"""Creates a genai client."""
kwargs: dict[str, object] = {
'http_options': types.HttpOptions(api_version='v1alpha'),
}
if self.api_key is not None:
# NOTE: client is cached on first access. Do not mutate attributes
# (model, api_key) after the client has
# been created — the cached instance will not reflect the changes.
@functools.cached_property
def client(self) -> genai.Client:
"""Creates and caches a genai.Client."""
kwargs = {'http_options': types.HttpOptions(api_version='v1alpha')}
if self.api_key:
kwargs['api_key'] = self.api_key
return genai.Client(**kwargs)

Expand All @@ -122,31 +132,40 @@ def annotate(
schema: type[pydantic.BaseModel],
system_prompt: str,
) -> pd.DataFrame:
"""Extract structured features via constrained decoding.
"""Extract structured features via google.genai API (sequential).

Always passes the ``schema`` as the ``response_schema`` to
``generate_content``. When the schema contains ``Literal`` fields the
model is constrained to the allowed values; schemas with plain types
(e.g. ``str``) still benefit from the structural guidance but allow the
model to produce any value.

Args:
texts: Input texts to annotate.
schema: Pydantic model used as the ``response_schema`` for constrained
decoding.
decoding when it contains ``Literal`` fields. Schemas with plain types
(e.g. ``str``) trigger free-form JSON generation guided by the system
prompt and field descriptions.
system_prompt: System-level instructions for the LLM.

Returns:
DataFrame with exactly ``len(texts)`` rows. Failed rows have ``None``.
"""
client = self._make_client()
client = self.client
field_names = list(schema.model_fields.keys())
null_row = {f: None for f in field_names}
rows: list[dict[str, str | None]] = []
config = types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type='application/json',
response_schema=schema,
)
for i, text in enumerate(texts):
try:
response = client.models.generate_content(
model=self.model,
contents=text,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type='application/json',
response_schema=schema,
),
config=config,
)
if response.text:
cleaned = _strip_markdown_fences(response.text)
Expand All @@ -155,11 +174,195 @@ def annotate(
else:
logging.warning('Empty annotation response for text %d.', i)
rows.append(null_row)
except Exception: # pylint: disable=broad-except
logging.warning('Annotation failed for text %d.', i)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Annotation failed for text %d. Error: %s', i, e, exc_info=True
)
rows.append(null_row)
return pd.DataFrame(rows)

def batch_annotate(
self,
texts: Sequence[str],
schema: type[pydantic.BaseModel],
system_prompt: str,
chunk_size: int | None = None,
max_concurrent_jobs: int | None = None,
) -> pd.DataFrame:
"""Extract structured features via the GenAI Batch API.

Submits texts as inlined requests to the batch prediction endpoint,
polls for completion, and parses the inlined responses.

Args:
texts: Input texts to annotate.
schema: Pydantic model used as the ``response_schema``.
system_prompt: System-level instructions for the LLM.
chunk_size: Number of texts per batch job.
max_concurrent_jobs: Maximum number of active parallel batch jobs.

Returns:
DataFrame with exactly ``len(texts)`` rows. Failed rows have ``None``.

Raises:
RuntimeError: If the batch job fails or is cancelled.
"""
client = self.client
field_names = list(schema.model_fields.keys())
null_row = {f: None for f in field_names}

if chunk_size is None:
chunk_size = self.chunk_size
if max_concurrent_jobs is None:
max_concurrent_jobs = self.max_concurrent_jobs

if chunk_size <= 0:
raise ValueError('chunk_size must be positive.')
if max_concurrent_jobs <= 0:
raise ValueError('max_concurrent_jobs must be positive.')

jobs = []

config = types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type='application/json',
response_schema=schema,
)

offsets = list(range(0, len(texts), chunk_size))
num_chunks = len(offsets)
active_jobs = []
chunk_idx = 0

logging.info(
'Batch annotate: starting processing of %d chunks with concurrency'
' limit %d...',
num_chunks,
max_concurrent_jobs,
)

while chunk_idx < num_chunks or active_jobs:
# Submit new jobs up to the concurrency limit
while len(active_jobs) < max_concurrent_jobs and chunk_idx < num_chunks:
offset = offsets[chunk_idx]
chunk_texts = texts[offset : offset + chunk_size]
logging.info(
'Batch annotate: submitting inline chunk %d/%d (size=%d)...',
chunk_idx + 1,
num_chunks,
len(chunk_texts),
)
inlined_requests = [
types.InlinedRequest(contents=text, config=config)
for text in chunk_texts
]
batch_job = client.batches.create(
model=self.model,
src=inlined_requests,
)
logging.info(
'Batch annotate: job %s created for chunk %d/%d',
batch_job.name,
chunk_idx + 1,
num_chunks,
)
job_info = {
'chunk_idx': chunk_idx,
'chunk_texts': chunk_texts,
'job_name': batch_job.name,
'job': batch_job,
}
jobs.append(job_info)
active_jobs.append(job_info)
chunk_idx += 1

# Poll active jobs if there are any
if active_jobs:
logging.info(
'Batch annotate: %d active jobs. Polling in %ds...',
len(active_jobs),
self.poll_interval_seconds,
)
time.sleep(self.poll_interval_seconds)

for j in active_jobs:
try:
j['job'] = client.batches.get(name=j['job_name'])
except Exception as e: # pylint: disable=broad-except
logging.warning('Failed to poll job %s: %s', j['job_name'], e)

# Filter out finished jobs
still_active = []
for j in active_jobs:
if j['job'].done:
logging.info(
'Batch annotate: job %s completed with state=%s',
j['job_name'],
j['job'].state,
)
else:
still_active.append(j)
active_jobs = still_active

logging.info('Batch annotate: all jobs completed. Parsing responses...')

# Step 4: Parse responses in order
all_rows = []
for j in jobs:
batch_job = j['job']
chunk_texts = j['chunk_texts']
job_name = j['job_name']

if batch_job.state != types.JobState.JOB_STATE_SUCCEEDED:
error_msg = f'Batch job {job_name} ended with state={batch_job.state}.'
if batch_job.error:
error_msg += f' Error: {batch_job.error}'
raise RuntimeError(error_msg)

inlined_responses = (
batch_job.dest.inlined_responses if batch_job.dest else []
) or []

chunk_rows = []
for i, inlined_resp in enumerate(inlined_responses):
try:
if inlined_resp.error:
logging.warning(
'Batch result %d in job %s had error: %s',
i,
job_name,
inlined_resp.error,
)
chunk_rows.append(null_row)
continue

response = inlined_resp.response
if response and response.text:
cleaned = _strip_markdown_fences(response.text)
parsed = schema.model_validate_json(cleaned)
chunk_rows.append(parsed.model_dump())
else:
logging.warning(
'Empty batch response in job %s for text %d.', job_name, i
)
chunk_rows.append(null_row)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Failed to parse batch result %d in job %s: %s', i, job_name, e
)
chunk_rows.append(null_row)

# Ensure index alignment for this chunk
if len(chunk_rows) != len(chunk_texts):
raise ValueError(
f'Batch annotate: job {job_name} got {len(chunk_rows)} results for'
f' {len(chunk_texts)} inputs.'
)

all_rows.extend(chunk_rows)

return pd.DataFrame(all_rows)

def generate(self, prompts: Sequence[str]) -> list[str]:
"""Generate free-form text via google.genai.

Expand All @@ -169,7 +372,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]:
Returns:
List of exactly ``len(prompts)`` strings. Empty string on failure.
"""
client = self._make_client()
client = self.client
results: list[str] = []
for i, prompt in enumerate(prompts):
try:
Expand All @@ -178,8 +381,10 @@ def generate(self, prompts: Sequence[str]) -> list[str]:
contents=prompt,
)
results.append(response.text or '')
except Exception: # pylint: disable=broad-except
logging.warning('Generation failed for prompt %d.', i)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Generation failed for prompt %d. Error: %s', i, e, exc_info=True
)
results.append('')
return results

Expand Down
Loading
Loading