Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backend/src/billing/billing.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ describe('BillingController', () => {
createCheckoutSession: jest.fn(),
getSubscription: jest.fn(),
createPortalSession: jest.fn(),
previewUpgrade: jest.fn(),
upgradeSubscription: jest.fn(),
constructWebhookEvent: jest.fn(),
handleWebhookEvent: jest.fn(),
};
Expand Down Expand Up @@ -98,6 +100,46 @@ describe('BillingController', () => {
});
});

describe('previewUpgrade', () => {
it('returns preview from service', async () => {
const preview = {
immediateAmountCents: 3350,
currency: 'usd',
targetPlan: 'team',
currentPeriodEnd: '2026-04-09T00:00:00.000Z',
};
mockBillingService.previewUpgrade.mockResolvedValue(preview);

const result = await controller.previewUpgrade(mockReq, {
plan: 'team',
});
expect(result).toEqual(preview);
expect(mockBillingService.previewUpgrade).toHaveBeenCalledWith(
userId,
'team',
);
});
});

describe('upgrade', () => {
it('returns upgraded subscription from service', async () => {
const upgraded = {
plan: 'team',
status: 'active',
currentPeriodEnd: '2026-04-09T00:00:00.000Z',
cancelAtPeriodEnd: false,
};
mockBillingService.upgradeSubscription.mockResolvedValue(upgraded);

const result = await controller.upgrade(mockReq, { plan: 'team' });
expect(result).toEqual(upgraded);
expect(mockBillingService.upgradeSubscription).toHaveBeenCalledWith(
userId,
'team',
);
});
});

