Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ Base Path: `/api/auth`
* `GET /me`: Current authenticated user details. (Requires `Authorization: Bearer <JWT_TOKEN>`)
* `POST /set-username`: Set username (e.g., after Google OAuth).
* `GET /check-username/{username}`: Check username availability.
* `POST /ws-ticket`: Generate a one-time-use WebSocket connection ticket. (Requires `Authorization: Bearer <JWT_TOKEN>`)

### Chat
Base Path: `/api`
Expand All @@ -212,7 +213,7 @@ Base Path: `/api`
### WebSocket
* **Endpoint:** `/ws/chat`
* **Connection:**
* **Authenticated Users:** `ws://localhost:8080/ws/chat?token=<YOUR_JWT_TOKEN>`
* **Authenticated Users:** Use a one-time ticket (see secure connection flow below)
* **Anonymous Users:** Connect with `anonSessionId` in a cookie.
* **Message Payload:** `WebSocketMessagePayload` JSON
```json
Expand All @@ -226,18 +227,47 @@ Base Path: `/api`
}
```

#### Secure WebSocket Connection Flow (Authenticated Users)

For enhanced security, authenticated users should follow this two-step process:

1. **Get a One-Time Ticket:**
```http
POST /api/auth/ws-ticket
Authorization: Bearer <YOUR_JWT_TOKEN>
```
Response:
```json
{
"success": true,
"message": "WebSocket ticket generated successfully",
"ticket": "550e8400-e29b-41d4-a716-446655440000"
}
```

2. **Connect to WebSocket:**
```
ws://localhost:8080/ws/chat?ticket=<TICKET>
```

**Important Notes:**
- Tickets are valid for **30 seconds** only
- Tickets can only be used **once** (immediately invalidated upon use)
- This prevents the security risks of passing long-lived JWTs in URLs (logging, browser history, etc.)

---

## Usage

1. **Anonymous User:**
* Frontend generates a unique `anonSessionId`.
* Connect to `/ws/chat` with `anonSessionId` (e.g., via cookie).
* Connect to `/ws/chat` with `anonSessionId` in a cookie.
* Send messages with `type: "ANON_TO_USER"`.

2. **Registered User:**
* Register/Login to get a JWT.
* Connect to `/ws/chat?token=<JWT>`.
* Request a WebSocket ticket via `POST /api/auth/ws-ticket` with your JWT.
* Connect to `/ws/chat?ticket=<TICKET>` (use ticket within 30 seconds).
* Fetch sessions via `GET /api/chats`.
* Reply via WebSocket with `type: "USER_TO_ANON"`.
* View history/mark read via `GET /api/chat/session_history`.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.fredmaina.chatapp.Auth.Dtos;

import lombok.*;

@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class WebSocketTicketResponse {
private boolean success;
private String message;
private String ticket;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.fredmaina.chatapp.Auth.Repositories.UserRepository;
import com.fredmaina.chatapp.Auth.services.AuthService;
import com.fredmaina.chatapp.Auth.services.JWTService;
import com.fredmaina.chatapp.Auth.services.WebSocketTicketService;
import jakarta.validation.Valid;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -27,6 +28,8 @@ public class AuthController {
JWTService jwtService;
@Autowired
UserRepository userRepository;
@Autowired
WebSocketTicketService webSocketTicketService;


@PostMapping("/login")
Expand Down Expand Up @@ -128,4 +131,56 @@ public ResponseEntity<Map<String, Object>> checkUsername(@PathVariable String us
));
}
}

/**
* Generates a one-time-use WebSocket connection ticket.
* This endpoint requires authentication via Authorization header.
* The ticket is valid for 30 seconds and can only be used once.
*/
@PostMapping("/ws-ticket")
public ResponseEntity<WebSocketTicketResponse> generateWebSocketTicket(
@RequestHeader("Authorization") String authHeader) {
if (authHeader == null || !authHeader.startsWith("Bearer ")) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(WebSocketTicketResponse.builder()
.success(false)
.message("Missing or invalid Authorization header")
.build());
}

