From ab470293324d9ef67893ac14dba980b0dcb1896e Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Mon, 20 Apr 2026 00:16:23 +0000 Subject: [PATCH 1/2] feat(chat): forward threads metadata --- internal/identity/identity.go | 10 ++++ internal/server/server.go | 41 ++++++++------ internal/server/server_test.go | 97 +++++++++++++++++++++++++++++----- 3 files changed, 121 insertions(+), 27 deletions(-) diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 67cebf1..81d47fe 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -39,6 +39,16 @@ func FromContext(ctx context.Context) (Identity, error) { }, nil } +func AppendToOutgoingContext(ctx context.Context, identity Identity) context.Context { + return metadata.AppendToOutgoingContext( + ctx, + MetadataKeyIdentityID, + identity.IdentityID, + MetadataKeyIdentityType, + identity.IdentityType, + ) +} + func singleValue(md metadata.MD, key string) (string, error) { values := md.Get(key) if len(values) == 0 { diff --git a/internal/server/server.go b/internal/server/server.go index c86fdb2..352c8f2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -44,6 +44,7 @@ func (s *Server) CreateChat(ctx context.Context, req *chatv1.CreateChatRequest) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "identity: %v", err) } + threadsCtx := identity.AppendToOutgoingContext(ctx, id) organizationID, err := parseUUID(req.GetOrganizationId()) if err != nil { @@ -62,9 +63,17 @@ func (s *Server) CreateChat(ctx context.Context, req *chatv1.CreateChatRequest) } participantIDs = append(participantIDs, pid) } + participants := make([]*threadsv1.ParticipantIdentifier, len(participantIDs)) + for i, pid := range participantIDs { + participants[i] = &threadsv1.ParticipantIdentifier{ + Identifier: &threadsv1.ParticipantIdentifier_ParticipantId{ParticipantId: pid}, + } + } + orgIDValue := organizationID.String() - resp, err := s.threads.CreateThread(ctx, &threadsv1.CreateThreadRequest{ - ParticipantIds: participantIDs, + resp, err := s.threads.CreateThread(threadsCtx, &threadsv1.CreateThreadRequest{ + Participants: participants, + OrganizationId: &orgIDValue, }) if err != nil { return nil, mapThreadsError(err) @@ -95,6 +104,7 @@ func (s *Server) GetChats(ctx context.Context, req *chatv1.GetChatsRequest) (*ch if err != nil { return nil, status.Errorf(codes.Unauthenticated, "identity: %v", err) } + threadsCtx := identity.AppendToOutgoingContext(ctx, id) organizationID, err := parseUUID(req.GetOrganizationId()) if err != nil { @@ -130,7 +140,7 @@ func (s *Server) GetChats(ctx context.Context, req *chatv1.GetChatsRequest) (*ch threadIDs[i] = chat.ThreadID } - threadsByID, err := s.fetchThreads(ctx, id.IdentityID, threadIDs) + threadsByID, err := s.fetchThreads(threadsCtx, id.IdentityID, threadIDs) if err != nil { return nil, mapThreadsError(err) } @@ -192,7 +202,8 @@ func (s *Server) UpdateChat(ctx context.Context, req *chatv1.UpdateChatRequest) } } - threadsByID, err := s.fetchThreads(ctx, id.IdentityID, []uuid.UUID{threadID}) + threadsCtx := identity.AppendToOutgoingContext(ctx, id) + threadsByID, err := s.fetchThreads(threadsCtx, id.IdentityID, []uuid.UUID{threadID}) if err != nil { return nil, mapThreadsError(err) } @@ -221,7 +232,8 @@ func (s *Server) GetMessages(ctx context.Context, req *chatv1.GetMessagesRequest return nil, status.Error(codes.InvalidArgument, "chat_id is required") } - msgResp, err := s.threads.GetMessages(ctx, &threadsv1.GetMessagesRequest{ + threadsCtx := identity.AppendToOutgoingContext(ctx, id) + msgResp, err := s.threads.GetMessages(threadsCtx, &threadsv1.GetMessagesRequest{ ThreadId: req.GetChatId(), PageSize: req.GetPageSize(), PageToken: req.GetPageToken(), @@ -230,7 +242,7 @@ func (s *Server) GetMessages(ctx context.Context, req *chatv1.GetMessagesRequest return nil, mapThreadsError(err) } - unreadCount, err := s.countUnread(ctx, id.IdentityID, req.GetChatId()) + unreadCount, err := s.countUnread(threadsCtx, id.IdentityID, req.GetChatId()) if err != nil { return nil, mapThreadsError(err) } @@ -260,7 +272,8 @@ func (s *Server) SendMessage(ctx context.Context, req *chatv1.SendMessageRequest return nil, status.Error(codes.InvalidArgument, "body or file_ids must be provided") } - resp, err := s.threads.SendMessage(ctx, &threadsv1.SendMessageRequest{ + threadsCtx := identity.AppendToOutgoingContext(ctx, id) + resp, err := s.threads.SendMessage(threadsCtx, &threadsv1.SendMessageRequest{ ThreadId: req.GetChatId(), SenderId: id.IdentityID, Body: req.GetBody(), @@ -288,7 +301,8 @@ func (s *Server) MarkAsRead(ctx context.Context, req *chatv1.MarkAsReadRequest) return nil, status.Error(codes.InvalidArgument, "message_ids must not be empty") } - resp, err := s.threads.AckMessages(ctx, &threadsv1.AckMessagesRequest{ + threadsCtx := identity.AppendToOutgoingContext(ctx, id) + resp, err := s.threads.AckMessages(threadsCtx, &threadsv1.AckMessagesRequest{ ParticipantId: id.IdentityID, MessageIds: req.GetMessageIds(), }) @@ -305,12 +319,11 @@ func (s *Server) countUnread(ctx context.Context, participantID, chatID string) var count int32 var pageToken string - // TODO: Threads.GetUnackedMessages lacks a thread-scoped filter, so we scan - // all unacked messages across chats; a thread filter upstream would avoid - // this full scan. + threadID := chatID for { resp, err := s.threads.GetUnackedMessages(ctx, &threadsv1.GetUnackedMessagesRequest{ ParticipantId: participantID, + ThreadId: &threadID, PageSize: unackedPageSize, PageToken: pageToken, }) @@ -318,10 +331,8 @@ func (s *Server) countUnread(ctx context.Context, participantID, chatID string) return 0, err } - for _, msg := range resp.GetMessages() { - if msg.GetThreadId() == chatID { - count++ - } + for range resp.GetMessages() { + count++ } if resp.GetNextPageToken() == "" { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index c9d8f94..179c467 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -20,14 +20,16 @@ import ( ) type mockThreadsClient struct { - createThreadFunc func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) - archiveThreadFunc func(ctx context.Context, req *threadsv1.ArchiveThreadRequest, opts ...grpc.CallOption) (*threadsv1.ArchiveThreadResponse, error) - addParticipantFunc func(ctx context.Context, req *threadsv1.AddParticipantRequest, opts ...grpc.CallOption) (*threadsv1.AddParticipantResponse, error) - sendMessageFunc func(ctx context.Context, req *threadsv1.SendMessageRequest, opts ...grpc.CallOption) (*threadsv1.SendMessageResponse, error) - getThreadsFunc func(ctx context.Context, req *threadsv1.GetThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadsResponse, error) - getMessagesFunc func(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error) - getUnackedMessagesFunc func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error) - ackMessagesFunc func(ctx context.Context, req *threadsv1.AckMessagesRequest, opts ...grpc.CallOption) (*threadsv1.AckMessagesResponse, error) + createThreadFunc func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) + archiveThreadFunc func(ctx context.Context, req *threadsv1.ArchiveThreadRequest, opts ...grpc.CallOption) (*threadsv1.ArchiveThreadResponse, error) + addParticipantFunc func(ctx context.Context, req *threadsv1.AddParticipantRequest, opts ...grpc.CallOption) (*threadsv1.AddParticipantResponse, error) + sendMessageFunc func(ctx context.Context, req *threadsv1.SendMessageRequest, opts ...grpc.CallOption) (*threadsv1.SendMessageResponse, error) + getThreadsFunc func(ctx context.Context, req *threadsv1.GetThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadsResponse, error) + getOrganizationThreadsFunc func(ctx context.Context, req *threadsv1.GetOrganizationThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetOrganizationThreadsResponse, error) + getThreadFunc func(ctx context.Context, req *threadsv1.GetThreadRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadResponse, error) + getMessagesFunc func(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error) + getUnackedMessagesFunc func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error) + ackMessagesFunc func(ctx context.Context, req *threadsv1.AckMessagesRequest, opts ...grpc.CallOption) (*threadsv1.AckMessagesResponse, error) } func (m *mockThreadsClient) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) { @@ -65,6 +67,20 @@ func (m *mockThreadsClient) GetThreads(ctx context.Context, req *threadsv1.GetTh return m.getThreadsFunc(ctx, req, opts...) } +func (m *mockThreadsClient) GetOrganizationThreads(ctx context.Context, req *threadsv1.GetOrganizationThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetOrganizationThreadsResponse, error) { + if m.getOrganizationThreadsFunc == nil { + return nil, unexpectedCall("GetOrganizationThreads") + } + return m.getOrganizationThreadsFunc(ctx, req, opts...) +} + +func (m *mockThreadsClient) GetThread(ctx context.Context, req *threadsv1.GetThreadRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadResponse, error) { + if m.getThreadFunc == nil { + return nil, unexpectedCall("GetThread") + } + return m.getThreadFunc(ctx, req, opts...) +} + func (m *mockThreadsClient) GetMessages(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error) { if m.getMessagesFunc == nil { return nil, unexpectedCall("GetMessages") @@ -137,6 +153,20 @@ func contextWithIdentity(identityID string) context.Context { return metadata.NewIncomingContext(context.Background(), md) } +func requireOutgoingIdentity(t *testing.T, ctx context.Context, identityID, identityType string) { + t.Helper() + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("expected outgoing metadata") + } + if got := md.Get(identity.MetadataKeyIdentityID); len(got) != 1 || got[0] != identityID { + t.Fatalf("expected outgoing identity id %q, got %v", identityID, got) + } + if got := md.Get(identity.MetadataKeyIdentityType); len(got) != 1 || got[0] != identityType { + t.Fatalf("expected outgoing identity type %q, got %v", identityType, got) + } +} + func requireStatusCode(t *testing.T, err error, code codes.Code) { t.Helper() if err == nil { @@ -194,6 +224,7 @@ func TestCreateChatDeduplicatesParticipants(t *testing.T) { var storedOrgID uuid.UUID threads := &mockThreadsClient{ createThreadFunc: func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") gotRequest = req return &threadsv1.CreateThreadResponse{Thread: thread}, nil }, @@ -215,8 +246,27 @@ func TestCreateChatDeduplicatesParticipants(t *testing.T) { t.Fatalf("CreateThread was not called") } expectedParticipants := []string{"user-1", "user-2", "user-3"} - if !reflect.DeepEqual(gotRequest.ParticipantIds, expectedParticipants) { - t.Fatalf("expected participants %v, got %v", expectedParticipants, gotRequest.ParticipantIds) + if len(gotRequest.GetParticipantIds()) != 0 { + t.Fatalf("expected participant_ids to be empty, got %v", gotRequest.GetParticipantIds()) + } + if gotRequest.OrganizationId == nil { + t.Fatalf("expected organization_id to be set") + } + if gotRequest.GetOrganizationId() != orgID.String() { + t.Fatalf("expected organization id %q, got %q", orgID, gotRequest.GetOrganizationId()) + } + gotParticipants := make([]string, len(gotRequest.GetParticipants())) + for i, participant := range gotRequest.GetParticipants() { + if participant == nil { + t.Fatalf("expected participant %d to be set", i) + } + gotParticipants[i] = participant.GetParticipantId() + if gotParticipants[i] == "" { + t.Fatalf("expected participant %d to have id", i) + } + } + if !reflect.DeepEqual(gotParticipants, expectedParticipants) { + t.Fatalf("expected participants %v, got %v", expectedParticipants, gotParticipants) } if resp.GetChat().GetId() != thread.GetId() { t.Fatalf("expected chat id %q, got %q", thread.GetId(), resp.GetChat().GetId()) @@ -243,6 +293,19 @@ func TestCreateChatReturnsThreadWhenStoreFails(t *testing.T) { var storedThreadID uuid.UUID threads := &mockThreadsClient{ createThreadFunc: func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") + if req.OrganizationId == nil { + return nil, status.Error(codes.InvalidArgument, "organization_id missing") + } + if req.GetOrganizationId() != orgID.String() { + return nil, status.Errorf(codes.InvalidArgument, "unexpected organization id %q", req.GetOrganizationId()) + } + if len(req.GetParticipantIds()) != 0 { + return nil, status.Errorf(codes.InvalidArgument, "unexpected participant_ids %v", req.GetParticipantIds()) + } + if len(req.GetParticipants()) != 2 { + return nil, status.Errorf(codes.InvalidArgument, "unexpected participants %v", req.GetParticipants()) + } return &threadsv1.CreateThreadResponse{Thread: thread}, nil }, } @@ -318,6 +381,7 @@ func TestGetChatsUsesStoreAndThreads(t *testing.T) { threads := &mockThreadsClient{ getThreadsFunc: func(ctx context.Context, req *threadsv1.GetThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadsResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") if req.GetParticipantId() != "user-1" { return nil, status.Errorf(codes.InvalidArgument, "unexpected participant %q", req.GetParticipantId()) } @@ -750,13 +814,21 @@ func TestGetMessagesAggregatesUnread(t *testing.T) { var gotPageTokens []string threads := &mockThreadsClient{ getMessagesFunc: func(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") gotMessagesReq = req return &threadsv1.GetMessagesResponse{Messages: threadsMessages, NextPageToken: "next-token"}, nil }, getUnackedMessagesFunc: func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") if req.GetParticipantId() != "user-1" { t.Fatalf("expected participant user-1, got %q", req.GetParticipantId()) } + if req.ThreadId == nil { + t.Fatal("expected thread_id to be set") + } + if req.GetThreadId() != chatID { + t.Fatalf("expected thread id %q, got %q", chatID, req.GetThreadId()) + } if req.GetPageSize() != unackedPageSize { t.Fatalf("expected page size %d, got %d", unackedPageSize, req.GetPageSize()) } @@ -766,7 +838,7 @@ func TestGetMessagesAggregatesUnread(t *testing.T) { return &threadsv1.GetUnackedMessagesResponse{ Messages: []*threadsv1.Message{ {Id: "u1", ThreadId: chatID}, - {Id: "u2", ThreadId: "chat-2"}, + {Id: "u2", ThreadId: chatID}, }, NextPageToken: "page-2", }, nil @@ -774,7 +846,6 @@ func TestGetMessagesAggregatesUnread(t *testing.T) { return &threadsv1.GetUnackedMessagesResponse{ Messages: []*threadsv1.Message{ {Id: "u3", ThreadId: chatID}, - {Id: "u4", ThreadId: chatID}, }, }, nil default: @@ -834,6 +905,7 @@ func TestSendMessageDelegates(t *testing.T) { msgTime := time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC) threads := &mockThreadsClient{ sendMessageFunc: func(ctx context.Context, req *threadsv1.SendMessageRequest, opts ...grpc.CallOption) (*threadsv1.SendMessageResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") if req.GetThreadId() != "chat-1" { return nil, status.Errorf(codes.InvalidArgument, "unexpected thread id %q", req.GetThreadId()) } @@ -886,6 +958,7 @@ func TestMarkAsReadDelegates(t *testing.T) { ctx := contextWithIdentity("user-1") threads := &mockThreadsClient{ ackMessagesFunc: func(ctx context.Context, req *threadsv1.AckMessagesRequest, opts ...grpc.CallOption) (*threadsv1.AckMessagesResponse, error) { + requireOutgoingIdentity(t, ctx, "user-1", "user") if req.GetParticipantId() != "user-1" { return nil, status.Errorf(codes.InvalidArgument, "unexpected participant %q", req.GetParticipantId()) } From 8a751c8aea9438f1a902c733730700bbbf2e50f1 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Mon, 20 Apr 2026 00:29:51 +0000 Subject: [PATCH 2/2] fix(chat): omit initiator participants --- internal/server/server.go | 3 +-- internal/server/server_test.go | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 352c8f2..6194f99 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -55,8 +55,7 @@ func (s *Server) CreateChat(ctx context.Context, req *chatv1.CreateChatRequest) return nil, status.Error(codes.InvalidArgument, "participant_ids must not be empty") } - participantIDs := make([]string, 0, len(req.GetParticipantIds())+1) - participantIDs = append(participantIDs, id.IdentityID) + participantIDs := make([]string, 0, len(req.GetParticipantIds())) for _, pid := range req.GetParticipantIds() { if pid == id.IdentityID { continue diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 179c467..fdd6900 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -245,7 +245,7 @@ func TestCreateChatDeduplicatesParticipants(t *testing.T) { if gotRequest == nil { t.Fatalf("CreateThread was not called") } - expectedParticipants := []string{"user-1", "user-2", "user-3"} + expectedParticipants := []string{"user-2", "user-3"} if len(gotRequest.GetParticipantIds()) != 0 { t.Fatalf("expected participant_ids to be empty, got %v", gotRequest.GetParticipantIds()) } @@ -303,7 +303,7 @@ func TestCreateChatReturnsThreadWhenStoreFails(t *testing.T) { if len(req.GetParticipantIds()) != 0 { return nil, status.Errorf(codes.InvalidArgument, "unexpected participant_ids %v", req.GetParticipantIds()) } - if len(req.GetParticipants()) != 2 { + if len(req.GetParticipants()) != 1 { return nil, status.Errorf(codes.InvalidArgument, "unexpected participants %v", req.GetParticipants()) } return &threadsv1.CreateThreadResponse{Thread: thread}, nil