Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion internal/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import (
const (
MetadataKeyIdentityID = "x-identity-id"
MetadataKeyIdentityType = "x-identity-type"
MetadataKeyWorkloadID = "x-workload-id"
)

type Identity struct {
IdentityID string
IdentityType string
WorkloadID *string
}

func FromContext(ctx context.Context) (Identity, error) {
Expand All @@ -33,20 +35,27 @@ func FromContext(ctx context.Context) (Identity, error) {
return Identity{}, err
}

workloadID := optionalValue(md, MetadataKeyWorkloadID)

return Identity{
IdentityID: identityID,
IdentityType: identityType,
WorkloadID: workloadID,
}, nil
}

func AppendToOutgoingContext(ctx context.Context, identity Identity) context.Context {
return metadata.AppendToOutgoingContext(
ctx = metadata.AppendToOutgoingContext(
ctx,
MetadataKeyIdentityID,
identity.IdentityID,
MetadataKeyIdentityType,
identity.IdentityType,
)
if identity.WorkloadID == nil {
return ctx
}
return metadata.AppendToOutgoingContext(ctx, MetadataKeyWorkloadID, *identity.WorkloadID)
}

func singleValue(md metadata.MD, key string) (string, error) {
Expand All @@ -56,3 +65,12 @@ func singleValue(md metadata.MD, key string) (string, error) {
}
return values[0], nil
}

func optionalValue(md metadata.MD, key string) *string {
values := md.Get(key)
if len(values) == 0 {
return nil
}
value := values[0]
return &value
}
20 changes: 16 additions & 4 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,18 @@ func unexpectedStoreCall(method string) error {
return fmt.Errorf("unexpected store %s call", method)
}

func contextWithIdentity(identityID string) context.Context {
func contextWithIdentity(identityID string, workloadID ...string) context.Context {
md := metadata.New(map[string]string{
identity.MetadataKeyIdentityID: identityID,
identity.MetadataKeyIdentityType: "user",
})
if len(workloadID) > 0 {
md.Set(identity.MetadataKeyWorkloadID, workloadID[0])
}
return metadata.NewIncomingContext(context.Background(), md)
}

func requireOutgoingIdentity(t *testing.T, ctx context.Context, identityID, identityType string) {
func requireOutgoingIdentity(t *testing.T, ctx context.Context, identityID, identityType string, workloadID ...string) {
t.Helper()
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
Expand All @@ -165,6 +168,15 @@ func requireOutgoingIdentity(t *testing.T, ctx context.Context, identityID, iden
if got := md.Get(identity.MetadataKeyIdentityType); len(got) != 1 || got[0] != identityType {
t.Fatalf("expected outgoing identity type %q, got %v", identityType, got)
}
if len(workloadID) > 0 {
if got := md.Get(identity.MetadataKeyWorkloadID); len(got) != 1 || got[0] != workloadID[0] {
t.Fatalf("expected outgoing workload id %q, got %v", workloadID[0], got)
}
return
}
if got := md.Get(identity.MetadataKeyWorkloadID); len(got) != 0 {
t.Fatalf("expected no outgoing workload id, got %v", got)
}
}

func requireStatusCode(t *testing.T, err error, code codes.Code) {
Expand Down Expand Up @@ -215,7 +227,7 @@ func TestCreateChatRejectsInvalidOrganizationID(t *testing.T) {
}

func TestCreateChatDeduplicatesParticipants(t *testing.T) {
ctx := contextWithIdentity("user-1")
ctx := contextWithIdentity("user-1", "workload-1")
orgID := uuid.New()
var gotRequest *threadsv1.CreateThreadRequest
threadID := uuid.New()
Expand All @@ -224,7 +236,7 @@ func TestCreateChatDeduplicatesParticipants(t *testing.T) {
var storedOrgID uuid.UUID
threads := &mockThreadsClient{
createThreadFunc: func(ctx context.Context, req *threadsv1.CreateThreadRequest, opts ...grpc.CallOption) (*threadsv1.CreateThreadResponse, error) {
requireOutgoingIdentity(t, ctx, "user-1", "user")
requireOutgoingIdentity(t, ctx, "user-1", "user", "workload-1")
gotRequest = req
return &threadsv1.CreateThreadResponse{Thread: thread}, nil
},
Expand Down
Loading
Loading