diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 9c86bace56b42..d1a4c7c904868 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -18,10 +18,14 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any +from typing import TYPE_CHECKING, Any +from airflow._shared.timezones.timezone import make_aware, parse_timezone from airflow.partition_mappers.base import PartitionMapper +if TYPE_CHECKING: + from pendulum import FixedTimezone, Timezone + class _BaseTemporalMapper(PartitionMapper, ABC): """Base class for Temporal Partition Mappers.""" @@ -30,14 +34,21 @@ class _BaseTemporalMapper(PartitionMapper, ABC): def __init__( self, + *, + timezone: str | Timezone | FixedTimezone = "UTC", input_format: str = "%Y-%m-%dT%H:%M:%S", output_format: str | None = None, ): self.input_format = input_format self.output_format = output_format or self.default_output_format + if isinstance(timezone, str): + timezone = parse_timezone(timezone) + self._timezone = timezone def to_downstream(self, key: str) -> str: dt = datetime.strptime(key, self.input_format) + if dt.tzinfo is None: + dt = make_aware(dt, self._timezone) normalized = self.normalize(dt) return self.format(normalized) @@ -50,7 +61,10 @@ def format(self, dt: datetime) -> str: return dt.strftime(self.output_format) def serialize(self) -> dict[str, Any]: + from airflow.serialization.encoders import encode_timezone + return { + "timezone": encode_timezone(self._timezone), "input_format": self.input_format, "output_format": self.output_format, } @@ -58,6 +72,7 @@ def serialize(self) -> dict[str, Any]: @classmethod def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: return cls( + timezone=parse_timezone(data.get("timezone", "UTC")), input_format=data["input_format"], output_format=data["output_format"], ) diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py b/airflow-core/tests/unit/partition_mappers/test_temporal.py index 89ad98cebbccc..c6b5b1d760d71 100644 --- a/airflow-core/tests/unit/partition_mappers/test_temporal.py +++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py @@ -46,7 +46,7 @@ def test_to_downstream( mapper_cls: type[_BaseTemporalMapper], expected_downstream_key: str, ): - pm = mapper_cls() + pm = mapper_cls(timezone="UTC") assert pm.to_downstream("2026-02-10T14:30:45") == expected_downstream_key @pytest.mark.parametrize( @@ -61,8 +61,9 @@ def test_to_downstream( ], ) def test_serialize(self, mapper_cls: type[_BaseTemporalMapper], expected_outut_format: str): - pm = mapper_cls() + pm = mapper_cls(timezone="UTC") assert pm.serialize() == { + "timezone": "UTC", "input_format": "%Y-%m-%dT%H:%M:%S", "output_format": expected_outut_format, } @@ -81,6 +82,7 @@ def test_serialize(self, mapper_cls: type[_BaseTemporalMapper], expected_outut_f def test_deserialize(self, mapper_cls): pm = mapper_cls.deserialize( { + "timezone": "UTC", "input_format": "%Y-%m-%dT%H:%M:%S", "output_format": "customized-format", } @@ -88,3 +90,32 @@ def test_deserialize(self, mapper_cls): assert isinstance(pm, mapper_cls) assert pm.input_format == "%Y-%m-%dT%H:%M:%S" assert pm.output_format == "customized-format" + + @pytest.mark.parametrize( + "mapper_cls", + [ + StartOfHourMapper, + StartOfDayMapper, + StartOfWeekMapper, + StartOfMonthMapper, + StartOfQuarterMapper, + StartOfYearMapper, + ], + ) + def test_deserialize_legacy_no_timezone(self, mapper_cls): + """Deserializing data without a timezone key defaults to UTC.""" + pm = mapper_cls.deserialize( + { + "input_format": "%Y-%m-%dT%H:%M:%S", + "output_format": "customized-format", + } + ) + assert isinstance(pm, mapper_cls) + + def test_to_downstream_timezone_aware(self): + """Input is interpreted as local time in the given timezone.""" + pm = StartOfDayMapper(timezone="America/New_York") + # 2026-02-10T23:00:00 in New York local time → start-of-day is 2026-02-10 + assert pm.to_downstream("2026-02-10T23:00:00") == "2026-02-10" + # 2026-02-11T01:00:00 in New York local time → start-of-day is 2026-02-11 + assert pm.to_downstream("2026-02-11T01:00:00") == "2026-02-11"