-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlambda_pricing.py
More file actions
147 lines (125 loc) · 4.85 KB
/
lambda_pricing.py
File metadata and controls
147 lines (125 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# lambda_pricing.py — Resolve GPU $/hour from Lambda Cloud public API (optional; gateway uses if LAMBDA_CLOUD_API_KEY is set).
# API reference: https://docs.lambda.ai/api/cloud#listInstanceTypes
from __future__ import annotations
import logging
import os
from typing import Any
import httpx
_log = logging.getLogger("gateway.lambda_pricing")
LAMBDA_API_BASE = os.environ.get("LAMBDA_CLOUD_API_BASE", "https://cloud.lambdalabs.com/api/v1").rstrip("/")
def _auth_headers(api_key: str) -> dict[str, str]:
return {"Authorization": f"Bearer {api_key.strip()}"}
def _cents_to_usd_hour(cents: Any) -> float | None:
try:
v = float(cents)
except (TypeError, ValueError):
return None
if v < 0:
return None
return v / 100.0
def _merge_price_entry(prices: dict[str, float], name: str | None, cents: Any) -> None:
if not name:
return
usd = _cents_to_usd_hour(cents)
if usd is None:
return
key = str(name).strip()
if not key:
return
prev = prices.get(key)
if prev is not None and abs(prev - usd) > 1e-6:
_log.debug("Lambda price for %s: keeping %.4f (also saw %.4f)", key, prev, usd)
return
prices[key] = usd
def _extract_prices_from_obj(obj: Any, prices: dict[str, float]) -> None:
"""Walk nested JSON from GET /instance-types (shape varies by API version)."""
if isinstance(obj, dict):
it = obj.get("instance_type")
if isinstance(it, dict):
name = it.get("name")
cents = (
obj.get("price_cents_per_hour")
or it.get("price_cents_per_hour")
or obj.get("price_cents")
)
_merge_price_entry(prices, name, cents)
if "name" in obj and "price_cents_per_hour" in obj:
_merge_price_entry(prices, obj.get("name"), obj.get("price_cents_per_hour"))
for v in obj.values():
_extract_prices_from_obj(v, prices)
elif isinstance(obj, list):
for item in obj:
_extract_prices_from_obj(item, prices)
def fetch_instance_type_prices_usd_per_hour(api_key: str, client: httpx.Client | None = None) -> dict[str, float]:
"""Return map instance type name -> USD/hour for the whole instance (Lambda's posted rate)."""
own = client is None
if own:
client = httpx.Client(timeout=30.0)
try:
url = f"{LAMBDA_API_BASE}/instance-types"
r = client.get(url, headers=_auth_headers(api_key))
r.raise_for_status()
body = r.json()
finally:
if own:
client.close()
prices: dict[str, float] = {}
data = body.get("data", body)
_extract_prices_from_obj(data, prices)
if not prices:
_extract_prices_from_obj(body, prices)
return prices
def _running_instance_type_name(api_key: str, client: httpx.Client | None = None) -> str | None:
own = client is None
if own:
client = httpx.Client(timeout=30.0)
try:
url = f"{LAMBDA_API_BASE}/instances"
r = client.get(url, headers=_auth_headers(api_key))
r.raise_for_status()
body = r.json()
finally:
if own:
client.close()
rows = body.get("data")
if not isinstance(rows, list):
return None
active_like = frozenset({"active", "running", "booting", "pending"})
for inst in rows:
if not isinstance(inst, dict):
continue
status = str(inst.get("status", "")).lower()
if status and status not in active_like:
continue
it = inst.get("instance_type")
if isinstance(it, dict) and it.get("name"):
return str(it["name"])
if inst.get("instance_type_name"):
return str(inst["instance_type_name"])
return None
def resolve_gpu_hourly_usd_from_lambda(
api_key: str,
*,
instance_type: str | None = None,
) -> float | None:
"""
Look up current Lambda list price for the instance type.
- If ``instance_type`` is set (or env LAMBDA_INSTANCE_TYPE), use that name.
- Else use the first active-like row from GET /instances.
"""
explicit = (instance_type or os.environ.get("LAMBDA_INSTANCE_TYPE", "")).strip() or None
with httpx.Client(timeout=30.0) as client:
prices = fetch_instance_type_prices_usd_per_hour(api_key, client)
if not prices:
_log.warning("Lambda instance-types: no price map parsed (API shape may have changed).")
return None
name = explicit or _running_instance_type_name(api_key, client)
if not name:
_log.warning(
"Set LAMBDA_INSTANCE_TYPE (e.g. gpu_1x_a10) or run at least one active instance to infer type."
)
return None
if name not in prices:
_log.warning("Unknown Lambda instance type %r; known: %s", name, sorted(prices.keys())[:12])
return None
return prices[name]