Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func FromContext(ctx context.Context) (Identity, error) {
}, nil
}

func AppendToOutgoingContext(ctx context.Context, identity Identity) context.Context {
return metadata.AppendToOutgoingContext(
ctx,
MetadataKeyIdentityID,
identity.IdentityID,
MetadataKeyIdentityType,
identity.IdentityType,
)
}

func singleValue(md metadata.MD, key string) (string, error) {
values := md.Get(key)
if len(values) == 0 {
Expand Down
44 changes: 27 additions & 17 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func (s *Server) CreateChat(ctx context.Context, req *chatv1.CreateChatRequest)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "identity: %v", err)
}
threadsCtx := identity.AppendToOutgoingContext(ctx, id)

organizationID, err := parseUUID(req.GetOrganizationId())
if err != nil {
Expand All @@ -54,17 +55,24 @@ func (s *Server) CreateChat(ctx context.Context, req *chatv1.CreateChatRequest)
return nil, status.Error(codes.InvalidArgument, "participant_ids must not be empty")
}

participantIDs := make([]string, 0, len(req.GetParticipantIds())+1)
participantIDs = append(participantIDs, id.IdentityID)
participantIDs := make([]string, 0, len(req.GetParticipantIds()))
for _, pid := range req.GetParticipantIds() {
if pid == id.IdentityID {
continue
}
participantIDs = append(participantIDs, pid)
}
participants := make([]*threadsv1.ParticipantIdentifier, len(participantIDs))
Comment thread
noa-lucent marked this conversation as resolved.
for i, pid := range participantIDs {
participants[i] = &threadsv1.ParticipantIdentifier{
Identifier: &threadsv1.ParticipantIdentifier_ParticipantId{ParticipantId: pid},
}
}
orgIDValue := organizationID.String()

