1414
1515from __future__ import annotations
1616
17- import functools
1817import json
19- from typing import Any , List , Literal , Mapping , Sequence , Tuple
18+ from typing import Any , List , Literal , Mapping , Tuple
2019
2120from bigframes import clients , dtypes , series
2221from bigframes .operations import ai_ops
@@ -101,13 +100,8 @@ def ai_generate_bool(
101100 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
102101 """
103102
104- if request_type not in ("dedicated" , "shared" , "unspecified" ):
105- raise ValueError (f"Unsupported request type: { request_type } " )
106-
107103 prompt_context , series_list = _separate_context_and_series (prompt )
108-
109- if not series_list :
110- raise ValueError ("Please provide at least one Series in the prompt" )
104+ assert len (series_list ) > 0
111105
112106 operator = ai_ops .AIGenerateBool (
113107 prompt_context = tuple (prompt_context ),
@@ -120,9 +114,8 @@ def ai_generate_bool(
120114 return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
121115
122116
123- @functools .singledispatch
124117def _separate_context_and_series (
125- prompt : Any ,
118+ prompt : series . Series | List [ str | series . Series ] | Tuple [ str | series . Series , ...] ,
126119) -> Tuple [List [str | None ], List [series .Series ]]:
127120 """
128121 Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
@@ -131,24 +124,14 @@ def _separate_context_and_series(
131124 Input: ("str1", series1, "str2", "str3", series2)
132125 Output: ["str1", None, "str2", "str3", None], [series1, series2]
133126 """
134- raise ValueError ( f"Unsupported prompt type: { type ( prompt ) } " )
135-
127+ if not isinstance ( prompt , ( list , tuple , series . Series )):
128+ raise ValueError ( f"Unsupported prompt type: { type ( prompt ) } " )
136129
137- @_separate_context_and_series .register
138- def _ (
139- prompt : series .Series ,
140- ) -> Tuple [List [str | None ], List [series .Series ]]:
141- if prompt .dtype == dtypes .OBJ_REF_DTYPE :
142- # Multi-model support
143- return [None ], [prompt .blob .read_url ()]
144- return [None ], [prompt ]
145-
146-
147- @_separate_context_and_series .register (list )
148- @_separate_context_and_series .register (tuple )
149- def _ (
150- prompt : Sequence [str | series .Series ],
151- ) -> Tuple [List [str | None ], List [series .Series ]]:
130+ if isinstance (prompt , series .Series ):
131+ if prompt .dtype == dtypes .OBJ_REF_DTYPE :
132+ # Multi-model support
133+ return [None ], [prompt .blob .read_url ()]
134+ return [None ], [prompt ]
152135
153136 prompt_context : List [str | None ] = []
154137 series_list : List [series .Series ] = []
@@ -166,7 +149,10 @@ def _(
166149 series_list .append (item )
167150
168151 else :
169- raise ValueError (f"Unsupported type in prompt: { type (item )} " )
152+ raise TypeError (f"Unsupported type in prompt: { type (item )} " )
153+
154+ if not series_list :
155+ raise ValueError ("Please provide at least one Series in the prompt" )
170156
171157 return prompt_context , series_list
172158
0 commit comments