Skip to content

Commit 7a64276

Browse files
authored
Autoset UTC timezone for datetimes loaded from the db (#2922)
* Load datetimes with utc timezone from db * Use tz aware datetimes for sqlalchemy filters
1 parent 67d9b27 commit 7a64276

21 files changed

+71
-87
lines changed

src/dstack/_internal/server/background/tasks/process_fleets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def _process_next_fleet():
4040
FleetModel.deleted == False,
4141
FleetModel.id.not_in(lockset),
4242
FleetModel.last_processed_at
43-
< get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
43+
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
4444
)
4545
.order_by(FleetModel.last_processed_at.asc())
4646
.limit(1)

src/dstack/_internal/server/background/tasks/process_idle_volumes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def _should_delete_volume(volume: VolumeModel) -> bool:
8282

8383
def _get_idle_time(volume: VolumeModel) -> datetime.timedelta:
8484
last_used = volume.last_job_processed_at or volume.created_at
85-
last_used_utc = last_used.replace(tzinfo=datetime.timezone.utc)
86-
idle_time = get_current_datetime() - last_used_utc
85+
idle_time = get_current_datetime() - last_used
8786
return max(idle_time, datetime.timedelta(0))
8887

8988

src/dstack/_internal/server/background/tasks/process_instances.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@
104104
from dstack._internal.server.services.runner import client as runner_client
105105
from dstack._internal.server.services.runner.client import HealthStatus
106106
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
107-
from dstack._internal.utils.common import get_current_datetime, run_async
107+
from dstack._internal.utils.common import (
108+
get_current_datetime,
109+
run_async,
110+
)
108111
from dstack._internal.utils.logging import get_logger
109112
from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
110113
from dstack._internal.utils.ssh import (
@@ -149,7 +152,7 @@ async def _process_next_instance():
149152
),
150153
InstanceModel.id.not_in(lockset),
151154
InstanceModel.last_processed_at
152-
< get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
155+
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
153156
)
154157
.options(lazyload(InstanceModel.jobs))
155158
.order_by(InstanceModel.last_processed_at.asc())
@@ -461,7 +464,7 @@ def _deploy_instance(
461464

462465
async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
463466
if instance.last_retry_at is not None:
464-
last_retry = instance.last_retry_at.replace(tzinfo=datetime.timezone.utc)
467+
last_retry = instance.last_retry_at
465468
if get_current_datetime() < last_retry + timedelta(minutes=1):
466469
return
467470

@@ -801,7 +804,7 @@ async def _check_instance(instance: InstanceModel) -> None:
801804
instance.name,
802805
extra={"instance_name": instance.name},
803806
)
804-
deadline = instance.termination_deadline.replace(tzinfo=datetime.timezone.utc)
807+
deadline = instance.termination_deadline
805808
if get_current_datetime() > deadline:
806809
instance.status = InstanceStatus.TERMINATING
807810
instance.termination_reason = "Termination deadline"
@@ -956,18 +959,12 @@ async def _terminate(instance: InstanceModel) -> None:
956959

957960
def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
958961
assert instance.last_termination_retry_at is not None
959-
return (
960-
instance.last_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
961-
+ TERMINATION_RETRY_TIMEOUT
962-
)
962+
return instance.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT
963963

964964

965965
def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime:
966966
assert instance.first_termination_retry_at is not None
967-
return (
968-
instance.first_termination_retry_at.replace(tzinfo=datetime.timezone.utc)
969-
+ TERMINATION_RETRY_MAX_DURATION
970-
)
967+
return instance.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION
971968

972969

973970
def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
@@ -1102,27 +1099,26 @@ async def _create_placement_group(
11021099

11031100

11041101
def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta:
1105-
last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)
1102+
last_time = instance.created_at
11061103
if instance.last_job_processed_at is not None:
1107-
last_time = instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc)
1104+
last_time = instance.last_job_processed_at
11081105
return get_current_datetime() - last_time
11091106

11101107

11111108
def _get_retry_duration_deadline(instance: InstanceModel, retry: Retry) -> datetime.datetime:
1112-
return instance.created_at.replace(tzinfo=datetime.timezone.utc) + timedelta(
1113-
seconds=retry.duration
1114-
)
1109+
return instance.created_at + timedelta(seconds=retry.duration)
11151110

11161111

11171112
def _get_provisioning_deadline(
11181113
instance: InstanceModel,
11191114
job_provisioning_data: JobProvisioningData,
11201115
) -> datetime.datetime:
1116+
assert instance.started_at is not None
11211117
timeout_interval = get_provisioning_timeout(
11221118
backend_type=job_provisioning_data.get_base_backend(),
11231119
instance_type_name=job_provisioning_data.instance_type.name,
11241120
)
1125-
return instance.started_at.replace(tzinfo=datetime.timezone.utc) + timeout_interval
1121+
return instance.started_at + timeout_interval
11261122

11271123

11281124
def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import uuid
44
from collections.abc import Iterable
5-
from datetime import timedelta, timezone
5+
from datetime import timedelta
66
from typing import Dict, List, Optional
77

88
from sqlalchemy import select
@@ -108,8 +108,7 @@ async def _process_next_running_job():
108108
RunModel.status.not_in([RunStatus.TERMINATING]),
109109
JobModel.id.not_in(lockset),
110110
JobModel.last_processed_at
111-
< common_utils.get_current_datetime().replace(tzinfo=None)
112-
- MIN_PROCESSING_INTERVAL,
111+
< common_utils.get_current_datetime() - MIN_PROCESSING_INTERVAL,
113112
)
114113
.order_by(JobModel.last_processed_at.asc())
115114
.limit(1)
@@ -801,7 +800,7 @@ def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
801800
return False
802801
return (
803802
common_utils.get_current_datetime()
804-
> job_model.disconnected_at.replace(tzinfo=timezone.utc) + JOB_DISCONNECTED_RETRY_TIMEOUT
803+
> job_model.disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT
805804
)
806805

