From dd3dbf6c59defebce831bd607ed1870fd25ed798 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Sun, 26 Apr 2026 02:30:05 +0000 Subject: [PATCH] fix(chat): ack unread pages --- internal/server/server.go | 39 +++++++++++++++++++--------------- internal/server/server_test.go | 11 ++++++---- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 55d80a2..7febfb4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -377,23 +377,13 @@ func (s *Server) MarkAsRead(ctx context.Context, req *chatv1.MarkAsReadRequest) } threadsCtx := identity.AppendToOutgoingContext(ctx, id) - messageIDs, err := s.listUnackedMessageIDs(threadsCtx, id.IdentityID, req.GetChatId()) - if err != nil { - return nil, mapThreadsError(err) - } - if len(messageIDs) == 0 { - return &chatv1.MarkAsReadResponse{ReadCount: 0}, nil - } - resp, err := s.threads.AckMessages(threadsCtx, &threadsv1.AckMessagesRequest{ - ParticipantId: id.IdentityID, - MessageIds: messageIDs, - }) + readCount, err := s.ackUnackedMessages(threadsCtx, id.IdentityID, req.GetChatId()) if err != nil { return nil, mapThreadsError(err) } return &chatv1.MarkAsReadResponse{ - ReadCount: resp.GetAckedCount(), + ReadCount: readCount, }, nil } @@ -426,8 +416,8 @@ func (s *Server) countUnread(ctx context.Context, participantID, chatID string) return count, nil } -func (s *Server) listUnackedMessageIDs(ctx context.Context, participantID, chatID string) ([]string, error) { - var messageIDs []string +func (s *Server) ackUnackedMessages(ctx context.Context, participantID, chatID string) (int32, error) { + var readCount int32 var pageToken string threadID := chatID @@ -439,11 +429,26 @@ func (s *Server) listUnackedMessageIDs(ctx context.Context, participantID, chatI PageToken: pageToken, }) if err != nil { - return nil, err + return 0, err } + messageIDs := make([]string, 0, len(resp.GetMessages())) for _, message := range resp.GetMessages() { - messageIDs = append(messageIDs, message.GetId()) + messageID := message.GetId() + if messageID == "" { + return 0, status.Error(codes.Internal, "threads returned message without id") + } + messageIDs = append(messageIDs, messageID) + } + if len(messageIDs) > 0 { + ackResp, err := s.threads.AckMessages(ctx, &threadsv1.AckMessagesRequest{ + ParticipantId: participantID, + MessageIds: messageIDs, + }) + if err != nil { + return 0, err + } + readCount += ackResp.GetAckedCount() } if resp.GetNextPageToken() == "" { @@ -452,7 +457,7 @@ func (s *Server) listUnackedMessageIDs(ctx context.Context, participantID, chatI pageToken = resp.GetNextPageToken() } - return messageIDs, nil + return readCount, nil } func mapThreadsError(err error) error { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 31e510b..74b1dbc 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1215,6 +1215,7 @@ func TestMarkAsReadValidation(t *testing.T) { func TestMarkAsReadDelegates(t *testing.T) { ctx := contextWithIdentity("user-1") pageTokens := []string{} + ackRequests := [][]string{} threads := &mockThreadsClient{ getUnackedMessagesFunc: func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error) { requireOutgoingIdentity(t, ctx, "user-1", "user") @@ -1246,10 +1247,9 @@ func TestMarkAsReadDelegates(t *testing.T) { if req.GetParticipantId() != "user-1" { return nil, status.Errorf(codes.InvalidArgument, "unexpected participant %q", req.GetParticipantId()) } - if !reflect.DeepEqual(req.GetMessageIds(), []string{"msg-1", "msg-2", "msg-3"}) { - return nil, status.Errorf(codes.InvalidArgument, "unexpected message ids %v", req.GetMessageIds()) - } - return &threadsv1.AckMessagesResponse{AckedCount: 3}, nil + messageIDs := append([]string(nil), req.GetMessageIds()...) + ackRequests = append(ackRequests, messageIDs) + return &threadsv1.AckMessagesResponse{AckedCount: int32(len(req.GetMessageIds()))}, nil }, } @@ -1264,6 +1264,9 @@ func TestMarkAsReadDelegates(t *testing.T) { if !reflect.DeepEqual(pageTokens, []string{"", "page-2"}) { t.Fatalf("unexpected page tokens %v", pageTokens) } + if !reflect.DeepEqual(ackRequests, [][]string{{"msg-1", "msg-2"}, {"msg-3"}}) { + t.Fatalf("unexpected ack requests %v", ackRequests) + } } func TestThreadToChat(t *testing.T) {