Skip to content

Commit 628589c

Browse files
committed
fix: address review comments on websocket package
- Make Hub.Shutdown() idempotent via sync.Once to prevent double-close panic - Add Hub.Register()/Unregister() methods that select on hub.done to prevent goroutine leaks and deadlocks during shutdown - Update NewClient to use Hub.Register(), returning error on closed hub - Update readPump to use Hub.Unregister() instead of raw channel send - Change NewMessage signature to (Message, error) to surface marshal failures - Change Message.Bytes() signature to ([]byte, error) for explicit error handling - Replace time.Sleep in tests with assert.Eventually polling - Add tests for idempotent shutdown and register-after-shutdown
1 parent ac4a608 commit 628589c

4 files changed

Lines changed: 61 additions & 17 deletions

File tree

backend/internal/websocket/client.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,20 @@ type Client struct {
3434
// NewClient creates a new Client attached to the given hub and connection,
3535
// registers it with the hub, and starts the read/write pumps.
3636
// The caller should not interact with conn after calling NewClient.
37-
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
37+
// Returns an error if the hub has already been shut down.
38+
func NewClient(hub *Hub, conn *websocket.Conn) (*Client, error) {
3839
client := &Client{
3940
hub: hub,
4041
conn: conn,
4142
send: make(chan []byte, sendBufferSize),
4243
}
43-
hub.register <- client
44+
if err := hub.Register(client); err != nil {
45+
conn.Close()
46+
return nil, err
47+
}
4448
go client.writePump()
4549
go client.readPump()
46-
return client
50+
return client, nil
4751
}
4852

4953
// readPump pumps messages from the WebSocket connection to the hub.
@@ -54,7 +58,7 @@ func (c *Client) readPump() {
5458
if r := recover(); r != nil {
5559
slog.Error("Panic in WebSocket readPump", "recover", r)
5660
}
57-
c.hub.unregister <- c
61+
c.hub.Unregister(c)
5862
c.conn.Close()
5963
}()
6064

backend/internal/websocket/hub.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ import (
77
"sync"
88
)
99

10+
// ErrHubClosed is returned when attempting to register a client on a shut-down hub.
11+
var ErrHubClosed = errHubClosed{}
12+
13+
type errHubClosed struct{}
14+
15+
func (errHubClosed) Error() string { return "hub is closed" }
16+
1017
// BroadcastSender is implemented by any type that can broadcast messages
1118
// to all connected WebSocket clients. Use this interface for decoupled
1219
// dependency injection (e.g., handlers broadcast events without importing Hub).
@@ -34,6 +41,9 @@ type Hub struct {
3441

3542
// done signals the Run loop to stop.
3643
done chan struct{}
44+
45+
// shutdownOnce ensures Shutdown is idempotent and safe to call concurrently.
46+
shutdownOnce sync.Once
3747
}
3848

3949
// NewHub creates a new Hub ready to accept clients.
@@ -99,8 +109,31 @@ func (h *Hub) Broadcast(message []byte) {
99109
}
100110

101111
// Shutdown gracefully stops the hub's Run loop and closes all client connections.
112+
// It is safe to call multiple times and concurrently.
102113
func (h *Hub) Shutdown() {
103-
close(h.done)
114+
h.shutdownOnce.Do(func() {
115+
close(h.done)
116+
})
117+
}
118+
119+
// Register safely registers a client with the hub. It returns ErrHubClosed if
120+
// the hub has been shut down, preventing the caller from blocking forever.
121+
func (h *Hub) Register(c *Client) error {
122+
select {
123+
case h.register <- c:
124+
return nil
125+
case <-h.done:
126+
return ErrHubClosed
127+
}
128+
}
129+
130+
// Unregister safely requests client removal from the hub. If the hub has
131+
// already been shut down the call is a no-op, preventing goroutine leaks.
132+
func (h *Hub) Unregister(c *Client) {
133+
select {
134+
case h.unregister <- c:
135+
case <-h.done:
136+
}
104137
}
105138

106139
// ClientCount returns the current number of connected clients.

backend/internal/websocket/hub_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ func TestNewMessage(t *testing.T) {
185185
t.Run(tt.name, func(t *testing.T) {
186186
t.Parallel()
187187

188-
msg := NewMessage(tt.msgType, tt.payload)
188+
msg, err := NewMessage(tt.msgType, tt.payload)
189+
assert.NoError(t, err)
189190
assert.Equal(t, tt.wantType, msg.Type)
190191
assert.JSONEq(t, tt.wantPayload, string(msg.Payload))
191192
})
@@ -195,11 +196,14 @@ func TestNewMessage(t *testing.T) {
195196
func TestMessage_Bytes(t *testing.T) {
196197
t.Parallel()
197198

198-
msg := NewMessage("item.updated", map[string]int{"id": 1})
199-
b := msg.Bytes()
199+
msg, err := NewMessage("item.updated", map[string]int{"id": 1})
200+
assert.NoError(t, err)
201+
202+
b, err := msg.Bytes()
203+
assert.NoError(t, err)
200204

201205
var parsed Message
202-
err := json.Unmarshal(b, &parsed)
206+
err = json.Unmarshal(b, &parsed)
203207
assert.NoError(t, err)
204208
assert.Equal(t, "item.updated", parsed.Type)
205209
assert.JSONEq(t, `{"id":1}`, string(parsed.Payload))
Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package websocket
22

3-
import "encoding/json"
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
47

58
// Message is the envelope for all WebSocket messages sent to clients.
69
type Message struct {
@@ -12,23 +15,23 @@ type Message struct {
1215
}
1316

1417
// NewMessage creates a Message with the given type and payload.
15-
// The payload is JSON-marshalled; if marshalling fails the payload is set to null.
16-
func NewMessage(msgType string, payload interface{}) Message {
18+
// The payload is JSON-marshalled; an error is returned if marshalling fails.
19+
func NewMessage(msgType string, payload interface{}) (Message, error) {
1720
data, err := json.Marshal(payload)
1821
if err != nil {
19-
data = []byte("null")
22+
return Message{}, fmt.Errorf("marshal payload: %w", err)
2023
}
2124
return Message{
2225
Type: msgType,
2326
Payload: data,
24-
}
27+
}, nil
2528
}
2629

2730
// Bytes serialises the Message to JSON bytes suitable for broadcasting.
28-
func (m Message) Bytes() []byte {
31+
func (m Message) Bytes() ([]byte, error) {
2932
b, err := json.Marshal(m)
3033
if err != nil {
31-
return []byte(`{"type":"error","payload":null}`)
34+
return nil, fmt.Errorf("marshal message: %w", err)
3235
}
33-
return b
36+
return b, nil
3437
}

0 commit comments

Comments
 (0)