Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit b2446ab

Browse files
committed
fix code
1 parent 1f00e3c commit b2446ab

1 file changed

Lines changed: 14 additions & 28 deletions

File tree

  • bigframes/bigquery/_operations

bigframes/bigquery/_operations/ai.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
import functools
1817
import json
19-
from typing import Any, List, Literal, Mapping, Sequence, Tuple
18+
from typing import Any, List, Literal, Mapping, Tuple
2019

2120
from bigframes import clients, dtypes, series
2221
from bigframes.operations import ai_ops
@@ -101,13 +100,8 @@ def ai_generate_bool(
101100
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
102101
"""
103102

104-
if request_type not in ("dedicated", "shared", "unspecified"):
105-
raise ValueError(f"Unsupported request type: {request_type}")
106-
107103
prompt_context, series_list = _separate_context_and_series(prompt)
108-
109-
if not series_list:
110-
raise ValueError("Please provide at least one Series in the prompt")
104+
assert len(series_list) > 0
111105

112106
operator = ai_ops.AIGenerateBool(
113107
prompt_context=tuple(prompt_context),
@@ -120,9 +114,8 @@ def ai_generate_bool(
120114
return series_list[0]._apply_nary_op(operator, series_list[1:])
121115

122116

123-
@functools.singledispatch
124117
def _separate_context_and_series(
125-
prompt: Any,
118+
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
126119
) -> Tuple[List[str | None], List[series.Series]]:
127120
"""
128121
Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
@@ -131,24 +124,14 @@ def _separate_context_and_series(
131124
Input: ("str1", series1, "str2", "str3", series2)
132125
Output: ["str1", None, "str2", "str3", None], [series1, series2]
133126
"""
134-
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
135-
127+
if not isinstance(prompt, (list, tuple, series.Series)):
128+
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
136129

137-
@_separate_context_and_series.register
138-
def _(
139-
prompt: series.Series,
140-
) -> Tuple[List[str | None], List[series.Series]]:
141-
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
142-
# Multi-model support
143-
return [None], [prompt.blob.read_url()]
144-
return [None], [prompt]
145-
146-
147-
@_separate_context_and_series.register(list)
148-
@_separate_context_and_series.register(tuple)
149-
def _(
150-
prompt: Sequence[str | series.Series],
151-
) -> Tuple[List[str | None], List[series.Series]]:
130+
if isinstance(prompt, series.Series):
131+
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
132+
# Multi-model support
133+
return [None], [prompt.blob.read_url()]
134+
return [None], [prompt]
152135

153136
prompt_context: List[str | None] = []
154137
series_list: List[series.Series] = []
@@ -166,7 +149,10 @@ def _(
166149
series_list.append(item)
167150

168151
else:
169-
raise ValueError(f"Unsupported type in prompt: {type(item)}")
152+
raise TypeError(f"Unsupported type in prompt: {type(item)}")
153+
154+
if not series_list:
155+
raise ValueError("Please provide at least one Series in the prompt")
170156

171157
return prompt_context, series_list
172158

0 commit comments

Comments
 (0)