From d91d90602a446ab8f6aa9fe5794218656cae756c Mon Sep 17 00:00:00 2001 From: krakenhavoc Date: Thu, 12 Mar 2026 19:41:06 +0000 Subject: [PATCH 1/3] feat: add support for prorated billing --- .../src/billing/billing.controller.spec.ts | 42 ++++++ backend/src/billing/billing.controller.ts | 30 ++++ backend/src/billing/billing.service.spec.ts | 137 ++++++++++++++++- backend/src/billing/billing.service.ts | 117 ++++++++++++++- backend/src/billing/dto/upgrade-plan.dto.ts | 13 ++ frontend/src/pages/Billing.tsx | 140 ++++++++++++++++-- frontend/src/services/billingService.ts | 21 +++ shared/src/constants/routes.ts | 2 + shared/src/types/subscription.ts | 14 ++ 9 files changed, 502 insertions(+), 14 deletions(-) create mode 100644 backend/src/billing/dto/upgrade-plan.dto.ts diff --git a/backend/src/billing/billing.controller.spec.ts b/backend/src/billing/billing.controller.spec.ts index 5bc534e..3812386 100644 --- a/backend/src/billing/billing.controller.spec.ts +++ b/backend/src/billing/billing.controller.spec.ts @@ -17,6 +17,8 @@ describe('BillingController', () => { createCheckoutSession: jest.fn(), getSubscription: jest.fn(), createPortalSession: jest.fn(), + previewUpgrade: jest.fn(), + upgradeSubscription: jest.fn(), constructWebhookEvent: jest.fn(), handleWebhookEvent: jest.fn(), }; @@ -98,6 +100,46 @@ describe('BillingController', () => { }); }); + describe('previewUpgrade', () => { + it('returns preview from service', async () => { + const preview = { + immediateAmountCents: 3350, + currency: 'usd', + targetPlan: 'team', + currentPeriodEnd: '2026-04-09T00:00:00.000Z', + }; + mockBillingService.previewUpgrade.mockResolvedValue(preview); + + const result = await controller.previewUpgrade(mockReq, { + plan: 'team', + }); + expect(result).toEqual(preview); + expect(mockBillingService.previewUpgrade).toHaveBeenCalledWith( + userId, + 'team', + ); + }); + }); + + describe('upgrade', () => { + it('returns upgraded subscription from service', async () => { + const upgraded = { + plan: 'team', + status: 'active', + currentPeriodEnd: '2026-04-09T00:00:00.000Z', + cancelAtPeriodEnd: false, + }; + mockBillingService.upgradeSubscription.mockResolvedValue(upgraded); + + const result = await controller.upgrade(mockReq, { plan: 'team' }); + expect(result).toEqual(upgraded); + expect(mockBillingService.upgradeSubscription).toHaveBeenCalledWith( + userId, + 'team', + ); + }); + }); + describe('webhook', () => { it('throws on missing signature', async () => { const req = { rawBody: Buffer.from('{}') } as any; diff --git a/backend/src/billing/billing.controller.ts b/backend/src/billing/billing.controller.ts index f87172b..731631d 100644 --- a/backend/src/billing/billing.controller.ts +++ b/backend/src/billing/billing.controller.ts @@ -19,6 +19,7 @@ import { import { SkipThrottle } from '@nestjs/throttler'; import { BillingService } from './billing.service'; import { CreateCheckoutDto } from './dto/create-checkout.dto'; +import { UpgradePlanDto } from './dto/upgrade-plan.dto'; import { JwtOrApiKeyGuard } from '../auth/guards/jwt-or-api-key.guard'; import type { RequestWithUser } from '../auth/interfaces/request-with-user.interface'; import { RateLimitCategoryDecorator } from '../throttler/decorators/rate-limit-category.decorator'; @@ -91,6 +92,35 @@ export class BillingController { return { portalUrl }; } + @Post('upgrade/preview') + @UseGuards(JwtOrApiKeyGuard) + @ApiBearerAuth() + @ApiOperation({ summary: 'Preview prorated cost for a plan upgrade' }) + @ApiResponse({ status: 200, description: 'Upgrade preview details' }) + @ApiResponse({ status: 400, description: 'Invalid upgrade path' }) + @ApiResponse({ status: 401, description: 'Unauthorized' }) + @ApiResponse({ status: 404, description: 'No active subscription' }) + @RateLimitCategoryDecorator(RateLimitCategory.AUTHENTICATED_READ) + async previewUpgrade( + @Request() req: RequestWithUser, + @Body() dto: UpgradePlanDto, + ) { + return this.billingService.previewUpgrade(req.user.userId, dto.plan); + } + + @Post('upgrade') + @UseGuards(JwtOrApiKeyGuard) + @ApiBearerAuth() + @ApiOperation({ summary: 'Upgrade subscription with proration' }) + @ApiResponse({ status: 201, description: 'Subscription upgraded' }) + @ApiResponse({ status: 400, description: 'Invalid upgrade path' }) + @ApiResponse({ status: 401, description: 'Unauthorized' }) + @ApiResponse({ status: 404, description: 'No active subscription' }) + @RateLimitCategoryDecorator(RateLimitCategory.AUTHENTICATED_WRITE) + async upgrade(@Request() req: RequestWithUser, @Body() dto: UpgradePlanDto) { + return this.billingService.upgradeSubscription(req.user.userId, dto.plan); + } + @Post('webhook') @SkipThrottle() @ApiExcludeEndpoint() diff --git a/backend/src/billing/billing.service.spec.ts b/backend/src/billing/billing.service.spec.ts index a26059e..36d35eb 100644 --- a/backend/src/billing/billing.service.spec.ts +++ b/backend/src/billing/billing.service.spec.ts @@ -10,7 +10,8 @@ const mockStripe = { customers: { create: jest.fn() }, checkout: { sessions: { create: jest.fn() } }, billingPortal: { sessions: { create: jest.fn() } }, - subscriptions: { retrieve: jest.fn() }, + subscriptions: { retrieve: jest.fn(), update: jest.fn() }, + invoices: { createPreview: jest.fn() }, webhooks: { constructEvent: jest.fn() }, }; jest.mock('stripe', () => { @@ -59,6 +60,7 @@ describe('BillingService', () => { KK_STRIPE_SECRET_KEY: 'sk_test_123', KK_STRIPE_WEBHOOK_SECRET: 'whsec_test', KK_STRIPE_PRICE_STARTER: 'price_starter_123', + KK_STRIPE_PRICE_TEAM: 'price_team_456', KK_APP_DOMAIN: 'app.krakenkey.io', }; return config[key] ?? defaultValue ?? ''; @@ -289,5 +291,138 @@ describe('BillingService', () => { }), ); }); + + it('syncs plan from price ID on subscription.updated', async () => { + mockRepository.findOne.mockResolvedValue({ ...mockSubscription }); + + await service.handleWebhookEvent({ + type: 'customer.subscription.updated', + data: { + object: { + id: 'sub_test123', + status: 'active', + items: { + data: [ + { + current_period_end: + Math.floor(Date.now() / 1000) + 86400 * 30, + price: { id: 'price_team_456' }, + }, + ], + }, + cancel_at_period_end: false, + }, + }, + } as any); + + expect(mockRepository.save).toHaveBeenCalledWith( + expect.objectContaining({ + plan: 'team', + }), + ); + }); + }); + + describe('previewUpgrade', () => { + it('throws NotFoundException when no subscription exists', async () => { + mockRepository.findOne.mockResolvedValue(null); + await expect(service.previewUpgrade(userId, 'team')).rejects.toThrow( + NotFoundException, + ); + }); + + it('throws BadRequestException when subscription is not active', async () => { + mockRepository.findOne.mockResolvedValue({ + ...mockSubscription, + status: 'past_due', + }); + await expect(service.previewUpgrade(userId, 'team')).rejects.toThrow( + BadRequestException, + ); + }); + + it('throws BadRequestException when target plan is not higher', async () => { + mockRepository.findOne.mockResolvedValue({ + ...mockSubscription, + plan: 'team', + }); + await expect(service.previewUpgrade(userId, 'starter')).rejects.toThrow( + BadRequestException, + ); + }); + + it('throws BadRequestException for same plan', async () => { + mockRepository.findOne.mockResolvedValue(mockSubscription); + await expect(service.previewUpgrade(userId, 'starter')).rejects.toThrow( + BadRequestException, + ); + }); + + it('returns preview for valid upgrade', async () => { + mockRepository.findOne.mockResolvedValue(mockSubscription); + mockStripe.subscriptions.retrieve.mockResolvedValue({ + items: { data: [{ id: 'si_item1' }] }, + }); + mockStripe.invoices.createPreview.mockResolvedValue({ + amount_due: 3350, + currency: 'usd', + }); + + const result = await service.previewUpgrade(userId, 'team'); + expect(result).toEqual({ + immediateAmountCents: 3350, + currency: 'usd', + targetPlan: 'team', + currentPeriodEnd: mockSubscription.currentPeriodEnd!.toISOString(), + }); + expect(mockStripe.invoices.createPreview).toHaveBeenCalledWith({ + subscription: 'sub_test123', + subscription_details: { + items: [{ id: 'si_item1', price: 'price_team_456' }], + proration_behavior: 'create_prorations', + }, + }); + }); + }); + + describe('upgradeSubscription', () => { + it('throws NotFoundException when no subscription exists', async () => { + mockRepository.findOne.mockResolvedValue(null); + await expect(service.upgradeSubscription(userId, 'team')).rejects.toThrow( + NotFoundException, + ); + }); + + it('throws BadRequestException when target plan is not higher', async () => { + mockRepository.findOne.mockResolvedValue(mockSubscription); + await expect( + service.upgradeSubscription(userId, 'starter'), + ).rejects.toThrow(BadRequestException); + }); + + it('upgrades subscription with proration', async () => { + mockRepository.findOne.mockResolvedValue({ ...mockSubscription }); + mockStripe.subscriptions.retrieve.mockResolvedValue({ + items: { data: [{ id: 'si_item1' }] }, + }); + mockStripe.subscriptions.update.mockResolvedValue({}); + mockRepository.save.mockResolvedValue({ + ...mockSubscription, + plan: 'team', + }); + + const result = await service.upgradeSubscription(userId, 'team'); + expect(result.plan).toBe('team'); + expect(mockStripe.subscriptions.update).toHaveBeenCalledWith( + 'sub_test123', + { + items: [{ id: 'si_item1', price: 'price_team_456' }], + proration_behavior: 'create_prorations', + }, + ); + expect(mockRepository.save).toHaveBeenCalledWith( + expect.objectContaining({ plan: 'team' }), + ); + }); }); }); diff --git a/backend/src/billing/billing.service.ts b/backend/src/billing/billing.service.ts index 71a4d34..457f01c 100644 --- a/backend/src/billing/billing.service.ts +++ b/backend/src/billing/billing.service.ts @@ -10,11 +10,20 @@ import { ConfigService } from '@nestjs/config'; import Stripe from 'stripe'; import { Subscription } from './entities/subscription.entity'; +const PLAN_ORDER: Record = { + free: 0, + starter: 1, + team: 2, + business: 3, + enterprise: 4, +}; + @Injectable() export class BillingService { private readonly logger = new Logger(BillingService.name); private readonly stripe: Stripe; private readonly priceMap: Record; + private readonly reversePriceMap: Record; private readonly webhookSecret: string; constructor( @@ -40,6 +49,11 @@ export class BillingService { const teamPrice = this.configService.get('KK_STRIPE_PRICE_TEAM'); if (teamPrice) this.priceMap['team'] = teamPrice; + + this.reversePriceMap = {}; + for (const [plan, priceId] of Object.entries(this.priceMap)) { + this.reversePriceMap[priceId] = plan; + } } async getOrCreateCustomer( @@ -118,6 +132,102 @@ export class BillingService { return session.url; } + async previewUpgrade( + userId: string, + newPlan: string, + ): Promise<{ + immediateAmountCents: number; + currency: string; + targetPlan: string; + currentPeriodEnd: string; + }> { + const sub = await this.validateUpgrade(userId, newPlan); + const newPriceId = this.priceMap[newPlan]; + + const stripeSub = await this.stripe.subscriptions.retrieve( + sub.stripeSubscriptionId!, + ); + const itemId = stripeSub.items.data[0].id; + + const upcomingInvoice = await this.stripe.invoices.createPreview({ + subscription: sub.stripeSubscriptionId!, + subscription_details: { + items: [{ id: itemId, price: newPriceId }], + proration_behavior: 'create_prorations', + }, + }); + + return { + immediateAmountCents: upcomingInvoice.amount_due, + currency: upcomingInvoice.currency, + targetPlan: newPlan, + currentPeriodEnd: sub.currentPeriodEnd!.toISOString(), + }; + } + + async upgradeSubscription( + userId: string, + newPlan: string, + ): Promise<{ + plan: string; + status: string; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; + }> { + const sub = await this.validateUpgrade(userId, newPlan); + const newPriceId = this.priceMap[newPlan]; + + const stripeSub = await this.stripe.subscriptions.retrieve( + sub.stripeSubscriptionId!, + ); + const itemId = stripeSub.items.data[0].id; + + await this.stripe.subscriptions.update(sub.stripeSubscriptionId!, { + items: [{ id: itemId, price: newPriceId }], + proration_behavior: 'create_prorations', + }); + + sub.plan = newPlan; + await this.subscriptionRepository.save(sub); + + this.logger.log(`Subscription upgraded: user=${userId} newPlan=${newPlan}`); + + return { + plan: sub.plan, + status: sub.status, + currentPeriodEnd: sub.currentPeriodEnd?.toISOString() ?? null, + cancelAtPeriodEnd: sub.cancelAtPeriodEnd, + }; + } + + private async validateUpgrade( + userId: string, + newPlan: string, + ): Promise { + const sub = await this.subscriptionRepository.findOne({ + where: { userId }, + }); + if (!sub || !sub.stripeSubscriptionId) { + throw new NotFoundException('No active subscription found'); + } + if (sub.status !== 'active') { + throw new BadRequestException( + 'Cannot upgrade a subscription that is not active', + ); + } + if ((PLAN_ORDER[newPlan] ?? 0) <= (PLAN_ORDER[sub.plan] ?? 0)) { + throw new BadRequestException( + `Cannot upgrade from ${sub.plan} to ${newPlan}`, + ); + } + if (!this.priceMap[newPlan]) { + throw new BadRequestException( + `Plan "${newPlan}" is not available for purchase`, + ); + } + return sub; + } + constructWebhookEvent(rawBody: Buffer, signature: string): Stripe.Event { return this.stripe.webhooks.constructEvent( rawBody, @@ -213,9 +323,14 @@ export class BillingService { sub.currentPeriodEnd = new Date(this.extractPeriodEnd(stripeSub) * 1000); sub.cancelAtPeriodEnd = stripeSub.cancel_at_period_end; + const priceId = stripeSub.items?.data?.[0]?.price?.id; + if (priceId && this.reversePriceMap[priceId]) { + sub.plan = this.reversePriceMap[priceId]; + } + await this.subscriptionRepository.save(sub); this.logger.log( - `Subscription updated: user=${sub.userId} status=${sub.status}`, + `Subscription updated: user=${sub.userId} plan=${sub.plan} status=${sub.status}`, ); } diff --git a/backend/src/billing/dto/upgrade-plan.dto.ts b/backend/src/billing/dto/upgrade-plan.dto.ts new file mode 100644 index 0000000..790bd86 --- /dev/null +++ b/backend/src/billing/dto/upgrade-plan.dto.ts @@ -0,0 +1,13 @@ +import { IsNotEmpty, IsIn } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class UpgradePlanDto { + @ApiProperty({ + description: 'Plan to upgrade to', + example: 'team', + enum: ['starter', 'team', 'business', 'enterprise'], + }) + @IsNotEmpty() + @IsIn(['starter', 'team', 'business', 'enterprise']) + plan: string; +} diff --git a/frontend/src/pages/Billing.tsx b/frontend/src/pages/Billing.tsx index bcd2a98..bcfe2f0 100644 --- a/frontend/src/pages/Billing.tsx +++ b/frontend/src/pages/Billing.tsx @@ -3,18 +3,26 @@ import { CreditCard, ExternalLink } from 'lucide-react'; import { PageHeader } from '../components/ui/PageHeader'; import { Card } from '../components/ui/Card'; import { Button } from '../components/ui/Button'; +import { Modal } from '../components/ui/Modal'; import { PlanBadge } from '../components/ui/PlanBadge'; import { fetchSubscription, createCheckout, createPortalSession, + previewUpgrade, + upgradeSubscription, } from '../services/billingService'; -import type { Subscription } from '@krakenkey/shared'; +import type { Subscription, UpgradePreviewResponse } from '@krakenkey/shared'; export default function Billing() { const [subscription, setSubscription] = useState(null); const [loading, setLoading] = useState(true); const [actionLoading, setActionLoading] = useState(false); + const [upgradeModal, setUpgradeModal] = useState<{ + plan: string; + preview: UpgradePreviewResponse; + } | null>(null); + const [upgradeError, setUpgradeError] = useState(null); useEffect(() => { loadSubscription(); @@ -31,12 +39,41 @@ export default function Billing() { } } + const plan = subscription?.plan ?? 'free'; + const isPaid = plan !== 'free'; + const isPastDue = subscription?.status === 'past_due'; + const isCanceling = subscription?.cancelAtPeriodEnd === true; + async function handleUpgrade(targetPlan: string) { setActionLoading(true); + setUpgradeError(null); + try { + if (isPaid) { + const preview = await previewUpgrade(targetPlan); + setUpgradeModal({ plan: targetPlan, preview }); + } else { + const { sessionUrl } = await createCheckout(targetPlan); + window.location.href = sessionUrl; + return; + } + } catch { + setUpgradeError('Failed to load upgrade details. Please try again.'); + } finally { + setActionLoading(false); + } + } + + async function handleConfirmUpgrade() { + if (!upgradeModal) return; + setActionLoading(true); + setUpgradeError(null); try { - const { sessionUrl } = await createCheckout(targetPlan); - window.location.href = sessionUrl; + await upgradeSubscription(upgradeModal.plan); + setUpgradeModal(null); + await loadSubscription(); } catch { + setUpgradeError('Upgrade failed. Please try again.'); + } finally { setActionLoading(false); } } @@ -51,10 +88,31 @@ export default function Billing() { } } - const plan = subscription?.plan ?? 'free'; - const isPaid = plan !== 'free'; - const isPastDue = subscription?.status === 'past_due'; - const isCanceling = subscription?.cancelAtPeriodEnd === true; + function formatAmount(cents: number): string { + return `$${(cents / 100).toFixed(2)}`; + } + + function formatDate(dateStr: string): string { + return new Date(dateStr).toLocaleDateString(undefined, { + year: 'numeric', + month: 'long', + day: 'numeric', + }); + } + + const PLAN_LABELS: Record = { + starter: 'Starter', + team: 'Team', + business: 'Business', + enterprise: 'Enterprise', + }; + + const PLAN_PRICES: Record = { + starter: '$29/mo', + team: '$79/mo', + business: '$199/mo', + enterprise: 'Custom', + }; if (loading) { return ( @@ -114,11 +172,7 @@ export default function Billing() { {isPaid && subscription?.currentPeriodEnd && (

- Current period ends{' '} - {new Date(subscription.currentPeriodEnd).toLocaleDateString( - undefined, - { year: 'numeric', month: 'long', day: 'numeric' }, - )} + Current period ends {formatDate(subscription.currentPeriodEnd)}

)} @@ -213,6 +267,68 @@ export default function Billing() { )} + + {/* Upgrade Confirmation Modal */} + { + setUpgradeModal(null); + setUpgradeError(null); + }} + title={`Upgrade to ${upgradeModal ? PLAN_LABELS[upgradeModal.plan] : ''}`} + > + {upgradeModal && ( +
+

+ You'll be charged{' '} + + {formatAmount(upgradeModal.preview.immediateAmountCents)} + {' '} + immediately for the remaining days in your current billing period + (through {formatDate(upgradeModal.preview.currentPeriodEnd)}). +

+

+ Your subscription will then renew at the{' '} + + {PLAN_LABELS[upgradeModal.plan]} + {' '} + rate of{' '} + + {PLAN_PRICES[upgradeModal.plan]} + {' '} + on your next billing date. +

+ {isCanceling && ( +

+ Upgrading will also resume your subscription (scheduled + cancellation will be cleared). +

+ )} + {upgradeError && ( +

{upgradeError}

+ )} +
+ + +
+
+ )} +
); } diff --git a/frontend/src/services/billingService.ts b/frontend/src/services/billingService.ts index bbf94e9..75b155c 100644 --- a/frontend/src/services/billingService.ts +++ b/frontend/src/services/billingService.ts @@ -4,6 +4,8 @@ import type { Subscription, CheckoutResponse, PortalResponse, + UpgradePreviewResponse, + UpgradeResponse, } from '@krakenkey/shared'; export async function fetchSubscription(): Promise { @@ -23,3 +25,22 @@ export async function createPortalSession(): Promise { const response = await api.post(API_ROUTES.BILLING.PORTAL); return response.data; } + +export async function previewUpgrade( + plan: string, +): Promise { + const response = await api.post( + API_ROUTES.BILLING.UPGRADE_PREVIEW, + { plan }, + ); + return response.data; +} + +export async function upgradeSubscription( + plan: string, +): Promise { + const response = await api.post(API_ROUTES.BILLING.UPGRADE, { + plan, + }); + return response.data; +} diff --git a/shared/src/constants/routes.ts b/shared/src/constants/routes.ts index fa01c5e..b55ae4f 100644 --- a/shared/src/constants/routes.ts +++ b/shared/src/constants/routes.ts @@ -34,6 +34,8 @@ export const API_ROUTES = { SUBSCRIPTION: '/billing/subscription', PORTAL: '/billing/portal', WEBHOOK: '/billing/webhook', + UPGRADE_PREVIEW: '/billing/upgrade/preview', + UPGRADE: '/billing/upgrade', }, ORGANIZATIONS: { BASE: '/organizations', diff --git a/shared/src/types/subscription.ts b/shared/src/types/subscription.ts index 1db207a..e0d7340 100644 --- a/shared/src/types/subscription.ts +++ b/shared/src/types/subscription.ts @@ -32,3 +32,17 @@ export interface CheckoutResponse { export interface PortalResponse { portalUrl: string; } + +export interface UpgradePreviewResponse { + immediateAmountCents: number; + currency: string; + targetPlan: SubscriptionPlan; + currentPeriodEnd: string; +} + +export interface UpgradeResponse { + plan: SubscriptionPlan; + status: SubscriptionStatus; + currentPeriodEnd: string | null; + cancelAtPeriodEnd: boolean; +} From e2ca0fdabbea08a024f6c3166d3f4b50e2cd6239 Mon Sep 17 00:00:00 2001 From: krakenhavoc Date: Thu, 12 Mar 2026 20:46:07 +0000 Subject: [PATCH 2/3] fix: proration math --- backend/src/billing/billing.service.spec.ts | 26 +++++++++++++++++++-- backend/src/billing/billing.service.ts | 16 ++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/backend/src/billing/billing.service.spec.ts b/backend/src/billing/billing.service.spec.ts index 36d35eb..e99cfe0 100644 --- a/backend/src/billing/billing.service.spec.ts +++ b/backend/src/billing/billing.service.spec.ts @@ -364,13 +364,35 @@ describe('BillingService', () => { items: { data: [{ id: 'si_item1' }] }, }); mockStripe.invoices.createPreview.mockResolvedValue({ - amount_due: 3350, + amount_due: 12900, currency: 'usd', + lines: { + data: [ + { + amount: -1935, + parent: { + subscription_item_details: { proration: true }, + }, + }, + { + amount: 5270, + parent: { + subscription_item_details: { proration: true }, + }, + }, + { + amount: 7900, + parent: { + subscription_item_details: { proration: false }, + }, + }, + ], + }, }); const result = await service.previewUpgrade(userId, 'team'); expect(result).toEqual({ - immediateAmountCents: 3350, + immediateAmountCents: 3335, currency: 'usd', targetPlan: 'team', currentPeriodEnd: mockSubscription.currentPeriodEnd!.toISOString(), diff --git a/backend/src/billing/billing.service.ts b/backend/src/billing/billing.service.ts index 457f01c..88afa76 100644 --- a/backend/src/billing/billing.service.ts +++ b/backend/src/billing/billing.service.ts @@ -157,8 +157,22 @@ export class BillingService { }, }); + // Sum only the proration line items — amount_due includes the next + // full billing cycle which isn't charged immediately. + const isProration = (line: Stripe.InvoiceLineItem): boolean => { + const parent = line.parent; + if (!parent) return false; + return ( + parent.invoice_item_details?.proration === true || + parent.subscription_item_details?.proration === true + ); + }; + const prorationAmount = upcomingInvoice.lines.data + .filter(isProration) + .reduce((sum, line) => sum + line.amount, 0); + return { - immediateAmountCents: upcomingInvoice.amount_due, + immediateAmountCents: prorationAmount, currency: upcomingInvoice.currency, targetPlan: newPlan, currentPeriodEnd: sub.currentPeriodEnd!.toISOString(), From a8039646171e6aa596faba49dc66eb0df656f8ab Mon Sep 17 00:00:00 2001 From: krakenhavoc Date: Sat, 14 Mar 2026 12:55:14 +0000 Subject: [PATCH 3/3] fix: price difference calc --- backend/src/billing/billing.service.spec.ts | 70 ++++++++----------- backend/src/billing/billing.service.ts | 77 +++++++++++++-------- 2 files changed, 77 insertions(+), 70 deletions(-) diff --git a/backend/src/billing/billing.service.spec.ts b/backend/src/billing/billing.service.spec.ts index e99cfe0..0c73b03 100644 --- a/backend/src/billing/billing.service.spec.ts +++ b/backend/src/billing/billing.service.spec.ts @@ -11,7 +11,9 @@ const mockStripe = { checkout: { sessions: { create: jest.fn() } }, billingPortal: { sessions: { create: jest.fn() } }, subscriptions: { retrieve: jest.fn(), update: jest.fn() }, - invoices: { createPreview: jest.fn() }, + prices: { retrieve: jest.fn() }, + invoiceItems: { create: jest.fn() }, + invoices: { create: jest.fn(), pay: jest.fn() }, webhooks: { constructEvent: jest.fn() }, }; jest.mock('stripe', () => { @@ -360,50 +362,21 @@ describe('BillingService', () => { it('returns preview for valid upgrade', async () => { mockRepository.findOne.mockResolvedValue(mockSubscription); - mockStripe.subscriptions.retrieve.mockResolvedValue({ - items: { data: [{ id: 'si_item1' }] }, - }); - mockStripe.invoices.createPreview.mockResolvedValue({ - amount_due: 12900, - currency: 'usd', - lines: { - data: [ - { - amount: -1935, - parent: { - subscription_item_details: { proration: true }, - }, - }, - { - amount: 5270, - parent: { - subscription_item_details: { proration: true }, - }, - }, - { - amount: 7900, - parent: { - subscription_item_details: { proration: false }, - }, - }, - ], - }, - }); + mockStripe.prices.retrieve + .mockResolvedValueOnce({ unit_amount: 2900, currency: 'usd' }) + .mockResolvedValueOnce({ unit_amount: 7900, currency: 'usd' }); const result = await service.previewUpgrade(userId, 'team'); expect(result).toEqual({ - immediateAmountCents: 3335, + immediateAmountCents: 5000, currency: 'usd', targetPlan: 'team', currentPeriodEnd: mockSubscription.currentPeriodEnd!.toISOString(), }); - expect(mockStripe.invoices.createPreview).toHaveBeenCalledWith({ - subscription: 'sub_test123', - subscription_details: { - items: [{ id: 'si_item1', price: 'price_team_456' }], - proration_behavior: 'create_prorations', - }, - }); + expect(mockStripe.prices.retrieve).toHaveBeenCalledWith( + 'price_starter_123', + ); + expect(mockStripe.prices.retrieve).toHaveBeenCalledWith('price_team_456'); }); }); @@ -422,12 +395,18 @@ describe('BillingService', () => { ).rejects.toThrow(BadRequestException); }); - it('upgrades subscription with proration', async () => { + it('upgrades subscription with flat difference', async () => { mockRepository.findOne.mockResolvedValue({ ...mockSubscription }); + mockStripe.prices.retrieve + .mockResolvedValueOnce({ unit_amount: 2900, currency: 'usd' }) + .mockResolvedValueOnce({ unit_amount: 7900, currency: 'usd' }); mockStripe.subscriptions.retrieve.mockResolvedValue({ items: { data: [{ id: 'si_item1' }] }, }); mockStripe.subscriptions.update.mockResolvedValue({}); + mockStripe.invoiceItems.create.mockResolvedValue({}); + mockStripe.invoices.create.mockResolvedValue({ id: 'inv_123' }); + mockStripe.invoices.pay.mockResolvedValue({}); mockRepository.save.mockResolvedValue({ ...mockSubscription, plan: 'team', @@ -439,9 +418,20 @@ describe('BillingService', () => { 'sub_test123', { items: [{ id: 'si_item1', price: 'price_team_456' }], - proration_behavior: 'create_prorations', + proration_behavior: 'none', }, ); + expect(mockStripe.invoiceItems.create).toHaveBeenCalledWith({ + customer: 'cus_test123', + amount: 5000, + currency: 'usd', + description: 'Plan upgrade: starter → team', + }); + expect(mockStripe.invoices.create).toHaveBeenCalledWith({ + customer: 'cus_test123', + auto_advance: true, + }); + expect(mockStripe.invoices.pay).toHaveBeenCalledWith('inv_123'); expect(mockRepository.save).toHaveBeenCalledWith( expect.objectContaining({ plan: 'team' }), ); diff --git a/backend/src/billing/billing.service.ts b/backend/src/billing/billing.service.ts index 88afa76..8f197b8 100644 --- a/backend/src/billing/billing.service.ts +++ b/backend/src/billing/billing.service.ts @@ -142,38 +142,13 @@ export class BillingService { currentPeriodEnd: string; }> { const sub = await this.validateUpgrade(userId, newPlan); - const newPriceId = this.priceMap[newPlan]; - - const stripeSub = await this.stripe.subscriptions.retrieve( - sub.stripeSubscriptionId!, - ); - const itemId = stripeSub.items.data[0].id; - const upcomingInvoice = await this.stripe.invoices.createPreview({ - subscription: sub.stripeSubscriptionId!, - subscription_details: { - items: [{ id: itemId, price: newPriceId }], - proration_behavior: 'create_prorations', - }, - }); - - // Sum only the proration line items — amount_due includes the next - // full billing cycle which isn't charged immediately. - const isProration = (line: Stripe.InvoiceLineItem): boolean => { - const parent = line.parent; - if (!parent) return false; - return ( - parent.invoice_item_details?.proration === true || - parent.subscription_item_details?.proration === true - ); - }; - const prorationAmount = upcomingInvoice.lines.data - .filter(isProration) - .reduce((sum, line) => sum + line.amount, 0); + const { currentPriceCents, newPriceCents, currency } = + await this.getPriceDifference(sub.plan, newPlan); return { - immediateAmountCents: prorationAmount, - currency: upcomingInvoice.currency, + immediateAmountCents: newPriceCents - currentPriceCents, + currency, targetPlan: newPlan, currentPeriodEnd: sub.currentPeriodEnd!.toISOString(), }; @@ -191,16 +166,35 @@ export class BillingService { const sub = await this.validateUpgrade(userId, newPlan); const newPriceId = this.priceMap[newPlan]; + const { currentPriceCents, newPriceCents, currency } = + await this.getPriceDifference(sub.plan, newPlan); + const differenceCents = newPriceCents - currentPriceCents; + const stripeSub = await this.stripe.subscriptions.retrieve( sub.stripeSubscriptionId!, ); const itemId = stripeSub.items.data[0].id; + // Switch the plan without Stripe's day-based proration. await this.stripe.subscriptions.update(sub.stripeSubscriptionId!, { items: [{ id: itemId, price: newPriceId }], - proration_behavior: 'create_prorations', + proration_behavior: 'none', + }); + + // Charge the flat price difference immediately. + await this.stripe.invoiceItems.create({ + customer: sub.stripeCustomerId, + amount: differenceCents, + currency, + description: `Plan upgrade: ${sub.plan} → ${newPlan}`, }); + const invoice = await this.stripe.invoices.create({ + customer: sub.stripeCustomerId, + auto_advance: true, + }); + await this.stripe.invoices.pay(invoice.id); + sub.plan = newPlan; await this.subscriptionRepository.save(sub); @@ -214,6 +208,29 @@ export class BillingService { }; } + private async getPriceDifference( + currentPlan: string, + newPlan: string, + ): Promise<{ + currentPriceCents: number; + newPriceCents: number; + currency: string; + }> { + const currentPriceId = this.priceMap[currentPlan]; + const newPriceId = this.priceMap[newPlan]; + + const [currentPrice, newPrice] = await Promise.all([ + this.stripe.prices.retrieve(currentPriceId), + this.stripe.prices.retrieve(newPriceId), + ]); + + return { + currentPriceCents: currentPrice.unit_amount ?? 0, + newPriceCents: newPrice.unit_amount ?? 0, + currency: newPrice.currency, + }; + } + private async validateUpgrade( userId: string, newPlan: string,