String token = authHeader.replace("Bearer ", "");
String email;
try {
email = jwtService.getUsernameFromToken(token);
if (!jwtService.isTokenValid(token)) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(WebSocketTicketResponse.builder()
.success(false)
.message("Invalid or expired token")
.build());
}
} catch (Exception e) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.body(WebSocketTicketResponse.builder()
.success(false)
.message("Invalid token")
.build());
}

User user = userRepository.findByEmail(email).orElse(null);
if (user == null) {
return ResponseEntity.status(HttpStatus.NOT_FOUND)
.body(WebSocketTicketResponse.builder()
.success(false)
.message("User not found")
.build());
}

String ticket = webSocketTicketService.generateTicket(email);
return ResponseEntity.ok(WebSocketTicketResponse.builder()
.success(true)
.message("WebSocket ticket generated successfully")
.ticket(ticket)
.build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.fredmaina.chatapp.Auth.services;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;

import java.util.UUID;
import java.util.concurrent.TimeUnit;

/**
* Service for managing one-time-use WebSocket connection tickets.
* Tickets are stored in Redis with a short TTL (30 seconds) and are
* immediately evicted upon validation to prevent reuse.
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class WebSocketTicketService {

private static final String TICKET_PREFIX = "ws-ticket:";
private static final long TICKET_TTL_SECONDS = 30;

private final RedisTemplate<String, Object> redisTemplate;

/**
* Generates a one-time-use ticket for WebSocket connection.
*
* @param userEmail the email of the authenticated user
* @return a unique ticket string
*/
public String generateTicket(String userEmail) {
String ticket = UUID.randomUUID().toString();
String key = TICKET_PREFIX + ticket;

redisTemplate.opsForValue().set(key, userEmail, TICKET_TTL_SECONDS, TimeUnit.SECONDS);
log.info("Generated WebSocket ticket for user: {}", userEmail);

return ticket;
}

/**
* Validates and consumes a WebSocket ticket.
* The ticket is immediately evicted upon successful validation to prevent reuse.
*
* @param ticket the ticket to validate
* @return the user email if valid, null otherwise
*/
public String validateAndConsumeTicket(String ticket) {
if (ticket == null || ticket.isBlank()) {
return null;
}

String key = TICKET_PREFIX + ticket;
Object value = redisTemplate.opsForValue().get(key);

if (value != null) {
// Immediately evict the ticket to prevent reuse
Boolean deleted = redisTemplate.delete(key);
if (Boolean.TRUE.equals(deleted)) {
String userEmail = value.toString();
log.info("Validated and consumed WebSocket ticket for user: {}", userEmail);
return userEmail;
}
}

log.warn("Invalid or expired WebSocket ticket: {}", ticket);
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fredmaina.chatapp.Auth.Models.User;
import com.fredmaina.chatapp.Auth.Repositories.UserRepository;
import com.fredmaina.chatapp.Auth.services.JWTService;
import com.fredmaina.chatapp.Auth.services.WebSocketTicketService;
import com.fredmaina.chatapp.core.DTOs.WebSocketMessagePayload;
import com.fredmaina.chatapp.core.Services.MessagingService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
Expand All @@ -28,29 +27,35 @@ public class ChatWebSocketHandler extends TextWebSocketHandler {

private final MessagingService messagingService;

private final JWTService jwtService;
private final WebSocketTicketService webSocketTicketService;

private final Map<String, WebSocketSession> userSessions = new ConcurrentHashMap<>();
private final Map<String, WebSocketSession> anonymousSessions = new ConcurrentHashMap<>();

// Store the authenticated email for each session (for message handling)
private final Map<String, String> sessionEmailMap = new ConcurrentHashMap<>();

private final ObjectMapper objectMapper = new ObjectMapper();
private final UserRepository userRepository;
private final UserRepository userRepository;

@Override
public void afterConnectionEstablished(WebSocketSession session) {
String email = extractUsernameFromJWT(session);
String email = extractUserFromTicket(session);
String anonId = extractAnonSessionId(session);
Optional<String> userName = userRepository.findByEmail(email).map(User::getUsername);
Optional<String> userName = email != null
? userRepository.findByEmail(email).map(User::getUsername)
: Optional.empty();

if (userName.isPresent()) {
String user =userName.get();
String user = userName.get();
userSessions.put(user, session);
log.info("Authenticated user connected: {}, with nickname: {}", email,user);
sessionEmailMap.put(session.getId(), email);
log.info("Authenticated user connected: {}, with nickname: {}", email, user);
} else if (anonId != null) {
anonymousSessions.put(anonId, session);
log.info("Anonymous user connected: {}", anonId);
} else {
log.warn("Connection rejected: No valid token or session cookie.");
log.warn("Connection rejected: No valid ticket or session cookie.");
try {
session.close();
} catch (Exception ignored) {}
Expand Down Expand Up @@ -78,7 +83,8 @@ public void afterConnectionEstablished(WebSocketSession session) {
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
userSessions.values().remove(session);
anonymousSessions.values().remove(session);
log.info("Connection closed: {}, reason: {}", session.getId(),status.getReason());
sessionEmailMap.remove(session.getId());
log.info("Connection closed: {}, reason: {}", session.getId(), status.getReason());
}

@Override
Expand All @@ -94,10 +100,10 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
}
}
case USER_TO_ANON -> {
String username = extractUsernameFromJWT(session);
if (username != null) {
String email = sessionEmailMap.get(session.getId());
if (email != null) {
log.info(payload.toString());
messagingService.sendMessageFromUser(username, payload.getTo(), payload.getContent());
messagingService.sendMessageFromUser(email, payload.getTo(), payload.getContent());
}
}
case MARK_AS_READ -> {
Expand All @@ -108,56 +114,50 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message)
}
}

// Helper to extract JWT from query params
private String extractUsernameFromJWT(WebSocketSession session) {
/**
* Extracts user email from a one-time-use WebSocket ticket.
* The ticket is validated and immediately consumed (evicted from cache)
* to prevent reuse.
*/
private String extractUserFromTicket(WebSocketSession session) {
try {
URI uri = session.getUri();
if (uri != null && uri.getQuery() != null) {
String[] pairs = uri.getQuery().split("&");
for (String pair : pairs) {
String[] keyValue = pair.split("=");
if (keyValue.length == 2 && keyValue[0].equals("token")) {
String token = URLDecoder.decode(keyValue[1], StandardCharsets.UTF_8);
return jwtService.getUsernameFromToken(token);
if (keyValue.length == 2 && keyValue[0].equals("ticket")) {
String ticket = URLDecoder.decode(keyValue[1], StandardCharsets.UTF_8);
return webSocketTicketService.validateAndConsumeTicket(ticket);
}
}
}
} catch (Exception e) {
log.warn("JWT extraction failed: {}", e.getMessage());
log.warn("Ticket extraction failed: {}", e.getMessage());
}
return null;
}

// Helper to extract session ID from cookies
/**
* Extracts anonymous session ID from cookies only.
* Query parameter fallback has been removed for security reasons.
*/
private String extractAnonSessionId(WebSocketSession session) {
try {
// Try to get from cookies
// Get from cookies only (query param fallback removed for security)
List<String> cookies = session.getHandshakeHeaders().get("cookie");
if (cookies != null) {
for (String header : cookies) {
String[] parts = header.split(";");
for (String part : parts) {
String[] keyValue = part.trim().split("=");
if (keyValue.length == 2 && keyValue[0].equals("anonSessionId")) {
log.info("extracted form cookies");
log.info("extracted from cookies");
return keyValue[1];
}
}
}
}
// Fallback: Try to get from URI query param
URI uri = session.getUri();
if (uri != null && uri.getQuery() != null) {
String[] queryParams = uri.getQuery().split("&");
for (String param : queryParams) {
String[] keyValue = param.split("=");
if (keyValue.length == 2 && keyValue[0].equals("anonSessionId")) {
log.info("extracted from URI");
return keyValue[1];
}
}
}

} catch (Exception e) {
log.warn("Anon session extraction failed: {}", e.getMessage());
}
Expand Down
Loading