Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit 7fb3cdc

Browse files
make StringMethods generically typed
1 parent f055619 commit 7fb3cdc

3 files changed

Lines changed: 49 additions & 44 deletions

File tree

bigframes/operations/strings.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@
1515
from __future__ import annotations
1616

1717
import re
18-
from typing import Literal, Optional, TYPE_CHECKING, Union
18+
from typing import Generic, Literal, Optional, TypeVar, Union
1919

2020
import bigframes_vendored.constants as constants
2121
import bigframes_vendored.pandas.core.strings.accessor as vendorstr
2222

2323
from bigframes.core import log_adapter
24+
import bigframes.core.indexes.base as indices
2425
import bigframes.dataframe as df
2526
import bigframes.operations as ops
2627
from bigframes.operations._op_converters import convert_index, convert_slice
2728
import bigframes.operations.aggregations as agg_ops
28-
29-
if TYPE_CHECKING:
30-
import bigframes.core.indexes.base as indices
31-
import bigframes.series as series
29+
import bigframes.series as series
3230

3331
# Maps from python to re2
3432
REGEXP_FLAGS = {
@@ -37,15 +35,17 @@
3735
re.DOTALL: "s",
3836
}
3937

38+
T = TypeVar("T", series.Series, indices.Index)
39+
4040

4141
@log_adapter.class_logger
42-
class StringMethods(vendorstr.StringMethods):
42+
class StringMethods(vendorstr.StringMethods, Generic[T]):
4343
__doc__ = vendorstr.StringMethods.__doc__
4444

45-
def __init__(self, data: Union[series.Series, indices.Index]):
46-
self._data = data
45+
def __init__(self, data: T):
46+
self._data: T = data
4747

48-
def __getitem__(self, key: Union[int, slice]) -> series.Series:
48+
def __getitem__(self, key: Union[int, slice]) -> T:
4949
if isinstance(key, int):
5050
return self._data._apply_unary_op(convert_index(key))
5151
elif isinstance(key, slice):
@@ -58,18 +58,18 @@ def find(
5858
sub: str,
5959
start: Optional[int] = None,
6060
end: Optional[int] = None,
61-
) -> series.Series:
61+
) -> T:
6262
return self._data._apply_unary_op(
6363
ops.StrFindOp(substr=sub, start=start, end=end)
6464
)
6565

66-
def len(self) -> series.Series:
66+
def len(self) -> T:
6767
return self._data._apply_unary_op(ops.len_op)
6868

69-
def lower(self) -> series.Series:
69+
def lower(self) -> T:
7070
return self._data._apply_unary_op(ops.lower_op)
7171

