Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/temporalio/s2s-proxy/config"
"github.com/temporalio/s2s-proxy/logging"
"github.com/temporalio/s2s-proxy/metrics"
"github.com/temporalio/s2s-proxy/proto/compat"
"github.com/temporalio/s2s-proxy/proxy"
)
Expand Down Expand Up @@ -74,9 +75,13 @@ func startProxy(c *cli.Context) error {
}),
logging.Module,
config.Module,
metrics.Module,
proxy.Module,
fx.Populate(&proxyParams),
fx.Populate(compat.GetCodec().CodecParams),
fx.Invoke(func(reg *metrics.Registry) {
compat.GetCodec().SetRegistry(reg)
}),
)

if err := app.Err(); err != nil {
Expand Down
12 changes: 11 additions & 1 deletion endtoendtest/echo_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/gogo/status"
"github.com/prometheus/client_golang/prometheus"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/server/api/adminservice/v1"
replicationpb "go.temporal.io/server/api/replication/v1"
Expand Down Expand Up @@ -114,9 +115,16 @@ func NewEchoServer(
}

configProvider := config.NewMockConfigProvider(*localClusterInfo.S2sProxyConfig)
promReg := prometheus.NewRegistry()
reg, err := metrics.NewRegistry(promReg, promReg)
if err != nil {
panic(err)
}

proxy = s2sproxy.NewProxy(
configProvider,
logging.NewLoggerProvider(logger, configProvider),
reg,
)

clientConfig = config.ProxyClientConfig{
Expand Down Expand Up @@ -268,7 +276,8 @@ func SendRecv(stream adminservice.AdminService_StreamWorkflowReplicationMessages
InclusiveLowWatermarkTime: timestamppb.New(highWatermarkInfo.Timestamp),
},
},
}}
},
}