807806

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def process_runs(batch_size: int = 1):
5656
async def _process_next_run():
5757
run_lock, run_lockset = get_locker(get_db().dialect_name).get_lockset(RunModel.__tablename__)
5858
job_lock, job_lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
59-
now = common.get_current_datetime().replace(tzinfo=None)
59+
now = common.get_current_datetime()
6060
async with get_session_ctx() as session:
6161
async with run_lock, job_lock:
6262
res = await session.execute(
@@ -370,9 +370,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
370370
)
371371
if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING:
372372
current_time = common.get_current_datetime()
373-
submit_to_provision_duration = (
374-
current_time - run_model.submitted_at.replace(tzinfo=datetime.timezone.utc)
375-
).total_seconds()
373+
submit_to_provision_duration = (current_time - run_model.submitted_at).total_seconds()
376374
logger.info(
377375
"%s: run took %.2f seconds from submission to provisioning.",
378376
fmt(run_model),

src/dstack/_internal/server/background/tasks/process_terminating_jobs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
)
1919
from dstack._internal.server.services.locking import get_locker
2020
from dstack._internal.server.services.logging import fmt
21-
from dstack._internal.utils.common import get_current_datetime, get_or_error
21+
from dstack._internal.utils.common import (
22+
get_current_datetime,
23+
get_or_error,
24+
)
2225
from dstack._internal.utils.logging import get_logger
2326

2427
logger = get_logger(__name__)
@@ -43,7 +46,10 @@ async def _process_next_terminating_job():
4346
.where(
4447
JobModel.id.not_in(job_lockset),
4548
JobModel.status == JobStatus.TERMINATING,
46-
or_(JobModel.remove_at.is_(None), JobModel.remove_at < get_current_datetime()),
49+
or_(
50+
JobModel.remove_at.is_(None),
51+
JobModel.remove_at < get_current_datetime(),
52+
),
4753
)
4854
.order_by(JobModel.last_processed_at.asc())
4955
.limit(1)

src/dstack/_internal/server/models.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
22
import uuid
3-
from datetime import datetime
3+
from datetime import datetime, timezone
44
from typing import Callable, List, Optional, Union
55

66
from sqlalchemy import (
@@ -51,9 +51,10 @@
5151

5252
class NaiveDateTime(TypeDecorator):
5353
"""
54-
A custom type decorator that ensures datetime objects are offset-naive when stored in the database.
55-
This is needed because we use datetimes in UTC only and store them as offset-naive.
56-
Some databases (e.g. Postgres) throw an error if the timezone is set.
54+
A custom type decorator that ensures datetime objects are offset-naive when stored in the database
55+
and offset-aware with UTC timezone when loaded from the database.
56+
This is because we use datetimes in UTC everywhere, and
57+
some databases (e.g. Postgres) throw an error if the timezone is set.
5758
"""
5859

5960
impl = DateTime
@@ -65,7 +66,9 @@ def process_bind_param(self, value, dialect):
6566
return value
6667

6768
def process_result_value(self, value, dialect):
68-
return value
69+
if value is None:
70+
return None
71+
return value.replace(tzinfo=timezone.utc)
6972

7073

7174
class DecryptedString(CoreModel):

src/dstack/_internal/server/services/fleets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from collections.abc import Callable
3-
from datetime import datetime, timezone
3+
from datetime import datetime
44
from functools import wraps
55
from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast
66

@@ -600,7 +600,7 @@ def fleet_model_to_fleet(
600600
name=fleet_model.name,
601601
project_name=fleet_model.project.name,
602602
spec=spec,
603-
created_at=fleet_model.created_at.replace(tzinfo=timezone.utc),
603+
created_at=fleet_model.created_at,
604604
status=fleet_model.status,
605605
status_message=fleet_model.status_message,
606606
instances=instances,

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import datetime
33
import uuid
4-
from datetime import timedelta, timezone
4+
from datetime import timedelta
55
from functools import partial
66
from typing import List, Optional, Sequence
77

@@ -557,7 +557,7 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway:
557557
region=gateway_model.region,
558558
wildcard_domain=gateway_model.wildcard_domain,
559559
default=gateway_model.project.default_gateway_id == gateway_model.id,
560-
created_at=gateway_model.created_at.replace(tzinfo=timezone.utc),
560+
created_at=gateway_model.created_at,
561561
status=gateway_model.status,
562562
status_message=gateway_model.status_message,
563563
configuration=configuration,

src/dstack/_internal/server/services/instances.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from collections.abc import Container, Iterable
3-
from datetime import datetime, timezone
3+
from datetime import datetime
44
from typing import Dict, List, Literal, Optional, Union
55

66
import gpuhunt
@@ -62,7 +62,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
6262
status=instance_model.status,
6363
unreachable=instance_model.unreachable,
6464
termination_reason=instance_model.termination_reason,
65-
created=instance_model.created_at.replace(tzinfo=timezone.utc),
65+
created=instance_model.created_at,
6666
total_blocks=instance_model.total_blocks,
6767
busy_blocks=instance_model.busy_blocks,
6868
)

0 commit comments

Comments
 (0)