diff --git a/src/auth/dto/auth.dto.ts b/src/auth/dto/auth.dto.ts index ba420e7..0db261a 100644 --- a/src/auth/dto/auth.dto.ts +++ b/src/auth/dto/auth.dto.ts @@ -1,7 +1,7 @@ import { IsEmail, IsString, IsEnum, IsOptional, IsNotEmpty } from 'class-validator'; import { ApiProperty } from '@nestjs/swagger'; import { UserRole } from '../../users/entities/user.entity'; -import { IsStrongPassword } from '../../common/validators/is-strong-password.validator'; +import { IsStrongPassword } from '../../common/validators/password.validator'; export class RegisterDto { @ApiProperty({ example: 'john.doe@example.com' }) diff --git a/src/common/interceptors/global-exception.filter.spec.ts b/src/common/interceptors/global-exception.filter.spec.ts new file mode 100644 index 0000000..445bd3c --- /dev/null +++ b/src/common/interceptors/global-exception.filter.spec.ts @@ -0,0 +1,38 @@ +import { GlobalExceptionFilter } from './global-exception.filter'; +import { HttpStatus } from '@nestjs/common'; +import { runWithCorrelationId } from '../utils/correlation.utils'; + +describe('GlobalExceptionFilter', () => { + it('adds correlation ID to error response and header', () => { + const filter = new GlobalExceptionFilter(); + + const req: any = { method: 'GET', url: '/test' }; + const responseHeaders: Record = {}; + const res: any = { + status: (code: number) => { + res.statusCode = code; + return res; + }, + json: (body: any) => { + res.body = body; + return res; + }, + setHeader: (name: string, value: string) => { + responseHeaders[name.toLowerCase()] = value; + }, + getHeader: (name: string) => responseHeaders[name.toLowerCase()], + }; + + runWithCorrelationId(() => { + filter.catch(new Error('Test error'), { + switchToHttp: () => ({ getRequest: () => req, getResponse: () => res }), + } as any); + }, 'cid-123'); + + const body = res.body; + + expect(res.statusCode).toBe(HttpStatus.INTERNAL_SERVER_ERROR); + expect(body.correlationId).toBe('cid-123'); + expect(body.message).toBe('Test error'); + }); +}); diff --git a/src/common/interceptors/global-exception.filter.ts b/src/common/interceptors/global-exception.filter.ts index db85ff1..4d9c390 100644 --- a/src/common/interceptors/global-exception.filter.ts +++ b/src/common/interceptors/global-exception.filter.ts @@ -9,6 +9,7 @@ import { import { Request, Response } from 'express'; import { QueryFailedError, EntityNotFoundError } from 'typeorm'; import { ApiError, ValidationErrorDetail } from '../../interfaces/api-error.interface'; +import { CORRELATION_ID_HEADER, getCorrelationId } from '../utils/correlation.utils'; @Catch() export class GlobalExceptionFilter implements ExceptionFilter { @@ -22,16 +23,23 @@ export class GlobalExceptionFilter implements ExceptionFilter { const { statusCode, message, error, details, stack } = this.resolveException(exception); + const correlationId = getCorrelationId(); + const errorResponse: ApiError = { statusCode, message, error, timestamp: new Date().toISOString(), path: request.url, + correlationId, ...(details?.length && { details }), ...(!this.isProduction && stack && { stack }), }; + if (correlationId) { + response.setHeader(CORRELATION_ID_HEADER, correlationId); + } + this.logger.error( `[${request.method}] ${request.url} → ${statusCode} ${error}: ${ Array.isArray(message) ? message.join(', ') : message diff --git a/src/common/interceptors/logging.interceptor.spec.ts b/src/common/interceptors/logging.interceptor.spec.ts new file mode 100644 index 0000000..7c03d01 --- /dev/null +++ b/src/common/interceptors/logging.interceptor.spec.ts @@ -0,0 +1,33 @@ +import { LoggingInterceptor } from './logging.interceptor'; +import { of, firstValueFrom } from 'rxjs'; + +describe('LoggingInterceptor', () => { + it('attaches and propagates correlation ID header', async () => { + const interceptor = new LoggingInterceptor(); + + const req: any = { method: 'GET', url: '/spam', headers: {} }; + const headers: Record = {}; + const res: any = { + statusCode: 200, + setHeader: (name: string, value: string) => { + headers[name.toLowerCase()] = value; + }, + getHeader: (name: string) => headers[name.toLowerCase()], + }; + + const context: any = { + getType: () => 'http', + switchToHttp: () => ({ getRequest: () => req, getResponse: () => res }), + }; + + const next: any = { + handle: () => of({ success: true }), + }; + + await firstValueFrom(interceptor.intercept(context, next)); + + const correlationId = res.getHeader('x-request-id'); + expect(typeof correlationId).toBe('string'); + expect(correlationId).toMatch(/^cid-/); + }); +}); diff --git a/src/common/interceptors/logging.interceptor.ts b/src/common/interceptors/logging.interceptor.ts index 21b75b7..ed8341f 100644 --- a/src/common/interceptors/logging.interceptor.ts +++ b/src/common/interceptors/logging.interceptor.ts @@ -2,6 +2,7 @@ import { Injectable, NestInterceptor, ExecutionContext, CallHandler, Logger } fr import { Observable, throwError } from 'rxjs'; import { tap, catchError } from 'rxjs/operators'; import { Request, Response } from 'express'; +import { CORRELATION_ID_HEADER, getCorrelationId } from '../utils/correlation.utils'; export interface RequestLog { requestId: string; @@ -49,7 +50,10 @@ export class LoggingInterceptor implements NestInterceptor { } const startTime = Date.now(); - const requestId = this.generateRequestId(); + const requestId = getCorrelationId() || this.generateRequestId(); + + const response = httpCtx.getResponse(); + response?.setHeader(CORRELATION_ID_HEADER, requestId); const baseLog: RequestLog = { requestId, @@ -67,12 +71,12 @@ export class LoggingInterceptor implements NestInterceptor { return next.handle().pipe( tap(() => { - const response = httpCtx.getResponse(); + const res = httpCtx.getResponse(); this.logOutgoing({ ...baseLog, - statusCode: response.statusCode, + statusCode: res.statusCode, responseTimeMs: Date.now() - startTime, - contentLength: this.getContentLength(response), + contentLength: this.getContentLength(res), }); }), catchError((error: unknown) => { diff --git a/src/common/utils/correlation.utils.spec.ts b/src/common/utils/correlation.utils.spec.ts new file mode 100644 index 0000000..cc31312 --- /dev/null +++ b/src/common/utils/correlation.utils.spec.ts @@ -0,0 +1,50 @@ +import { + correlationMiddleware, + getCorrelationId, + injectCorrelationIdToHeaders, + CORRELATION_ID_HEADER, +} from './correlation.utils'; + +describe('correlation.utils', () => { + it('generates and propagates correlation ID through middleware', (done) => { + const req: any = { method: 'GET', url: '/test', headers: {} }; + const headers: Record = {}; + const res: any = { + setHeader: (name: string, value: string) => { + headers[name.toLowerCase()] = value; + }, + getHeader: (name: string) => headers[name.toLowerCase()], + }; + + correlationMiddleware(req, res, () => { + const id = getCorrelationId(); + expect(typeof id).toBe('string'); + expect(res.getHeader(CORRELATION_ID_HEADER)).toBe(id); + done(); + }); + }); + + it('respects incoming x-request-id header', (done) => { + const incomingId = 'test-correlation-id'; + const req: any = { method: 'GET', url: '/test', headers: { 'x-request-id': incomingId } }; + const headers: Record = {}; + const res: any = { + setHeader: (name: string, value: string) => { + headers[name.toLowerCase()] = value; + }, + getHeader: (name: string) => headers[name.toLowerCase()], + }; + + correlationMiddleware(req, res, () => { + expect(getCorrelationId()).toBe(incomingId); + expect(res.getHeader(CORRELATION_ID_HEADER)).toBe(incomingId); + done(); + }); + }); + + it('injects correlation header into outgoing request headers', () => { + const custom = injectCorrelationIdToHeaders({ Authorization: 'Bearer token' }, 'cid-1'); + expect(custom[CORRELATION_ID_HEADER]).toBe('cid-1'); + expect(custom.Authorization).toBe('Bearer token'); + }); +}); diff --git a/src/common/utils/correlation.utils.ts b/src/common/utils/correlation.utils.ts new file mode 100644 index 0000000..fb8a171 --- /dev/null +++ b/src/common/utils/correlation.utils.ts @@ -0,0 +1,51 @@ +import { AsyncLocalStorage } from 'async_hooks'; +import { Request, Response, NextFunction } from 'express'; + +export const CORRELATION_ID_HEADER = 'x-request-id'; + +export interface CorrelationContext { + correlationId: string; +} + +const correlationStorage = new AsyncLocalStorage(); + +export function generateCorrelationId(): string { + return `cid-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`; +} + +export function getCorrelationId(): string | undefined { + const store = correlationStorage.getStore(); + return store?.correlationId; +} + +export function setCorrelationId(req: Request, res: Response, correlationId: string): void { + (req as Request & { correlationId?: string }).correlationId = correlationId; + res.setHeader(CORRELATION_ID_HEADER, correlationId); +} + +export function correlationMiddleware(req: Request, res: Response, next: NextFunction): void { + const incoming = + (req.headers[CORRELATION_ID_HEADER] as string) || (req.headers['x-correlation-id'] as string); + const correlationId = incoming || generateCorrelationId(); + + correlationStorage.run({ correlationId }, () => { + setCorrelationId(req, res, correlationId); + next(); + }); +} + +export function runWithCorrelationId(callback: () => T, correlationId?: string): T { + const id = correlationId || generateCorrelationId(); + return correlationStorage.run({ correlationId: id }, callback); +} + +export function injectCorrelationIdToHeaders( + headers: Record = {}, + correlationId?: string, +): Record { + const id = correlationId || getCorrelationId() || generateCorrelationId(); + return { + ...headers, + [CORRELATION_ID_HEADER]: id, + }; +} diff --git a/src/common/utils/sanitization.utils.spec.ts b/src/common/utils/sanitization.utils.spec.ts new file mode 100644 index 0000000..b851293 --- /dev/null +++ b/src/common/utils/sanitization.utils.spec.ts @@ -0,0 +1,32 @@ +import { sanitizeSqlLike, enforceWhitelistedValue } from './sanitization.utils'; + +describe('sanitization.utils', () => { + describe('sanitizeSqlLike', () => { + it('trims whitespace and escapes %, _, and \\', () => { + const raw = " test%_\\' OR 1=1 -- "; + const escaped = sanitizeSqlLike(raw); + + expect(escaped).toBe("test\\%\\_\\\\' OR 1=1 --"); + }); + + it('normalizes control characters to space', () => { + const raw = 'foo\nbar\tbaz\rqux'; + const escaped = sanitizeSqlLike(raw); + + expect(escaped).toBe('foo bar baz qux'); + }); + }); + + describe('enforceWhitelistedValue', () => { + it('returns value from allowlist', () => { + const value = enforceWhitelistedValue('active', ['active', 'inactive'] as const, 'status'); + expect(value).toBe('active'); + }); + + it('throws if value is not allowlisted', () => { + expect(() => + enforceWhitelistedValue('hacked' as any, ['active', 'inactive'] as const, 'status'), + ).toThrow(/Invalid value for status/); + }); + }); +}); diff --git a/src/common/utils/sanitization.utils.ts b/src/common/utils/sanitization.utils.ts new file mode 100644 index 0000000..fa9381c --- /dev/null +++ b/src/common/utils/sanitization.utils.ts @@ -0,0 +1,32 @@ +export function sanitizeSqlLike(input: string): string { + if (typeof input !== 'string') { + throw new TypeError('Expected a string for SQL LIKE sanitization'); + } + + const trimmed = input.trim(); + + // Prevent CR/LF/Tab injection and normalize whitespace + const normalized = trimmed.replace(/[\r\n\t]+/g, ' '); + + // Escape SQL wildcard and escape characters for LIKE operators. + // This makes sure user-supplied `%`, `_`, and `\\` are treated literally. + return normalized.replace(/[\\%_]/g, (char) => `\\${char}`); +} + +export function enforceWhitelistedValue( + value: T | undefined, + allowlist: readonly T[], + fieldName: string, +): T | undefined { + if (value === undefined || value === null || value === '') { + return undefined; + } + + if (!allowlist.includes(value as T)) { + throw new Error( + `Invalid value for ${fieldName}: ${value}. Allowed values are ${allowlist.join(', ')}`, + ); + } + + return value as T; +} diff --git a/src/common/validators/is-strong-password.validator.ts b/src/common/validators/is-strong-password.validator.ts index c954d18..69f4227 100644 --- a/src/common/validators/is-strong-password.validator.ts +++ b/src/common/validators/is-strong-password.validator.ts @@ -1,35 +1,5 @@ -import { - registerDecorator, - ValidationOptions, - ValidatorConstraint, - ValidatorConstraintInterface, -} from 'class-validator'; - -@ValidatorConstraint({ name: 'isStrongPassword', async: false }) -export class IsStrongPasswordConstraint implements ValidatorConstraintInterface { - validate(password: string) { - if (typeof password !== 'string') return false; - const hasUpperCase = /[A-Z]/.test(password); - const hasLowerCase = /[a-z]/.test(password); - const hasNumber = /[0-9]/.test(password); - const hasSymbol = /[!@#$%^&*(),.?":{}|<>]/.test(password); - - return password.length >= 8 && hasUpperCase && hasLowerCase && hasNumber && hasSymbol; - } - - defaultMessage() { - return 'Password must be at least 8 characters long and contain at least one uppercase letter, one lowercase letter, one number, and one special character.'; - } -} - -export function IsStrongPassword(validationOptions?: ValidationOptions) { - return function (object: object, propertyName: string) { - registerDecorator({ - target: object.constructor, - propertyName, - options: validationOptions, - constraints: [], - validator: IsStrongPasswordConstraint, - }); - }; -} +export { + IsStrongPassword, + calculatePasswordStrength, + PasswordStrengthResult, +} from './password.validator'; diff --git a/src/common/validators/password.validator.spec.ts b/src/common/validators/password.validator.spec.ts new file mode 100644 index 0000000..78c398f --- /dev/null +++ b/src/common/validators/password.validator.spec.ts @@ -0,0 +1,42 @@ +import { calculatePasswordStrength, PasswordConstraint } from './password.validator'; + +describe('Password Validator', () => { + describe('calculatePasswordStrength', () => { + it('recognizes weak passwords', () => { + const result = calculatePasswordStrength('abc'); + expect(result.isValid).toBe(false); + expect(result.level).toBe('weak'); + expect(result.errors).toEqual( + expect.arrayContaining([ + 'Password must be at least 8 characters long', + 'Password must contain at least one uppercase letter', + 'Password must contain at least one number', + ]), + ); + }); + + it('recognizes strong passwords', () => { + const result = calculatePasswordStrength('StrongPass123!'); + expect(result.isValid).toBe(true); + expect(result.level).toBe('strong'); + expect(result.errors).toEqual([]); + }); + }); + + describe('PasswordConstraint', () => { + const constraint = new PasswordConstraint(); + + it('validates strong password as valid', () => { + expect(constraint.validate('StrongPass123!')).toBe(true); + }); + + it('validates weak password as invalid', () => { + expect(constraint.validate('weak')).toBe(false); + }); + + it('returns detailed message for weak password', () => { + const message = constraint.defaultMessage({ value: 'weak' } as any); + expect(message).toContain('Password must be at least 8 characters long'); + }); + }); +}); diff --git a/src/common/validators/password.validator.ts b/src/common/validators/password.validator.ts new file mode 100644 index 0000000..e74d64e --- /dev/null +++ b/src/common/validators/password.validator.ts @@ -0,0 +1,86 @@ +import { + registerDecorator, + ValidationOptions, + ValidatorConstraint, + ValidatorConstraintInterface, + ValidationArguments, +} from 'class-validator'; + +export interface PasswordStrengthResult { + isValid: boolean; + errors: string[]; + score: number; + level: 'weak' | 'medium' | 'strong'; +} + +export const PASSWORD_REQUIREMENTS = { + minLength: 8, + uppercase: /[A-Z]/, + lowercase: /[a-z]/, + number: /\d/, + special: /[!@#$%^&*()_+\-=[\]{};':"\\|,.<>/?]/, +}; + +export function calculatePasswordStrength(password: string): PasswordStrengthResult { + const errors: string[] = []; + + if (typeof password !== 'string') { + errors.push('Password must be a string'); + return { isValid: false, errors, score: 0, level: 'weak' }; + } + + if (password.length < PASSWORD_REQUIREMENTS.minLength) { + errors.push(`Password must be at least ${PASSWORD_REQUIREMENTS.minLength} characters long`); + } + if (!PASSWORD_REQUIREMENTS.uppercase.test(password)) { + errors.push('Password must contain at least one uppercase letter'); + } + if (!PASSWORD_REQUIREMENTS.lowercase.test(password)) { + errors.push('Password must contain at least one lowercase letter'); + } + if (!PASSWORD_REQUIREMENTS.number.test(password)) { + errors.push('Password must contain at least one number'); + } + if (!PASSWORD_REQUIREMENTS.special.test(password)) { + errors.push('Password must contain at least one special character'); + } + + const score = 5 - errors.length; + const level = score <= 2 ? 'weak' : score === 3 || score === 4 ? 'medium' : 'strong'; + + return { + isValid: errors.length === 0, + errors, + score: Math.max(0, score), + level, + }; +} + +@ValidatorConstraint({ name: 'password', async: false }) +export class PasswordConstraint implements ValidatorConstraintInterface { + validate(password: string): boolean { + const result = calculatePasswordStrength(password); + return result.isValid; + } + + defaultMessage(args: ValidationArguments): string { + const password = args.value as string; + const result = calculatePasswordStrength(password); + if (result.errors.length === 0) { + return 'Password does not meet strength requirements'; + } + return result.errors.join('; '); + } +} + +export function IsStrongPassword(validationOptions?: ValidationOptions) { + return function (object: object, propertyName: string) { + registerDecorator({ + target: object.constructor, + propertyName, + options: validationOptions, + constraints: [], + validator: PasswordConstraint, + }); + }; +} diff --git a/src/courses/courses.service.ts b/src/courses/courses.service.ts index 0aec6b8..31bad41 100644 --- a/src/courses/courses.service.ts +++ b/src/courses/courses.service.ts @@ -4,6 +4,7 @@ import { Repository } from 'typeorm'; import { Course } from './entities/course.entity'; import { UpdateCourseDto } from './dto/update-course.dto'; import { paginate, PaginatedResponse } from '../common/utils/pagination.util'; +import { sanitizeSqlLike, enforceWhitelistedValue } from '../common/utils/sanitization.utils'; import { CourseSearchDto } from './dto/course-search.dto'; import { CachingService } from '../caching/caching.service'; import { CacheInvalidationService } from '../caching/invalidation/invalidation.service'; @@ -40,13 +41,17 @@ export class CoursesService { query.leftJoinAndSelect('course.instructor', 'instructor'); if (filter?.search) { - query.andWhere('(course.title ILIKE :search OR course.description ILIKE :search)', { - search: `%${filter.search}%`, - }); + const safeSearch = sanitizeSqlLike(filter.search); + query.andWhere( + "(course.title ILIKE :search ESCAPE '\\' OR course.description ILIKE :search ESCAPE '\\')", + { search: `%${safeSearch}%` }, + ); } if (filter?.status) { - query.andWhere('course.status = :status', { status: filter.status }); + const allowedStatuses = ['draft', 'published', 'archived'] as const; + const status = enforceWhitelistedValue(filter.status, allowedStatuses, 'status'); + query.andWhere('course.status = :status', { status }); } if (filter?.instructorId) { diff --git a/src/main.ts b/src/main.ts index fbd65a2..028e878 100644 --- a/src/main.ts +++ b/src/main.ts @@ -9,6 +9,8 @@ import Redis from 'ioredis'; import { AppModule } from './app.module'; import { GlobalExceptionFilter } from './common/interceptors/global-exception.filter'; import { ResponseTransformInterceptor } from './common/interceptors/response-transform.interceptor'; +import { LoggingInterceptor } from './common/interceptors/logging.interceptor'; +import { correlationMiddleware } from './common/utils/correlation.utils'; import { sessionConfig } from './config/cache.config'; import { SESSION_REDIS_CLIENT } from './session/session.constants'; @@ -26,6 +28,8 @@ async function bootstrapWorker() { expressApp.set('trust proxy', 1); } + app.use(correlationMiddleware); + app.use( session({ store: new RedisStore({ @@ -50,6 +54,9 @@ async function bootstrapWorker() { // ─── Global Exception Filter ────────────────────────────────────────────── app.useGlobalFilters(new GlobalExceptionFilter()); + // ─── Global Logging Interceptor ─────────────────────────────────────────── + app.useGlobalInterceptors(new LoggingInterceptor()); + // ─── Global Response Transform Interceptor ─────────────────────────────── app.useGlobalInterceptors(new ResponseTransformInterceptor()); diff --git a/src/orchestration/service-mesh/service-mesh.service.spec.ts b/src/orchestration/service-mesh/service-mesh.service.spec.ts new file mode 100644 index 0000000..40d5f99 --- /dev/null +++ b/src/orchestration/service-mesh/service-mesh.service.spec.ts @@ -0,0 +1,24 @@ +import { ServiceMeshService } from './service-mesh.service'; +import { of } from 'rxjs'; + +describe('ServiceMeshService', () => { + it('propagates correlation ID to external API call headers', async () => { + const serviceDiscovery: any = { + getService: jest.fn().mockResolvedValue({ baseUrl: 'http://localhost' }), + markUnhealthy: jest.fn(), + }; + const httpService: any = { + request: jest.fn().mockReturnValue(of({ data: { ok: true } })), + }; + + const service = new ServiceMeshService(serviceDiscovery, httpService); + + await expect(service.request('dummy', '/ping', 'GET')).resolves.toEqual({ ok: true }); + + expect(httpService.request).toHaveBeenCalledWith( + expect.objectContaining({ + headers: expect.objectContaining({ 'x-request-id': expect.any(String) }), + }), + ); + }); +}); diff --git a/src/orchestration/service-mesh/service-mesh.service.ts b/src/orchestration/service-mesh/service-mesh.service.ts index 802124c..f439a2b 100644 --- a/src/orchestration/service-mesh/service-mesh.service.ts +++ b/src/orchestration/service-mesh/service-mesh.service.ts @@ -3,6 +3,10 @@ import { HttpService } from '@nestjs/axios'; import { firstValueFrom } from 'rxjs'; import { AxiosResponse } from 'axios'; import { ServiceDiscoveryService } from '../discovery/service-discovery.service'; +import { + injectCorrelationIdToHeaders, + getCorrelationId, +} from '../../common/utils/correlation.utils'; @Injectable() export class ServiceMeshService { @@ -20,6 +24,8 @@ export class ServiceMeshService { const service = await this.discovery.getService(serviceName); const url = `${service.baseUrl}${path}`; + const correlationId = getCorrelationId(); + try { const response: AxiosResponse = await firstValueFrom( this.httpService.request({ @@ -27,6 +33,7 @@ export class ServiceMeshService { method, data, timeout: 5000, + headers: injectCorrelationIdToHeaders(undefined, correlationId), }), ); diff --git a/src/tenancy/admin/tenant-admin.service.ts b/src/tenancy/admin/tenant-admin.service.ts index db52d88..59d7881 100644 --- a/src/tenancy/admin/tenant-admin.service.ts +++ b/src/tenancy/admin/tenant-admin.service.ts @@ -1,6 +1,7 @@ import { Injectable, NotFoundException } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; +import { sanitizeSqlLike } from '../../common/utils/sanitization.utils'; import { Tenant, TenantStatus, TenantPlan } from '../entities/tenant.entity'; import { TenantConfig } from '../entities/tenant-config.entity'; import { TenantBilling } from '../entities/tenant-billing.entity'; @@ -240,11 +241,13 @@ export class TenantAdminService { * Search tenants */ async searchTenants(query: string): Promise { + const safeQuery = sanitizeSqlLike(query); + return await this.tenantRepository .createQueryBuilder('tenant') - .where('tenant.name ILIKE :query', { query: `%${query}%` }) - .orWhere('tenant.slug ILIKE :query', { query: `%${query}%` }) - .orWhere('tenant.domain ILIKE :query', { query: `%${query}%` }) + .where("tenant.name ILIKE :query ESCAPE '\\'", { query: `%${safeQuery}%` }) + .orWhere("tenant.slug ILIKE :query ESCAPE '\\'", { query: `%${safeQuery}%` }) + .orWhere("tenant.domain ILIKE :query ESCAPE '\\'", { query: `%${safeQuery}%` }) .getMany(); } diff --git a/src/users/dto/create-user.dto.ts b/src/users/dto/create-user.dto.ts index bf15d22..c6fa115 100644 --- a/src/users/dto/create-user.dto.ts +++ b/src/users/dto/create-user.dto.ts @@ -1,7 +1,7 @@ import { IsEmail, IsString, IsOptional, IsEnum, IsNotEmpty } from 'class-validator'; import { ApiProperty } from '@nestjs/swagger'; import { UserRole } from '../entities/user.entity'; -import { IsStrongPassword } from '../../common/validators/is-strong-password.validator'; +import { IsStrongPassword } from '../../common/validators/password.validator'; export class CreateUserDto { @ApiProperty({ example: 'john.doe@example.com' }) diff --git a/src/users/entities/user.entity.ts b/src/users/entities/user.entity.ts index 115ed28..f27e99d 100644 --- a/src/users/entities/user.entity.ts +++ b/src/users/entities/user.entity.ts @@ -83,6 +83,9 @@ export class User { @Column({ nullable: true }) refreshToken?: string; + @Column('text', { array: true, default: [] }) + passwordHistory: string[]; + @Column({ type: 'timestamp', nullable: true }) lastLoginAt?: Date; diff --git a/src/users/users.service.spec.ts b/src/users/users.service.spec.ts new file mode 100644 index 0000000..6e3519c --- /dev/null +++ b/src/users/users.service.spec.ts @@ -0,0 +1,114 @@ +import { UsersService } from './users.service'; +import { UserRole, UserStatus } from './entities/user.entity'; +import * as bcrypt from 'bcryptjs'; + +describe('UsersService', () => { + let service: UsersService; + let queryBuilder: any; + let userRepository: any; + let cachingService: any; + + beforeEach(() => { + queryBuilder = { + andWhere: jest.fn().mockReturnThis(), + getCount: jest.fn().mockResolvedValue(1), + skip: jest.fn().mockReturnThis(), + take: jest.fn().mockReturnThis(), + getMany: jest.fn().mockResolvedValue([{ id: 'user-1', email: 'test@example.com' }]), + }; + + userRepository = { + createQueryBuilder: jest.fn().mockReturnValue(queryBuilder), + }; + + cachingService = { + getOrSet: jest.fn().mockImplementation(async (_key: string, handler: any) => handler()), + }; + + service = new UsersService(userRepository, cachingService, { emit: jest.fn() } as any); + }); + + it('sanitizes search input and uses parameterized ILIKE', async () => { + const maliciousSearch = "a%_b\\ test' OR 1=1 --"; + + await expect( + service.findAll({ + search: maliciousSearch, + role: UserRole.STUDENT, + status: UserStatus.ACTIVE, + }), + ).resolves.toBeDefined(); + + expect(userRepository.createQueryBuilder).toHaveBeenCalledWith('user'); + + expect(queryBuilder.andWhere).toHaveBeenCalledWith( + "(user.email ILIKE :search ESCAPE '\\' OR user.firstName ILIKE :search ESCAPE '\\' OR user.lastName ILIKE :search ESCAPE '\\')", + { + search: "%a\\%\\_b\\\\ test' OR 1=1 --%", + }, + ); + }); + + it('blocks non-whitelisted role values', async () => { + await expect(service.findAll({ role: 'hacker' as any })).rejects.toThrow( + /Invalid value for role/, + ); + }); + + it('blocks non-whitelisted status values', async () => { + await expect(service.findAll({ status: 'hacked' as any })).rejects.toThrow( + /Invalid value for status/, + ); + }); + + it('rejects password reuse via current password', async () => { + const currentHash = await bcrypt.hash('CurrentPass1!', 10); + + userRepository.findOne = jest.fn().mockResolvedValue({ + id: 'user-1', + email: 'test@example.com', + password: currentHash, + passwordHistory: [], + }); + userRepository.save = jest.fn().mockResolvedValue(true); + + await expect(service.update('user-1', { password: 'CurrentPass1!' })).rejects.toThrow( + /New password must be different from the current password/, + ); + }); + + it('rejects password reuse via history', async () => { + const currentHash = await bcrypt.hash('CurrentPass1!', 10); + const oldHash = await bcrypt.hash('OldPass1!', 10); + + userRepository.findOne = jest.fn().mockResolvedValue({ + id: 'user-1', + email: 'test@example.com', + password: currentHash, + passwordHistory: [oldHash], + }); + userRepository.save = jest.fn().mockResolvedValue(true); + + await expect(service.update('user-1', { password: 'OldPass1!' })).rejects.toThrow( + /New password must not match your last 5 passwords/, + ); + }); + + it('updates password and appends current password to history', async () => { + const currentHash = await bcrypt.hash('CurrentPass1!', 10); + + userRepository.findOne = jest.fn().mockResolvedValue({ + id: 'user-1', + email: 'test@example.com', + password: currentHash, + passwordHistory: [], + }); + userRepository.save = jest.fn().mockImplementation(async (user: any) => user); + + const result = await service.update('user-1', { password: 'NewPass1!' }); + + expect(result.password).not.toBe('NewPass1!'); + expect(await bcrypt.compare('NewPass1!', result.password)).toBe(true); + expect(result.passwordHistory).toEqual([currentHash]); + }); +}); diff --git a/src/users/users.service.ts b/src/users/users.service.ts index 50cb018..b365518 100644 --- a/src/users/users.service.ts +++ b/src/users/users.service.ts @@ -1,13 +1,14 @@ -import { Injectable } from '@nestjs/common'; +import { BadRequestException, Injectable } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; -import { User } from './entities/user.entity'; +import { User, UserRole, UserStatus } from './entities/user.entity'; import { CreateUserDto } from './dto/create-user.dto'; import { UpdateUserDto } from './dto/update-user.dto'; import * as bcrypt from 'bcryptjs'; import { ensureUserExists, ensureUserDoesNotExist } from '../common/utils/user.utils'; import { paginate, PaginatedResponse } from '../common/utils/pagination.util'; import { PaginationQueryDto } from '../common/dto/pagination.dto'; +import { sanitizeSqlLike, enforceWhitelistedValue } from '../common/utils/sanitization.utils'; import { GetUsersDto } from './dto/get-users.dto'; import { CachingService } from '../caching/caching.service'; import { CACHE_TTL, CACHE_PREFIXES, CACHE_EVENTS } from '../caching/caching.constants'; @@ -50,17 +51,24 @@ export class UsersService { const query = this.userRepository.createQueryBuilder('user'); if (filter?.role) { - query.andWhere('user.role = :role', { role: filter.role }); + const role = enforceWhitelistedValue(filter.role, Object.values(UserRole), 'role'); + query.andWhere('user.role = :role', { role }); } if (filter?.status) { - query.andWhere('user.status = :status', { status: filter.status }); + const status = enforceWhitelistedValue( + filter.status, + Object.values(UserStatus), + 'status', + ); + query.andWhere('user.status = :status', { status }); } if (filter?.search) { + const safeSearch = sanitizeSqlLike(filter.search); query.andWhere( - '(user.email ILIKE :search OR user.firstName ILIKE :search OR user.lastName ILIKE :search)', - { search: `%${filter.search}%` }, + "(user.email ILIKE :search ESCAPE '\\' OR user.firstName ILIKE :search ESCAPE '\\' OR user.lastName ILIKE :search ESCAPE '\\')", + { search: `%${safeSearch}%` }, ); } @@ -115,9 +123,24 @@ export class UsersService { async update(id: string, updateUserDto: UpdateUserDto): Promise { const user = await this.findUserOrThrow(id); - // If updating password, hash it if (updateUserDto.password) { - updateUserDto.password = await bcrypt.hash(updateUserDto.password, 10); + const plainPassword = updateUserDto.password; + + if (await bcrypt.compare(plainPassword, user.password)) { + throw new BadRequestException('New password must be different from the current password'); + } + + const recentPasswords = user.passwordHistory ?? []; + for (const oldHash of recentPasswords.slice(-5)) { + if (await bcrypt.compare(plainPassword, oldHash)) { + throw new BadRequestException('New password must not match your last 5 passwords'); + } + } + + // Append current, maintain last 5 entries + user.passwordHistory = [...recentPasswords, user.password].slice(-5); + + updateUserDto.password = await bcrypt.hash(plainPassword, 10); } Object.assign(user, updateUserDto);