Skip to content

Commit dbd2f63

Browse files
committed
fixup! feat: make temporal partition_mapper timezone aware
1 parent 9e201c0 commit dbd2f63

1 file changed

Lines changed: 17 additions & 12 deletions

File tree

airflow-core/src/airflow/partition_mappers/temporal.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
from abc import ABC, abstractmethod
2020
from datetime import datetime, timedelta
21-
from typing import Any
22-
23-
from pendulum import FixedTimezone, Timezone
21+
from typing import TYPE_CHECKING, Any
2422

2523
from airflow._shared.timezones.timezone import parse_timezone
2624
from airflow.partition_mappers.base import PartitionMapper
2725

26+
if TYPE_CHECKING:
27+
from pendulum import FixedTimezone, Timezone
28+
2829

2930
class _BaseTemporalMapper(PartitionMapper, ABC):
3031
"""Base class for Temporal Partition Mappers."""
@@ -33,10 +34,10 @@ class _BaseTemporalMapper(PartitionMapper, ABC):
3334

3435
def __init__(
3536
self,
36-
*
37-
input_format: str = "%Y-%m-%dT%H:%M:%S",
38-
output_format: str | None = None,
37+
*,
3938
timezone: str | Timezone | FixedTimezone,
39+
input_format: str = "%Y-%m-%dT%H:%M:%S%z",
40+
output_format: str | None = None,
4041
):
4142
self.input_format = input_format
4243
self.output_format = output_format or self.default_output_format
@@ -58,14 +59,18 @@ def format(self, dt: datetime) -> str:
5859
return dt.strftime(self.output_format)
5960

6061
def serialize(self) -> dict[str, Any]:
62+
from airflow.serialization.encoders import encode_timezone
63+
6164
return {
65+
"timezone": encode_timezone(self._timezone),
6266
"input_format": self.input_format,
6367
"output_format": self.output_format,
6468
}
6569

6670
@classmethod
6771
def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
6872
return cls(
73+
timezone=parse_timezone(data["timezone"]),
6974
input_format=data["input_format"],
7075
output_format=data["output_format"],
7176
)
@@ -74,7 +79,7 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
7479
class HourlyMapper(_BaseTemporalMapper):
7580
"""Map a time-based partition key to hour."""
7681

77-
default_output_format = "%Y-%m-%dT%H"
82+
default_output_format = "%Y-%m-%dT%H%z"
7883

7984
def normalize(self, dt: datetime) -> datetime:
8085
return dt.replace(minute=0, second=0, microsecond=0)
@@ -83,7 +88,7 @@ def normalize(self, dt: datetime) -> datetime:
8388
class DailyMapper(_BaseTemporalMapper):
8489
"""Map a time-based partition key to day."""
8590

86-
default_output_format = "%Y-%m-%d"
91+
default_output_format = "%Y-%m-%d%z"
8792

8893
def normalize(self, dt: datetime) -> datetime:
8994
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
@@ -92,7 +97,7 @@ def normalize(self, dt: datetime) -> datetime:
9297
class WeeklyMapper(_BaseTemporalMapper):
9398
"""Map a time-based partition key to week."""
9499

95-
default_output_format = "%Y-%m-%d (W%V)"
100+
default_output_format = "%Y-%m-%d (W%V)%z"
96101

97102
def normalize(self, dt: datetime) -> datetime:
98103
start = dt - timedelta(days=dt.weekday())
@@ -102,7 +107,7 @@ def normalize(self, dt: datetime) -> datetime:
102107
class MonthlyMapper(_BaseTemporalMapper):
103108
"""Map a time-based partition key to month."""
104109

105-
default_output_format = "%Y-%m"
110+
default_output_format = "%Y-%m%z"
106111

107112
def normalize(self, dt: datetime) -> datetime:
108113
return dt.replace(
@@ -117,7 +122,7 @@ def normalize(self, dt: datetime) -> datetime:
117122
class QuarterlyMapper(_BaseTemporalMapper):
118123
"""Map a time-based partition key to quarter."""
119124

120-
default_output_format = "%Y-Q{quarter}"
125+
default_output_format = "%Y-Q{quarter}%z"
121126

122127
def normalize(self, dt: datetime) -> datetime:
123128
quarter = (dt.month - 1) // 3
@@ -139,7 +144,7 @@ def format(self, dt: datetime) -> str:
139144
class YearlyMapper(_BaseTemporalMapper):
140145
"""Map a time-based partition key to year."""
141146

142-
default_output_format = "%Y"
147+
default_output_format = "%Y%z"
143148

144149
def normalize(self, dt: datetime) -> datetime:
145150
return dt.replace(

0 commit comments

Comments
 (0)