Skip to content

Commit abb567e

Browse files
authored
Update the wallet&statistic apis (#86)
1 parent a279609 commit abb567e

3 files changed

Lines changed: 46 additions & 15 deletions

File tree

app/api/routes/statistic.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from fastapi import APIRouter, Depends, Query
22
from sqlalchemy.ext.asyncio import AsyncSession
3-
from sqlalchemy import select, desc, func
3+
from sqlalchemy import select, desc, func, case
44
from sqlalchemy.sql.functions import coalesce
55
from sqlalchemy import or_
66
from datetime import datetime, timedelta, UTC
@@ -27,12 +27,12 @@
2727

2828

2929
# I want a query parameter called "offset: <int>" and "limit: <int>"
30-
@router.get("/usage/realtime", response_model=list[UsageRealtimeResponse])
30+
@router.get("/usage/realtime", response_model=UsageRealtimeResponse)
3131
async def get_usage_realtime(
3232
current_user: User = Depends(get_user_by_api_key),
3333
db: AsyncSession = Depends(get_async_db),
34-
offset: int = Query(0, ge=0),
35-
limit: int = Query(10, ge=1),
34+
page_index: int = Query(0, ge=0),
35+
page_size: int = Query(10, ge=1),
3636
forge_key: str = Query(None, min_length=1),
3737
provider_name: str = Query(None, min_length=1),
3838
model_name: str = Query(None, min_length=1),
@@ -66,9 +66,11 @@ async def get_usage_realtime(
6666
UsageTracker.output_tokens.label("output_tokens"),
6767
UsageTracker.cached_tokens.label("cached_tokens"),
6868
UsageTracker.cost.label("cost"),
69+
UsageTracker.billable.label("billable"),
6970
func.extract(
7071
"epoch", UsageTracker.updated_at - UsageTracker.created_at
7172
).label("duration"),
73+
func.count().over().label("total"),
7274
)
7375
.join(ProviderKey, UsageTracker.provider_key_id == ProviderKey.id)
7476
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
@@ -87,18 +89,19 @@ async def get_usage_realtime(
8789
UsageTracker.updated_at.is_not(None),
8890
)
8991
.order_by(desc(UsageTracker.created_at))
90-
.offset(offset)
91-
.limit(limit)
92+
.offset(page_index * page_size)
93+
.limit(page_size)
9294
)
9395

9496
# Execute the query
9597
result = await db.execute(query)
9698
rows = result.fetchall()
9799

98100
# Convert to list of dictionaries
99-
usage_stats = []
101+
items = []
102+
total = 0
100103
for row in rows:
101-
usage_stats.append(
104+
items.append(
102105
{
103106
"timestamp": row.timestamp,
104107
"forge_key": row.forge_key,
@@ -109,20 +112,27 @@ async def get_usage_realtime(
109112
"output_tokens": row.output_tokens,
110113
"cached_tokens": row.cached_tokens,
111114
"cost": decimal.Decimal(row.cost).normalize(),
115+
"billable": row.billable,
112116
"duration": round(float(row.duration), 2)
113117
if row.duration is not None
114118
else 0.0,
115119
}
116120
)
117-
return [UsageRealtimeResponse(**usage_stat) for usage_stat in usage_stats]
121+
total = row.total
122+
return UsageRealtimeResponse(
123+
items=items,
124+
total=total,
125+
page_size=page_size,
126+
page_index=page_index,
127+
)
118128

119129

120-
@router.get("/usage/realtime/clerk", response_model=list[UsageRealtimeResponse])
130+
@router.get("/usage/realtime/clerk", response_model=UsageRealtimeResponse)
121131
async def get_usage_realtime_clerk(
122132
current_user: User = Depends(get_current_active_user_from_clerk),
123133
db: AsyncSession = Depends(get_async_db),
124-
offset: int = Query(0, ge=0),
125-
limit: int = Query(10, ge=1),
134+
page_index: int = Query(0, ge=0),
135+
page_size: int = Query(10, ge=1),
126136
forge_key: str = Query(None, min_length=1),
127137
provider_name: str = Query(None, min_length=1),
128138
model_name: str = Query(None, min_length=1),
@@ -132,8 +142,8 @@ async def get_usage_realtime_clerk(
132142
return await get_usage_realtime(
133143
current_user,
134144
db,
135-
offset,
136-
limit,
145+
page_index,
146+
page_size,
137147
forge_key,
138148
provider_name,
139149
model_name,
@@ -186,6 +196,7 @@ async def get_usage_summary(
186196
func.sum(UsageTracker.output_tokens).label("output_tokens"),
187197
func.sum(UsageTracker.cached_tokens).label("cached_tokens"),
188198
func.sum(UsageTracker.cost).label("cost"),
199+
func.sum(case((UsageTracker.billable, UsageTracker.cost), else_=0)).label("charged_cost"),
189200
)
190201
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
191202
.where(
@@ -208,6 +219,7 @@ async def get_usage_summary(
208219
"breakdown": [],
209220
"total_tokens": 0,
210221
"total_cost": 0,
222+
"total_charged_cost": 0,
211223
"total_input_tokens": 0,
212224
"total_output_tokens": 0,
213225
"total_cached_tokens": 0,
@@ -217,6 +229,7 @@ async def get_usage_summary(
217229
"forge_key": row.forge_key,
218230
"tokens": row.tokens,
219231
"cost": decimal.Decimal(row.cost).normalize(),
232+
"charged_cost": decimal.Decimal(row.charged_cost).normalize(),
220233
"input_tokens": row.input_tokens,
221234
"output_tokens": row.output_tokens,
222235
"cached_tokens": row.cached_tokens,
@@ -226,6 +239,9 @@ async def get_usage_summary(
226239
data_points[row.time_point]["total_cost"] += decimal.Decimal(
227240
row.cost
228241
).normalize()
242+
data_points[row.time_point]["total_charged_cost"] += decimal.Decimal(
243+
row.charged_cost
244+
).normalize()
229245
data_points[row.time_point]["total_input_tokens"] += row.input_tokens
230246
data_points[row.time_point]["total_output_tokens"] += row.output_tokens
231247
data_points[row.time_point]["total_cached_tokens"] += row.cached_tokens
@@ -236,6 +252,7 @@ async def get_usage_summary(
236252
breakdown=data_point["breakdown"],
237253
total_tokens=data_point["total_tokens"],
238254
total_cost=data_point["total_cost"],
255+
total_charged_cost=data_point["total_charged_cost"],
239256
total_input_tokens=data_point["total_input_tokens"],
240257
total_output_tokens=data_point["total_output_tokens"],
241258
total_cached_tokens=data_point["total_cached_tokens"],
@@ -292,6 +309,7 @@ async def get_forge_keys_usage(
292309
func.sum(UsageTracker.output_tokens).label("output_tokens"),
293310
func.sum(UsageTracker.cached_tokens).label("cached_tokens"),
294311
func.sum(UsageTracker.cost).label("cost"),
312+
func.sum(case((UsageTracker.billable, UsageTracker.cost), else_=0)).label("charged_cost"),
295313
)
296314
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
297315
.where(
@@ -311,6 +329,7 @@ async def get_forge_keys_usage(
311329
forge_key=row.forge_key,
312330
tokens=row.tokens,
313331
cost=decimal.Decimal(row.cost).normalize(),
332+
charged_cost=decimal.Decimal(row.charged_cost).normalize(),
314333
input_tokens=row.input_tokens,
315334
output_tokens=row.output_tokens,
316335
cached_tokens=row.cached_tokens,

app/api/routes/wallet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ async def get_wallet_balance_clerk(
5151
return await get_wallet_balance(user, db)
5252

5353
class TransactionHistoryItem(BaseModel):
54+
transaction_id: str
5455
currency: str
5556
amount: Decimal
5657
status: str
@@ -75,6 +76,7 @@ async def get_wallet_transactions_history(
7576
# I would also want to get the total count of the transactions within one sql query
7677
query = (
7778
select(
79+
StripePayment.id,
7880
StripePayment.currency,
7981
StripePayment.amount,
8082
StripePayment.status,
@@ -92,6 +94,7 @@ async def get_wallet_transactions_history(
9294
return TransactionHistoryResponse(
9395
items=[
9496
TransactionHistoryItem(
97+
transaction_id=transaction.id,
9598
currency=transaction.currency,
9699
# Convert cents to dollars for USD
97100
amount=transaction.amount / 100.0 if transaction.currency == "USD" else transaction.amount,

app/api/schemas/statistic.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def mask_forge_name_or_key(v: str) -> str:
1212
# Otherwise, return the original value (user customized name)
1313
return v
1414

15-
class UsageRealtimeResponse(BaseModel):
15+
class UsageRealtimeItem(BaseModel):
1616
timestamp: datetime | str
1717
forge_key: str
1818
provider_name: str
@@ -23,6 +23,7 @@ class UsageRealtimeResponse(BaseModel):
2323
cached_tokens: int
2424
duration: float
2525
cost: decimal.Decimal
26+
billable: bool
2627

2728
@field_validator('forge_key')
2829
@classmethod
@@ -36,11 +37,17 @@ def convert_timestamp_to_iso(cls, v: datetime | str) -> str:
3637
return v
3738
return v.isoformat()
3839

40+
class UsageRealtimeResponse(BaseModel):
41+
total: int
42+
items: list[UsageRealtimeItem]
43+
page_size: int
44+
page_index: int
3945

4046
class UsageSummaryBreakdown(BaseModel):
4147
forge_key: str
4248
tokens: int
4349
cost: decimal.Decimal
50+
charged_cost: decimal.Decimal
4451
input_tokens: int
4552
output_tokens: int
4653
cached_tokens: int
@@ -56,6 +63,7 @@ class UsageSummaryResponse(BaseModel):
5663
breakdown: list[UsageSummaryBreakdown]
5764
total_tokens: int
5865
total_cost: decimal.Decimal
66+
total_charged_cost: decimal.Decimal
5967
total_input_tokens: int
6068
total_output_tokens: int
6169
total_cached_tokens: int
@@ -72,6 +80,7 @@ class ForgeKeysUsageSummaryResponse(BaseModel):
7280
forge_key: str
7381
tokens: int
7482
cost: decimal.Decimal
83+
charged_cost: decimal.Decimal
7584
input_tokens: int
7685
output_tokens: int
7786
cached_tokens: int

0 commit comments

Comments
 (0)