describe('webhook', () => {
it('throws on missing signature', async () => {
const req = { rawBody: Buffer.from('{}') } as any;
Expand Down
30 changes: 30 additions & 0 deletions backend/src/billing/billing.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
import { SkipThrottle } from '@nestjs/throttler';
import { BillingService } from './billing.service';
import { CreateCheckoutDto } from './dto/create-checkout.dto';
import { UpgradePlanDto } from './dto/upgrade-plan.dto';
import { JwtOrApiKeyGuard } from '../auth/guards/jwt-or-api-key.guard';
import type { RequestWithUser } from '../auth/interfaces/request-with-user.interface';
import { RateLimitCategoryDecorator } from '../throttler/decorators/rate-limit-category.decorator';
Expand Down Expand Up @@ -91,6 +92,35 @@ export class BillingController {
return { portalUrl };
}

@Post('upgrade/preview')
@UseGuards(JwtOrApiKeyGuard)
@ApiBearerAuth()
@ApiOperation({ summary: 'Preview prorated cost for a plan upgrade' })
@ApiResponse({ status: 200, description: 'Upgrade preview details' })
@ApiResponse({ status: 400, description: 'Invalid upgrade path' })
@ApiResponse({ status: 401, description: 'Unauthorized' })
@ApiResponse({ status: 404, description: 'No active subscription' })
@RateLimitCategoryDecorator(RateLimitCategory.AUTHENTICATED_READ)
async previewUpgrade(
@Request() req: RequestWithUser,
@Body() dto: UpgradePlanDto,
) {
return this.billingService.previewUpgrade(req.user.userId, dto.plan);
}

@Post('upgrade')
@UseGuards(JwtOrApiKeyGuard)
@ApiBearerAuth()
@ApiOperation({ summary: 'Upgrade subscription with proration' })
@ApiResponse({ status: 201, description: 'Subscription upgraded' })
@ApiResponse({ status: 400, description: 'Invalid upgrade path' })
@ApiResponse({ status: 401, description: 'Unauthorized' })
@ApiResponse({ status: 404, description: 'No active subscription' })
@RateLimitCategoryDecorator(RateLimitCategory.AUTHENTICATED_WRITE)
async upgrade(@Request() req: RequestWithUser, @Body() dto: UpgradePlanDto) {
return this.billingService.upgradeSubscription(req.user.userId, dto.plan);
}

@Post('webhook')
@SkipThrottle()
@ApiExcludeEndpoint()
Expand Down
149 changes: 148 additions & 1 deletion backend/src/billing/billing.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ const mockStripe = {
customers: { create: jest.fn() },
checkout: { sessions: { create: jest.fn() } },
billingPortal: { sessions: { create: jest.fn() } },
subscriptions: { retrieve: jest.fn() },
subscriptions: { retrieve: jest.fn(), update: jest.fn() },
prices: { retrieve: jest.fn() },
invoiceItems: { create: jest.fn() },
invoices: { create: jest.fn(), pay: jest.fn() },
webhooks: { constructEvent: jest.fn() },
};
jest.mock('stripe', () => {
Expand Down Expand Up @@ -59,6 +62,7 @@ describe('BillingService', () => {
KK_STRIPE_SECRET_KEY: 'sk_test_123',
KK_STRIPE_WEBHOOK_SECRET: 'whsec_test',
KK_STRIPE_PRICE_STARTER: 'price_starter_123',
KK_STRIPE_PRICE_TEAM: 'price_team_456',
KK_APP_DOMAIN: 'app.krakenkey.io',
};
return config[key] ?? defaultValue ?? '';
Expand Down Expand Up @@ -289,5 +293,148 @@ describe('BillingService', () => {
}),
);
});

it('syncs plan from price ID on subscription.updated', async () => {
mockRepository.findOne.mockResolvedValue({ ...mockSubscription });

await service.handleWebhookEvent({
type: 'customer.subscription.updated',
data: {
object: {
id: 'sub_test123',
status: 'active',
items: {
data: [
{
current_period_end:
Math.floor(Date.now() / 1000) + 86400 * 30,
price: { id: 'price_team_456' },
},
],
},
cancel_at_period_end: false,
},
},
} as any);

expect(mockRepository.save).toHaveBeenCalledWith(
expect.objectContaining({
plan: 'team',
}),
);
});
});

describe('previewUpgrade', () => {
it('throws NotFoundException when no subscription exists', async () => {
mockRepository.findOne.mockResolvedValue(null);
await expect(service.previewUpgrade(userId, 'team')).rejects.toThrow(
NotFoundException,
);
});

it('throws BadRequestException when subscription is not active', async () => {
mockRepository.findOne.mockResolvedValue({
...mockSubscription,
status: 'past_due',
});
await expect(service.previewUpgrade(userId, 'team')).rejects.toThrow(
BadRequestException,
);
});

it('throws BadRequestException when target plan is not higher', async () => {
mockRepository.findOne.mockResolvedValue({
...mockSubscription,
plan: 'team',
});
await expect(service.previewUpgrade(userId, 'starter')).rejects.toThrow(
BadRequestException,
);
});

it('throws BadRequestException for same plan', async () => {
mockRepository.findOne.mockResolvedValue(mockSubscription);
await expect(service.previewUpgrade(userId, 'starter')).rejects.toThrow(
BadRequestException,
);
});

it('returns preview for valid upgrade', async () => {
mockRepository.findOne.mockResolvedValue(mockSubscription);
mockStripe.prices.retrieve
.mockResolvedValueOnce({ unit_amount: 2900, currency: 'usd' })
.mockResolvedValueOnce({ unit_amount: 7900, currency: 'usd' });

const result = await service.previewUpgrade(userId, 'team');
expect(result).toEqual({
immediateAmountCents: 5000,
currency: 'usd',
targetPlan: 'team',
currentPeriodEnd: mockSubscription.currentPeriodEnd!.toISOString(),
});
expect(mockStripe.prices.retrieve).toHaveBeenCalledWith(
'price_starter_123',
);
expect(mockStripe.prices.retrieve).toHaveBeenCalledWith('price_team_456');
});
});

describe('upgradeSubscription', () => {
it('throws NotFoundException when no subscription exists', async () => {
mockRepository.findOne.mockResolvedValue(null);
await expect(service.upgradeSubscription(userId, 'team')).rejects.toThrow(
NotFoundException,
);
});

it('throws BadRequestException when target plan is not higher', async () => {
mockRepository.findOne.mockResolvedValue(mockSubscription);
await expect(
service.upgradeSubscription(userId, 'starter'),
).rejects.toThrow(BadRequestException);
});

it('upgrades subscription with flat difference', async () => {
mockRepository.findOne.mockResolvedValue({ ...mockSubscription });
mockStripe.prices.retrieve
.mockResolvedValueOnce({ unit_amount: 2900, currency: 'usd' })
.mockResolvedValueOnce({ unit_amount: 7900, currency: 'usd' });
mockStripe.subscriptions.retrieve.mockResolvedValue({
items: { data: [{ id: 'si_item1' }] },
});
mockStripe.subscriptions.update.mockResolvedValue({});
mockStripe.invoiceItems.create.mockResolvedValue({});
mockStripe.invoices.create.mockResolvedValue({ id: 'inv_123' });
mockStripe.invoices.pay.mockResolvedValue({});
mockRepository.save.mockResolvedValue({
...mockSubscription,
plan: 'team',
});

const result = await service.upgradeSubscription(userId, 'team');
expect(result.plan).toBe('team');
expect(mockStripe.subscriptions.update).toHaveBeenCalledWith(
'sub_test123',
{
items: [{ id: 'si_item1', price: 'price_team_456' }],
proration_behavior: 'none',
},
);
expect(mockStripe.invoiceItems.create).toHaveBeenCalledWith({
customer: 'cus_test123',
amount: 5000,
currency: 'usd',
description: 'Plan upgrade: starter → team',
});
expect(mockStripe.invoices.create).toHaveBeenCalledWith({
customer: 'cus_test123',
auto_advance: true,
});
expect(mockStripe.invoices.pay).toHaveBeenCalledWith('inv_123');
expect(mockRepository.save).toHaveBeenCalledWith(
expect.objectContaining({ plan: 'team' }),
);
});
});
});
Loading
Loading