72-
def reverse(self) -> series.Series:
72+
def reverse(self) -> T:
7373
"""Reverse strings in the Series.
7474
7575
**Examples:**
@@ -94,103 +94,103 @@ def slice(
9494
self,
9595
start: Optional[int] = None,
9696
stop: Optional[int] = None,
97-
) -> series.Series:
97+
) -> T:
9898
return self._data._apply_unary_op(ops.StrSliceOp(start=start, end=stop))
9999

100-
def strip(self, to_strip: Optional[str] = None) -> series.Series:
100+
def strip(self, to_strip: Optional[str] = None) -> T:
101101
return self._data._apply_unary_op(
102102
ops.StrStripOp(to_strip=" \n\t" if to_strip is None else to_strip)
103103
)
104104

105-
def upper(self) -> series.Series:
105+
def upper(self) -> T:
106106
return self._data._apply_unary_op(ops.upper_op)
107107

108-
def isnumeric(self) -> series.Series:
108+
def isnumeric(self) -> T:
109109
return self._data._apply_unary_op(ops.isnumeric_op)
110110

111111
def isalpha(
112112
self,
113-
) -> series.Series:
113+
) -> T:
114114
return self._data._apply_unary_op(ops.isalpha_op)
115115

116116
def isdigit(
117117
self,
118-
) -> series.Series:
118+
) -> T:
119119
return self._data._apply_unary_op(ops.isdigit_op)
120120

121121
def isdecimal(
122122
self,
123-
) -> series.Series:
123+
) -> T:
124124
return self._data._apply_unary_op(ops.isdecimal_op)
125125

126126
def isalnum(
127127
self,
128-
) -> series.Series:
128+
) -> T:
129129
return self._data._apply_unary_op(ops.isalnum_op)
130130

131131
def isspace(
132132
self,
133-
) -> series.Series:
133+
) -> T:
134134
return self._data._apply_unary_op(ops.isspace_op)
135135

136136
def islower(
137137
self,
138-
) -> series.Series:
138+
) -> T:
139139
return self._data._apply_unary_op(ops.islower_op)
140140

141141
def isupper(
142142
self,
143-
) -> series.Series:
143+
) -> T:
144144
return self._data._apply_unary_op(ops.isupper_op)
145145

146-
def rstrip(self, to_strip: Optional[str] = None) -> series.Series:
146+
def rstrip(self, to_strip: Optional[str] = None) -> T:
147147
return self._data._apply_unary_op(
148148
ops.StrRstripOp(to_strip=" \n\t" if to_strip is None else to_strip)
149149
)
150150

151-
def lstrip(self, to_strip: Optional[str] = None) -> series.Series:
151+
def lstrip(self, to_strip: Optional[str] = None) -> T:
152152
return self._data._apply_unary_op(
153153
ops.StrLstripOp(to_strip=" \n\t" if to_strip is None else to_strip)
154154
)
155155

156-
def repeat(self, repeats: int) -> series.Series:
156+
def repeat(self, repeats: int) -> T:
157157
return self._data._apply_unary_op(ops.StrRepeatOp(repeats=repeats))
158158

159-
def capitalize(self) -> series.Series:
159+
def capitalize(self) -> T:
160160
return self._data._apply_unary_op(ops.capitalize_op)
161161

162-
def match(self, pat, case=True, flags=0) -> series.Series:
162+
def match(self, pat, case=True, flags=0) -> T:
163163
# \A anchors start of entire string rather than start of any line in multiline mode
164164
adj_pat = rf"\A{pat}"
165165
return self.contains(pat=adj_pat, case=case, flags=flags)
166166

167-
def fullmatch(self, pat, case=True, flags=0) -> series.Series:
167+
def fullmatch(self, pat, case=True, flags=0) -> T:
168168
# \A anchors start of entire string rather than start of any line in multiline mode
169169
# \z likewise anchors to the end of the entire multiline string
170170
adj_pat = rf"\A{pat}\z"
171171
return self.contains(pat=adj_pat, case=case, flags=flags)
172172

173-
def get(self, i: int) -> series.Series:
173+
def get(self, i: int) -> T:
174174
return self._data._apply_unary_op(ops.StrGetOp(i=i))
175175

176-
def pad(self, width, side="left", fillchar=" ") -> series.Series:
176+
def pad(self, width, side="left", fillchar=" ") -> T:
177177
return self._data._apply_unary_op(
178178
ops.StrPadOp(length=width, fillchar=fillchar, side=side)
179179
)
180180

181-
def ljust(self, width, fillchar=" ") -> series.Series:
181+
def ljust(self, width, fillchar=" ") -> T:
182182
return self._data._apply_unary_op(
183183
ops.StrPadOp(length=width, fillchar=fillchar, side="right")
184184
)
185185

186-
def rjust(self, width, fillchar=" ") -> series.Series:
186+
def rjust(self, width, fillchar=" ") -> T:
187187
return self._data._apply_unary_op(
188188
ops.StrPadOp(length=width, fillchar=fillchar, side="left")
189189
)
190190

191191
def contains(
192192
self, pat, case: bool = True, flags: int = 0, *, regex: bool = True
193-
) -> series.Series:
193+
) -> T:
194194
if not case:
195195
return self.contains(pat=pat, flags=flags | re.IGNORECASE, regex=True)
196196
if regex:
@@ -235,7 +235,7 @@ def replace(
235235
case: Optional[bool] = None,
236236
flags: int = 0,
237237
regex: bool = False,
238-
) -> series.Series:
238+
) -> T:
239239
if isinstance(pat, re.Pattern):
240240
assert isinstance(pat.pattern, str)
241241
pat_str = pat.pattern
@@ -262,15 +262,15 @@ def replace(
262262
def startswith(
263263
self,
264264
pat: Union[str, tuple[str, ...]],
265-
) -> series.Series:
265+
) -> T:
266266
if not isinstance(pat, tuple):
267267
pat = (pat,)
268268
return self._data._apply_unary_op(ops.StartsWithOp(pat=pat))
269269

270270
def endswith(
271271
self,
272272
pat: Union[str, tuple[str, ...]],
273-
) -> series.Series:
273+
) -> T:
274274
if not isinstance(pat, tuple):
275275
pat = (pat,)
276276
return self._data._apply_unary_op(ops.EndsWithOp(pat=pat))
@@ -279,7 +279,7 @@ def split(
279279
self,
280280
pat: str = " ",
281281
regex: Union[bool, None] = None,
282-
) -> series.Series:
282+
) -> T:
283283
if regex is True or (regex is None and len(pat) > 1):
284284
raise NotImplementedError(
285285
"Regular expressions aren't currently supported. Please set "
@@ -297,18 +297,18 @@ def center(self, width: int, fillchar: str = " ") -> series.Series:
297297

298298
def cat(
299299
self,
300-
others: Union[str, series.Series],
300+
others: Union[str, indices.Index, series.Series],
301301
*,
302302
join: Literal["outer", "left"] = "left",
303-
) -> series.Series:
303+
) -> T:
304304
return self._data._apply_binary_op(others, ops.strconcat_op, alignment=join)
305305

306-
def join(self, sep: str) -> series.Series:
306+
def join(self, sep: str) -> T:
307307
return self._data._apply_unary_op(
308308
ops.ArrayReduceOp(aggregation=agg_ops.StringAggOp(sep=sep))
309309
)
310310

311-
def to_blob(self, connection: Optional[str] = None) -> series.Series:
311+
def to_blob(self, connection: Optional[str] = None) -> T:
312312
"""Create a BigFrames Blob series from a series of URIs.
313313
314314
.. note::