resp, err := s.threads.CreateThread(ctx, &threadsv1.CreateThreadRequest{
ParticipantIds: participantIDs,
resp, err := s.threads.CreateThread(threadsCtx, &threadsv1.CreateThreadRequest{
Participants: participants,
OrganizationId: &orgIDValue,
})
if err != nil {
return nil, mapThreadsError(err)
Expand Down Expand Up @@ -95,6 +103,7 @@ func (s *Server) GetChats(ctx context.Context, req *chatv1.GetChatsRequest) (*ch
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "identity: %v", err)
}
threadsCtx := identity.AppendToOutgoingContext(ctx, id)

organizationID, err := parseUUID(req.GetOrganizationId())
if err != nil {
Expand Down Expand Up @@ -130,7 +139,7 @@ func (s *Server) GetChats(ctx context.Context, req *chatv1.GetChatsRequest) (*ch
threadIDs[i] = chat.ThreadID
}

threadsByID, err := s.fetchThreads(ctx, id.IdentityID, threadIDs)
threadsByID, err := s.fetchThreads(threadsCtx, id.IdentityID, threadIDs)
if err != nil {
return nil, mapThreadsError(err)
}
Expand Down Expand Up @@ -192,7 +201,8 @@ func (s *Server) UpdateChat(ctx context.Context, req *chatv1.UpdateChatRequest)
}
}

threadsByID, err := s.fetchThreads(ctx, id.IdentityID, []uuid.UUID{threadID})
threadsCtx := identity.AppendToOutgoingContext(ctx, id)
threadsByID, err := s.fetchThreads(threadsCtx, id.IdentityID, []uuid.UUID{threadID})
if err != nil {
return nil, mapThreadsError(err)
}
Expand Down Expand Up @@ -221,7 +231,8 @@ func (s *Server) GetMessages(ctx context.Context, req *chatv1.GetMessagesRequest
return nil, status.Error(codes.InvalidArgument, "chat_id is required")
}

msgResp, err := s.threads.GetMessages(ctx, &threadsv1.GetMessagesRequest{
threadsCtx := identity.AppendToOutgoingContext(ctx, id)
msgResp, err := s.threads.GetMessages(threadsCtx, &threadsv1.GetMessagesRequest{
ThreadId: req.GetChatId(),
PageSize: req.GetPageSize(),
PageToken: req.GetPageToken(),
Expand All @@ -230,7 +241,7 @@ func (s *Server) GetMessages(ctx context.Context, req *chatv1.GetMessagesRequest
return nil, mapThreadsError(err)
}

unreadCount, err := s.countUnread(ctx, id.IdentityID, req.GetChatId())
unreadCount, err := s.countUnread(threadsCtx, id.IdentityID, req.GetChatId())
if err != nil {
return nil, mapThreadsError(err)
}
Expand Down Expand Up @@ -260,7 +271,8 @@ func (s *Server) SendMessage(ctx context.Context, req *chatv1.SendMessageRequest
return nil, status.Error(codes.InvalidArgument, "body or file_ids must be provided")
}

resp, err := s.threads.SendMessage(ctx, &threadsv1.SendMessageRequest{
threadsCtx := identity.AppendToOutgoingContext(ctx, id)
resp, err := s.threads.SendMessage(threadsCtx, &threadsv1.SendMessageRequest{
ThreadId: req.GetChatId(),
SenderId: id.IdentityID,
Body: req.GetBody(),
Expand Down Expand Up @@ -288,7 +300,8 @@ func (s *Server) MarkAsRead(ctx context.Context, req *chatv1.MarkAsReadRequest)
return nil, status.Error(codes.InvalidArgument, "message_ids must not be empty")
}

resp, err := s.threads.AckMessages(ctx, &threadsv1.AckMessagesRequest{
threadsCtx := identity.AppendToOutgoingContext(ctx, id)
resp, err := s.threads.AckMessages(threadsCtx, &threadsv1.AckMessagesRequest{
ParticipantId: id.IdentityID,
MessageIds: req.GetMessageIds(),
})
Expand All @@ -305,23 +318,20 @@ func (s *Server) countUnread(ctx context.Context, participantID, chatID string)
var count int32
var pageToken string

// TODO: Threads.GetUnackedMessages lacks a thread-scoped filter, so we scan
// all unacked messages across chats; a thread filter upstream would avoid
// this full scan.
threadID := chatID
for {
resp, err := s.threads.GetUnackedMessages(ctx, &threadsv1.GetUnackedMessagesRequest{
ParticipantId: participantID,
ThreadId: &threadID,
PageSize: unackedPageSize,
PageToken: pageToken,
})
if err != nil {
return 0, err
}

for _, msg := range resp.GetMessages() {
if msg.GetThreadId() == chatID {
count++
}
for range resp.GetMessages() {
count++
}

if resp.GetNextPageToken() == "" {
Expand Down
99 changes: 86 additions & 13 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ import (
)

type mockThreadsClient struct {
createThreadFunc func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error)
archiveThreadFunc func(ctx context.Context, req *threadsv1.ArchiveThreadRequest, opts ...grpc.CallOption) (*threadsv1.ArchiveThreadResponse, error)
addParticipantFunc func(ctx context.Context, req *threadsv1.AddParticipantRequest, opts ...grpc.CallOption) (*threadsv1.AddParticipantResponse, error)
sendMessageFunc func(ctx context.Context, req *threadsv1.SendMessageRequest, opts ...grpc.CallOption) (*threadsv1.SendMessageResponse, error)
getThreadsFunc func(ctx context.Context, req *threadsv1.GetThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadsResponse, error)
getMessagesFunc func(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error)
getUnackedMessagesFunc func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error)
ackMessagesFunc func(ctx context.Context, req *threadsv1.AckMessagesRequest, opts ...grpc.CallOption) (*threadsv1.AckMessagesResponse, error)
createThreadFunc func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error)
archiveThreadFunc func(ctx context.Context, req *threadsv1.ArchiveThreadRequest, opts ...grpc.CallOption) (*threadsv1.ArchiveThreadResponse, error)
addParticipantFunc func(ctx context.Context, req *threadsv1.AddParticipantRequest, opts ...grpc.CallOption) (*threadsv1.AddParticipantResponse, error)
sendMessageFunc func(ctx context.Context, req *threadsv1.SendMessageRequest, opts ...grpc.CallOption) (*threadsv1.SendMessageResponse, error)
getThreadsFunc func(ctx context.Context, req *threadsv1.GetThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadsResponse, error)
getOrganizationThreadsFunc func(ctx context.Context, req *threadsv1.GetOrganizationThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetOrganizationThreadsResponse, error)
getThreadFunc func(ctx context.Context, req *threadsv1.GetThreadRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadResponse, error)
getMessagesFunc func(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error)
getUnackedMessagesFunc func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error)
ackMessagesFunc func(ctx context.Context, req *threadsv1.AckMessagesRequest, opts ...grpc.CallOption) (*threadsv1.AckMessagesResponse, error)
}

func (m *mockThreadsClient) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) {
Expand Down Expand Up @@ -65,6 +67,20 @@ func (m *mockThreadsClient) GetThreads(ctx context.Context, req *threadsv1.GetTh
return m.getThreadsFunc(ctx, req, opts...)
}

func (m *mockThreadsClient) GetOrganizationThreads(ctx context.Context, req *threadsv1.GetOrganizationThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetOrganizationThreadsResponse, error) {
if m.getOrganizationThreadsFunc == nil {
return nil, unexpectedCall("GetOrganizationThreads")
}
return m.getOrganizationThreadsFunc(ctx, req, opts...)
}

func (m *mockThreadsClient) GetThread(ctx context.Context, req *threadsv1.GetThreadRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadResponse, error) {
if m.getThreadFunc == nil {
return nil, unexpectedCall("GetThread")
}
return m.getThreadFunc(ctx, req, opts...)
}

func (m *mockThreadsClient) GetMessages(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error) {
if m.getMessagesFunc == nil {
return nil, unexpectedCall("GetMessages")
Expand Down Expand Up @@ -137,6 +153,20 @@ func contextWithIdentity(identityID string) context.Context {
return metadata.NewIncomingContext(context.Background(), md)
}

func requireOutgoingIdentity(t *testing.T, ctx context.Context, identityID, identityType string) {
t.Helper()
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
t.Fatal("expected outgoing metadata")
}
if got := md.Get(identity.MetadataKeyIdentityID); len(got) != 1 || got[0] != identityID {
t.Fatalf("expected outgoing identity id %q, got %v", identityID, got)
}
if got := md.Get(identity.MetadataKeyIdentityType); len(got) != 1 || got[0] != identityType {
t.Fatalf("expected outgoing identity type %q, got %v", identityType, got)
}
}

func requireStatusCode(t *testing.T, err error, code codes.Code) {
t.Helper()
if err == nil {
Expand Down Expand Up @@ -194,6 +224,7 @@ func TestCreateChatDeduplicatesParticipants(t *testing.T) {
var storedOrgID uuid.UUID
threads := &mockThreadsClient{
createThreadFunc: func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
gotRequest = req
return &threadsv1.CreateThreadResponse{Thread: thread}, nil
},
Expand All @@ -214,9 +245,28 @@ func TestCreateChatDeduplicatesParticipants(t *testing.T) {
if gotRequest == nil {
t.Fatalf("CreateThread was not called")
}
expectedParticipants := []string{"user-1", "user-2", "user-3"}
if !reflect.DeepEqual(gotRequest.ParticipantIds, expectedParticipants) {
t.Fatalf("expected participants %v, got %v", expectedParticipants, gotRequest.ParticipantIds)
expectedParticipants := []string{"user-2", "user-3"}
if len(gotRequest.GetParticipantIds()) != 0 {
t.Fatalf("expected participant_ids to be empty, got %v", gotRequest.GetParticipantIds())
}
if gotRequest.OrganizationId == nil {
t.Fatalf("expected organization_id to be set")
}
if gotRequest.GetOrganizationId() != orgID.String() {
t.Fatalf("expected organization id %q, got %q", orgID, gotRequest.GetOrganizationId())
}
gotParticipants := make([]string, len(gotRequest.GetParticipants()))
Comment thread
noa-lucent marked this conversation as resolved.
for i, participant := range gotRequest.GetParticipants() {
if participant == nil {
t.Fatalf("expected participant %d to be set", i)
}
gotParticipants[i] = participant.GetParticipantId()
if gotParticipants[i] == "" {
t.Fatalf("expected participant %d to have id", i)
}
}
if !reflect.DeepEqual(gotParticipants, expectedParticipants) {
t.Fatalf("expected participants %v, got %v", expectedParticipants, gotParticipants)
}
if resp.GetChat().GetId() != thread.GetId() {
t.Fatalf("expected chat id %q, got %q", thread.GetId(), resp.GetChat().GetId())
Expand All @@ -243,6 +293,19 @@ func TestCreateChatReturnsThreadWhenStoreFails(t *testing.T) {
var storedThreadID uuid.UUID
threads := &mockThreadsClient{
createThreadFunc: func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
if req.OrganizationId == nil {
return nil, status.Error(codes.InvalidArgument, "organization_id missing")
}
if req.GetOrganizationId() != orgID.String() {
return nil, status.Errorf(codes.InvalidArgument, "unexpected organization id %q", req.GetOrganizationId())
}
if len(req.GetParticipantIds()) != 0 {
return nil, status.Errorf(codes.InvalidArgument, "unexpected participant_ids %v", req.GetParticipantIds())
}
if len(req.GetParticipants()) != 1 {
return nil, status.Errorf(codes.InvalidArgument, "unexpected participants %v", req.GetParticipants())
}
return &threadsv1.CreateThreadResponse{Thread: thread}, nil
},
}
Expand Down Expand Up @@ -318,6 +381,7 @@ func TestGetChatsUsesStoreAndThreads(t *testing.T) {

threads := &mockThreadsClient{
getThreadsFunc: func(ctx context.Context, req *threadsv1.GetThreadsRequest, opts ...grpc.CallOption) (*threadsv1.GetThreadsResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
if req.GetParticipantId() != "user-1" {
return nil, status.Errorf(codes.InvalidArgument, "unexpected participant %q", req.GetParticipantId())
}
Expand Down Expand Up @@ -750,13 +814,21 @@ func TestGetMessagesAggregatesUnread(t *testing.T) {
var gotPageTokens []string
threads := &mockThreadsClient{
getMessagesFunc: func(ctx context.Context, req *threadsv1.GetMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetMessagesResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
gotMessagesReq = req
return &threadsv1.GetMessagesResponse{Messages: threadsMessages, NextPageToken: "next-token"}, nil
},
getUnackedMessagesFunc: func(ctx context.Context, req *threadsv1.GetUnackedMessagesRequest, opts ...grpc.CallOption) (*threadsv1.GetUnackedMessagesResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
if req.GetParticipantId() != "user-1" {
t.Fatalf("expected participant user-1, got %q", req.GetParticipantId())
}
if req.ThreadId == nil {
t.Fatal("expected thread_id to be set")
}
if req.GetThreadId() != chatID {
t.Fatalf("expected thread id %q, got %q", chatID, req.GetThreadId())
}
if req.GetPageSize() != unackedPageSize {
t.Fatalf("expected page size %d, got %d", unackedPageSize, req.GetPageSize())
}
Expand All @@ -766,15 +838,14 @@ func TestGetMessagesAggregatesUnread(t *testing.T) {
return &threadsv1.GetUnackedMessagesResponse{
Messages: []*threadsv1.Message{
{Id: "u1", ThreadId: chatID},
{Id: "u2", ThreadId: "chat-2"},
{Id: "u2", ThreadId: chatID},
},
NextPageToken: "page-2",
}, nil
case 2:
return &threadsv1.GetUnackedMessagesResponse{
Messages: []*threadsv1.Message{
{Id: "u3", ThreadId: chatID},
{Id: "u4", ThreadId: chatID},
},
}, nil
default:
Expand Down Expand Up @@ -834,6 +905,7 @@ func TestSendMessageDelegates(t *testing.T) {
msgTime := time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC)
threads := &mockThreadsClient{
sendMessageFunc: func(ctx context.Context, req *threadsv1.SendMessageRequest, opts ...grpc.CallOption) (*threadsv1.SendMessageResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
if req.GetThreadId() != "chat-1" {
return nil, status.Errorf(codes.InvalidArgument, "unexpected thread id %q", req.GetThreadId())
}
Expand Down Expand Up @@ -886,6 +958,7 @@ func TestMarkAsReadDelegates(t *testing.T) {
ctx := contextWithIdentity("user-1")
threads := &mockThreadsClient{
ackMessagesFunc: func(ctx context.Context, req *threadsv1.AckMessagesRequest, opts ...grpc.CallOption) (*threadsv1.AckMessagesResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
if req.GetParticipantId() != "user-1" {
return nil, status.Errorf(codes.InvalidArgument, "unexpected participant %q", req.GetParticipantId())
}
Expand Down
Loading