if err = stream.Send(req); err != nil {
return echoed, err
Expand Down Expand Up @@ -322,6 +331,7 @@ func (s *EchoServer) PollActivityTaskQueue(req *workflowservice.PollActivityTask
defer cancel()
return workflowservice.NewWorkflowServiceClient(s.RemoteClient).PollActivityTaskQueue(timeout, req)
}

func (s *EchoServer) Describe() string {
proxyDescription := "no proxy"
if s.Proxy != nil {
Expand Down
1 change: 1 addition & 0 deletions interceptor/access_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type (
adminServiceAccess *auth.AccessControl
namespaceAccess *auth.AccessControl
}

ACLConfig interface {
AllowedNamespaces() []string
AdminServiceAllowedMethods() []string
Expand Down
40 changes: 28 additions & 12 deletions interceptor/reflection.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ type visitor func(logger log.Logger, obj any, match stringMatcher) (bool, error)
// in the given object. When it finds namespace string fields, it invokes
// the provided match function.
func visitNamespace(logger log.Logger, obj any, match stringMatcher) (bool, error) {
return visitNamespaceWithReg(logger, nil, obj, match)
}

func visitNamespaceWithReg(logger log.Logger, reg *metrics.Registry, obj any, match stringMatcher) (bool, error) {
if isSkippableForNamespaceTranslation(obj) {
return false, nil
}
Expand Down Expand Up @@ -166,15 +170,17 @@ func visitNamespace(logger log.Logger, obj any, match stringMatcher) (bool, erro
} else if hist, ok := vwp.Interface().(*history.History); ok && hist != nil {
for _, evt := range hist.GetEvents() {
// Do the recursive call here so that we check `isSkippableForNamespaceTranslation`.
m, err := visitNamespace(logger, evt, match)
m, err := visitNamespaceWithReg(logger, reg, evt, match)
matched = matched || m
if err != nil {
return visit.Stop, err
}
}
return visit.Skip, nil
} else if dataBlobFieldNames[fieldType.Name] {
changed, err := visitDataBlobs(logger, vwp, match, visitNamespace)
changed, err := visitDataBlobs(logger, reg, vwp, match, func(l log.Logger, o any, m stringMatcher) (bool, error) {
return visitNamespaceWithReg(l, reg, o, m)
})
matched = matched || changed
if err != nil {
return visit.Stop, err
Expand Down Expand Up @@ -202,9 +208,13 @@ func visitNamespace(logger log.Logger, obj any, match stringMatcher) (bool, erro
}

// visitSearchAttributes uses reflection to recursively visit all fields
// in the given object. When it finds namespace string fields, it invokes
// in the given object. When it finds search attribute fields, it invokes
// the provided match function.
func visitSearchAttributes(logger log.Logger, obj any, match stringMatcher) (bool, error) {
return visitSearchAttributesWithReg(logger, nil, obj, match)
}

func visitSearchAttributesWithReg(logger log.Logger, reg *metrics.Registry, obj any, match stringMatcher) (bool, error) {
var matched bool

// The visitor function can return Skip, Stop, or Continue to control recursion.
Expand All @@ -219,7 +229,9 @@ func visitSearchAttributes(logger log.Logger, obj any, match stringMatcher) (boo
return action, nil
}
if dataBlobFieldNames[fieldType.Name] {
changed, err := visitDataBlobs(logger, vwp, match, visitSearchAttributes)
changed, err := visitDataBlobs(logger, reg, vwp, match, func(l log.Logger, o any, m stringMatcher) (bool, error) {
return visitSearchAttributesWithReg(l, reg, o, m)
})
matched = matched || changed
if err != nil {
return visit.Stop, err
Expand Down Expand Up @@ -281,10 +293,10 @@ func getParentFieldType(vwp visit.ValueWithParent) (result reflect.StructField,
return fieldType, action
}

func visitDataBlobs(logger log.Logger, vwp visit.ValueWithParent, match stringMatcher, visitor visitor) (bool, error) {
func visitDataBlobs(logger log.Logger, reg *metrics.Registry, vwp visit.ValueWithParent, match stringMatcher, visitor visitor) (bool, error) {
switch evt := vwp.Interface().(type) {
case []*common.DataBlob:
newEvts, matched, changed, err := translateDataBlobs(logger, match, visitor, evt...)
newEvts, matched, changed, err := translateDataBlobs(logger, reg, match, visitor, evt...)
if err != nil {
return matched, err
}
Expand All @@ -295,7 +307,7 @@ func visitDataBlobs(logger log.Logger, vwp visit.ValueWithParent, match stringMa
}
return matched, nil
case *common.DataBlob:
newEvt, matched, changed, err := translateOneDataBlob(logger, match, visitor, evt)
newEvt, matched, changed, err := translateOneDataBlob(logger, reg, match, visitor, evt)
if err != nil {
return matched, err
}
Expand All @@ -310,9 +322,9 @@ func visitDataBlobs(logger log.Logger, vwp visit.ValueWithParent, match stringMa
}
}

func translateDataBlobs(logger log.Logger, match stringMatcher, visitor visitor, blobs ...*common.DataBlob) (result []*common.DataBlob, anyMatched, anyChanged bool, retErr error) {
func translateDataBlobs(logger log.Logger, reg *metrics.Registry, match stringMatcher, visitor visitor, blobs ...*common.DataBlob) (result []*common.DataBlob, anyMatched, anyChanged bool, retErr error) {
for i, blob := range blobs {
newBlob, matched, changed, err := translateOneDataBlob(logger, match, visitor, blob)
newBlob, matched, changed, err := translateOneDataBlob(logger, reg, match, visitor, blob)
anyChanged = anyChanged || changed
anyMatched = anyMatched || matched
if err != nil {
Expand All @@ -323,7 +335,7 @@ func translateDataBlobs(logger log.Logger, match stringMatcher, visitor visitor,
return blobs, anyMatched, anyChanged, nil
}

func translateOneDataBlob(logger log.Logger, match stringMatcher, visitor visitor, blob *common.DataBlob) (result *common.DataBlob, matched, changed bool, retErr error) {
func translateOneDataBlob(logger log.Logger, reg *metrics.Registry, match stringMatcher, visitor visitor, blob *common.DataBlob) (result *common.DataBlob, matched, changed bool, retErr error) {
if blob == nil || len(blob.Data) == 0 {
return blob, matched, changed, nil
}
Expand All @@ -341,11 +353,15 @@ func translateOneDataBlob(logger log.Logger, match stringMatcher, visitor visito
changed = changed || c
if err != nil {
logger.Error("failed to repair invalid utf-8 in history event blob", tag.Error(err))
metrics.TranslationErrors.WithLabelValues(metrics.UTF8RepairTranslationKind, metrics.HistoryBlobMessageType).Inc()
if reg != nil {
reg.TranslationErrors.WithLabelValues(metrics.UTF8RepairTranslationKind, metrics.HistoryBlobMessageType).Inc()
}
return blob, matched, changed, err
} else if changed {
logger.Debug("repaired invalid utf-8 in history event blob")
metrics.TranslationCount.WithLabelValues(metrics.UTF8RepairTranslationKind, metrics.HistoryBlobMessageType).Inc()
if reg != nil {
reg.TranslationCount.WithLabelValues(metrics.UTF8RepairTranslationKind, metrics.HistoryBlobMessageType).Inc()
}
events = repairedEvents
}
}
Expand Down
8 changes: 5 additions & 3 deletions interceptor/search_attribute_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ import (
type (
saTranslator struct {
logger log.Logger
reg *metrics.Registry
matchMethod func(string) bool
reqMap map[string]stringMatcher
respMap map[string]stringMatcher
}
)

func NewSearchAttributeTranslator(logger log.Logger, reqMap, respMap map[string]map[string]string) Translator {
func NewSearchAttributeTranslator(logger log.Logger, reg *metrics.Registry, reqMap, respMap map[string]map[string]string) Translator {
return &saTranslator{
logger: logger,
reg: reg,
matchMethod: func(method string) bool {
// In workflowservice APIs, responses only contain the search attribute alias.
// We should never translate these responses to the search attribute's indexed field.
Expand All @@ -40,11 +42,11 @@ func (s *saTranslator) MatchMethod(m string) bool {
}

func (s *saTranslator) TranslateRequest(req any) (bool, error) {
return visitSearchAttributes(s.logger, req, s.getNamespaceReqMatcher(""))
return visitSearchAttributesWithReg(s.logger, s.reg, req, s.getNamespaceReqMatcher(""))
}

func (s *saTranslator) TranslateResponse(resp any) (bool, error) {
return visitSearchAttributes(s.logger, resp, s.getNamespaceRespMatcher(""))
return visitSearchAttributesWithReg(s.logger, s.reg, resp, s.getNamespaceRespMatcher(""))
}

func (s *saTranslator) getNamespaceReqMatcher(namespaceId string) stringMatcher {
Expand Down
24 changes: 15 additions & 9 deletions interceptor/translation_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ type (
TranslationInterceptor struct {
logger log.Logger
translators []Translator
reg *metrics.Registry
}
)

func NewTranslationInterceptor(
logger log.Logger,
translators []Translator,
reg *metrics.Registry,
) *TranslationInterceptor {
return &TranslationInterceptor{
logger: logger,
translators: translators,
reg: reg,
}
}

Expand All @@ -50,7 +53,7 @@ func (i *TranslationInterceptor) Intercept(
if tr.MatchMethod(info.FullMethod) {
start := time.Now()
changed, trErr := tr.TranslateRequest(req)
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Request", req, time.Since(start))
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Request", req, time.Since(start), i.reg)
}
}

Expand All @@ -60,7 +63,7 @@ func (i *TranslationInterceptor) Intercept(
if tr.MatchMethod(info.FullMethod) {
start := time.Now()
changed, trErr := tr.TranslateResponse(resp)
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Response", resp, time.Since(start))
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Response", resp, time.Since(start), i.reg)
}
}

Expand All @@ -86,20 +89,21 @@ func (i *TranslationInterceptor) InterceptStream(
}
return err
}
return handler(srv, newStreamTranslator(ss, i.logger, i.translators))
return handler(srv, newStreamTranslator(ss, i.logger, i.translators, i.reg))
}

type streamTranslator struct {
grpc.ServerStream
logger log.Logger
translators []Translator
reg *metrics.Registry
}

func (w *streamTranslator) RecvMsg(m any) error {
for _, tr := range w.translators {
start := time.Now()
changed, trErr := tr.TranslateRequest(m)
logTranslateResult(tr, w.logger, changed, trErr, "RecvMsg", m, time.Since(start))
logTranslateResult(tr, w.logger, changed, trErr, "RecvMsg", m, time.Since(start), w.reg)
}
return w.ServerStream.RecvMsg(m)
}
Expand All @@ -108,7 +112,7 @@ func (w *streamTranslator) SendMsg(m any) error {
for _, tr := range w.translators {
start := time.Now()
changed, trErr := tr.TranslateResponse(m)
logTranslateResult(tr, w.logger, changed, trErr, "SendMsg", m, time.Since(start))
logTranslateResult(tr, w.logger, changed, trErr, "SendMsg", m, time.Since(start), w.reg)
}
return w.ServerStream.SendMsg(m)
}
Expand All @@ -117,25 +121,27 @@ func newStreamTranslator(
s grpc.ServerStream,
logger log.Logger,
translators []Translator,
reg *metrics.Registry,
) grpc.ServerStream {
return &streamTranslator{
ServerStream: s,
logger: logger,
translators: translators,
reg: reg,
}
}

func logTranslateResult(tr Translator, logger log.Logger, changed bool, err error, methodName string, obj any, duration time.Duration) {
func logTranslateResult(tr Translator, logger log.Logger, changed bool, err error, methodName string, obj any, duration time.Duration, reg *metrics.Registry) {
msgType := metrics.SanitizedTypeName(obj)
metrics.TranslationLatency.WithLabelValues(tr.Kind(), msgType).Observe(duration.Seconds())
reg.TranslationLatency.WithLabelValues(tr.Kind(), msgType).Observe(duration.Seconds())

methodTag := tag.NewStringTag("method", methodName)
if err != nil {
logger.Error("translation error", methodTag, tag.Error(err), tag.NewStringTag("type", msgType))
metrics.TranslationErrors.WithLabelValues(tr.Kind(), msgType).Inc()
reg.TranslationErrors.WithLabelValues(tr.Kind(), msgType).Inc()
} else if changed {
logger.Debug("translation applied", methodTag, tag.NewAnyTag("obj", obj))
metrics.TranslationCount.WithLabelValues(tr.Kind(), msgType).Inc()
reg.TranslationCount.WithLabelValues(tr.Kind(), msgType).Inc()
} else {
logger.Debug("translation not applied", methodTag, tag.NewAnyTag("obj", obj))
}
Expand Down
10 changes: 5 additions & 5 deletions interceptor/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@ type (

translatorImpl struct {
logger log.Logger
reg *metrics.Registry
matchMethod func(string) bool
matchReq stringMatcher
matchResp stringMatcher
visitor visitor
kind string
}
)

func NewNamespaceNameTranslator(logger log.Logger, reqMap, respMap map[string]string) Translator {
func NewNamespaceNameTranslator(logger log.Logger, reg *metrics.Registry, reqMap, respMap map[string]string) Translator {
return &translatorImpl{
logger: logger,
reg: reg,
matchMethod: func(string) bool { return true },
matchReq: createStringMatcher(reqMap),
matchResp: createStringMatcher(respMap),
visitor: visitNamespace,
kind: metrics.NamespaceTranslationKind,
}
}
Expand All @@ -44,11 +44,11 @@ func (n *translatorImpl) MatchMethod(m string) bool {
}

func (n *translatorImpl) TranslateRequest(req any) (bool, error) {
return n.visitor(n.logger, req, n.matchReq)
return visitNamespaceWithReg(n.logger, n.reg, req, n.matchReq)
}

func (n *translatorImpl) TranslateResponse(resp any) (bool, error) {
return n.visitor(n.logger, resp, n.matchResp)
return visitNamespaceWithReg(n.logger, n.reg, resp, n.matchResp)
}

func createStringMatcher(mapping map[string]string) stringMatcher {
Expand Down
11 changes: 11 additions & 0 deletions metrics/fx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package metrics

import (
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/fx"
)

var Module = fx.Provide(func() (*Registry, error) {
reg := prometheus.NewRegistry()
return NewRegistry(reg, reg)
})
Loading
Loading