bigframes/series.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@
7474
import bigframes.operations.datetimes as dt
7575
import bigframes.operations.lists as lists
7676
import bigframes.operations.plotting as plotting
77-
import bigframes.operations.strings as strings
7877
import bigframes.operations.structs as structs
7978
import bigframes.session
8079

8180
if typing.TYPE_CHECKING:
8281
import bigframes.geopandas.geoseries
82+
import bigframes.operations.strings as strings
83+
8384

8485
LevelType = typing.Union[str, int]
8586
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
@@ -2649,6 +2650,8 @@ def _cached(self, *, force: bool = True, session_aware: bool = True) -> Series:
26492650
# confusing type checker by overriding str
26502651
@property
26512652
def str(self) -> strings.StringMethods:
2653+
import bigframes.operations.strings as strings
2654+
26522655
return strings.StringMethods(self)
26532656

26542657
@property

tests/system/small/test_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,8 @@ def test_index_str_accessor_unary(scalars_df_index, scalars_pandas_df_index):
698698

699699

700700
def test_index_str_accessor_binary(scalars_df_index, scalars_pandas_df_index):
701+
if pd.__version__.startswith("1."):
702+
pytest.skip("doesn't work in pandas 1.x.")
701703
bf_index = scalars_df_index.set_index("string_col").index
702704
pd_index = scalars_pandas_df_index.set_index("string_col").index
703705

0 commit comments

Comments
 (0)