From a02e5b3d5d0ca9b2a47732231a2f33d2e9750d6f Mon Sep 17 00:00:00 2001 From: Edward J Date: Tue, 21 Apr 2026 11:28:26 -0700 Subject: [PATCH 1/4] Use fabric dns client for dns resolves --- src/pkg/cli/client/client.go | 3 + src/pkg/cli/client/grpc.go | 12 ++++ src/pkg/cli/connect.go | 5 +- src/pkg/dns/fabric_test.go | 132 +++++++++++++++++++++++++++++++++++ src/pkg/dns/resolver.go | 88 +++++++++++++++++++++++ 5 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 src/pkg/dns/fabric_test.go diff --git a/src/pkg/cli/client/client.go b/src/pkg/cli/client/client.go index 728af7610..f025e869d 100644 --- a/src/pkg/cli/client/client.go +++ b/src/pkg/cli/client/client.go @@ -34,6 +34,9 @@ type FabricClient interface { Preview(context.Context, *defangv1.PreviewRequest) (*defangv1.PreviewResponse, error) PutDeployment(context.Context, *defangv1.PutDeploymentRequest) error PutStack(context.Context, *defangv1.PutStackRequest) error + ResolveCNAME(context.Context, *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) + ResolveIPAddr(context.Context, *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) + ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) RevokeToken(context.Context) error SetOptions(context.Context, *defangv1.SetOptionsRequest) error Token(context.Context, *defangv1.TokenRequest) (*defangv1.TokenResponse, error) diff --git a/src/pkg/cli/client/grpc.go b/src/pkg/cli/client/grpc.go index 7c34af198..392d634db 100644 --- a/src/pkg/cli/client/grpc.go +++ b/src/pkg/cli/client/grpc.go @@ -206,3 +206,15 @@ func (g GrpcClient) GenerateCompose(ctx context.Context, req *defangv1.GenerateC func (g GrpcClient) GetDefaultStack(ctx context.Context, req *defangv1.GetDefaultStackRequest) (*defangv1.GetStackResponse, error) { return getMsg(g.client.GetDefaultStack(ctx, connect.NewRequest(req))) } + +func (g GrpcClient) ResolveIPAddr(ctx context.Context, req *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) { + return getMsg(g.client.ResolveIPAddr(ctx, connect.NewRequest(req))) +} + +func (g GrpcClient) ResolveCNAME(ctx context.Context, req *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) { + return getMsg(g.client.ResolveCNAME(ctx, connect.NewRequest(req))) +} + +func (g GrpcClient) ResolveNS(ctx context.Context, req *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) { + return getMsg(g.client.ResolveNS(ctx, connect.NewRequest(req))) +} diff --git a/src/pkg/cli/connect.go b/src/pkg/cli/connect.go index e46ebcf77..6d9b2cb32 100644 --- a/src/pkg/cli/connect.go +++ b/src/pkg/cli/connect.go @@ -7,6 +7,7 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/aws" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/do" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/gcp" + "github.com/DefangLabs/defang/src/pkg/dns" "github.com/DefangLabs/defang/src/pkg/term" "github.com/DefangLabs/defang/src/pkg/types" ) @@ -17,7 +18,9 @@ func Connect(fabricAddr string, requestedTenant types.TenantNameOrID) *client.Gr term.Debugf("Using tenant %q for cluster %q", requestedTenant, host) accessToken := client.GetExistingToken(host) - return client.NewGrpcClient(host, accessToken, requestedTenant) + grpcClient := client.NewGrpcClient(host, accessToken, requestedTenant) + dns.UseFabricResolver(grpcClient) + return grpcClient } func ConnectWithTenant(ctx context.Context, fabricAddr string, requestedTenant types.TenantNameOrID) (*client.GrpcClient, error) { diff --git a/src/pkg/dns/fabric_test.go b/src/pkg/dns/fabric_test.go new file mode 100644 index 000000000..e0707f802 --- /dev/null +++ b/src/pkg/dns/fabric_test.go @@ -0,0 +1,132 @@ +package dns + +import ( + "context" + "errors" + "testing" + + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" +) + +type mockFabricClient struct { + ipResp *defangv1.ResolveIPAddrResponse + ipErr error + cnameResp *defangv1.ResolveCNAMEResponse + cnameErr error + nsResp *defangv1.ResolveNSResponse + nsErr error + + lastIPReq *defangv1.ResolveIPAddrRequest + lastCNAMEReq *defangv1.ResolveCNAMERequest + lastNSReq *defangv1.ResolveNSRequest +} + +func (m *mockFabricClient) ResolveIPAddr(_ context.Context, req *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) { + m.lastIPReq = req + return m.ipResp, m.ipErr +} + +func (m *mockFabricClient) ResolveCNAME(_ context.Context, req *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) { + m.lastCNAMEReq = req + return m.cnameResp, m.cnameErr +} + +func (m *mockFabricClient) ResolveNS(_ context.Context, req *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) { + m.lastNSReq = req + return m.nsResp, m.nsErr +} + +func TestFabricResolverLookupIPAddr(t *testing.T) { + t.Run("returns parsed IPs and forwards NSServer", func(t *testing.T) { + m := &mockFabricClient{ + ipResp: &defangv1.ResolveIPAddrResponse{IpAddrs: []string{"1.2.3.4", "::1", "not-an-ip"}}, + } + r := FabricResolver{Client: m, NSServer: "ns.example.com"} + ips, err := r.LookupIPAddr(t.Context(), "example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ips) != 2 { + t.Fatalf("expected 2 valid IPs, got %v", ips) + } + if m.lastIPReq.Domain != "example.com" || m.lastIPReq.NsServer != "ns.example.com" { + t.Errorf("request mismatch: %+v", m.lastIPReq) + } + }) + + t.Run("empty IPs returns ErrNoSuchHost", func(t *testing.T) { + m := &mockFabricClient{ipResp: &defangv1.ResolveIPAddrResponse{}} + r := FabricResolver{Client: m} + if _, err := r.LookupIPAddr(t.Context(), "nx.example.com"); !errors.Is(err, ErrNoSuchHost) { + t.Errorf("expected ErrNoSuchHost, got %v", err) + } + }) + + t.Run("propagates RPC error", func(t *testing.T) { + boom := errors.New("rpc boom") + m := &mockFabricClient{ipErr: boom} + r := FabricResolver{Client: m} + if _, err := r.LookupIPAddr(t.Context(), "example.com"); err != boom { + t.Errorf("expected rpc error, got %v", err) + } + }) +} + +func TestFabricResolverLookupCNAME(t *testing.T) { + t.Run("returns cname", func(t *testing.T) { + m := &mockFabricClient{cnameResp: &defangv1.ResolveCNAMEResponse{Cname: "alb.example.com"}} + r := FabricResolver{Client: m} + cname, err := r.LookupCNAME(t.Context(), "api.example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cname != "alb.example.com" { + t.Errorf("got %q", cname) + } + }) + + t.Run("empty cname returns ErrNoSuchHost", func(t *testing.T) { + m := &mockFabricClient{cnameResp: &defangv1.ResolveCNAMEResponse{}} + r := FabricResolver{Client: m} + if _, err := r.LookupCNAME(t.Context(), "api.example.com"); !errors.Is(err, ErrNoSuchHost) { + t.Errorf("expected ErrNoSuchHost, got %v", err) + } + }) +} + +func TestFabricResolverLookupNS(t *testing.T) { + m := &mockFabricClient{nsResp: &defangv1.ResolveNSResponse{Hosts: []string{"ns1.example.com.", "ns2.example.com."}}} + r := FabricResolver{Client: m} + ns, err := r.LookupNS(t.Context(), "example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ns) != 2 || ns[0].Host != "ns1.example.com." { + t.Errorf("unexpected NS result: %+v", ns) + } +} + +func TestUseFabricResolver(t *testing.T) { + t.Cleanup(func() { + fabricClient = nil + ResolverAt = DirectResolverAt + }) + + m := &mockFabricClient{ipResp: &defangv1.ResolveIPAddrResponse{IpAddrs: []string{"9.9.9.9"}}} + UseFabricResolver(m) + + // RootResolver should now delegate to FabricResolver. + ips, err := RootResolver{}.LookupIPAddr(t.Context(), "example.com") + if err != nil { + t.Fatalf("RootResolver.LookupIPAddr: %v", err) + } + if len(ips) != 1 || ips[0].IP.String() != "9.9.9.9" { + t.Errorf("unexpected IPs: %v", ips) + } + + // ResolverAt should return a FabricResolver bound to the NS. + r := ResolverAt("ns1.example.com") + if fr, ok := r.(FabricResolver); !ok || fr.NSServer != "ns1.example.com" { + t.Errorf("ResolverAt did not return FabricResolver: %T %+v", r, r) + } +} diff --git a/src/pkg/dns/resolver.go b/src/pkg/dns/resolver.go index 0bfc85cf8..71baf44a1 100644 --- a/src/pkg/dns/resolver.go +++ b/src/pkg/dns/resolver.go @@ -9,6 +9,7 @@ import ( "sort" "github.com/DefangLabs/defang/src/pkg" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/miekg/dns" ) @@ -18,6 +19,84 @@ type Resolver interface { LookupNS(ctx context.Context, domain string) ([]*net.NS, error) } +// FabricResolverClient is the subset of the fabric gRPC API used to resolve DNS +// records remotely. +type FabricResolverClient interface { + ResolveIPAddr(context.Context, *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) + ResolveCNAME(context.Context, *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) + ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) +} + +// fabricClient is set by UseFabricResolver. When non-nil, RootResolver and +// ResolverAt route DNS lookups through the fabric gRPC API. +var fabricClient FabricResolverClient + +// UseFabricResolver wires DNS lookups through the fabric gRPC API. After it is +// called, RootResolver{} and ResolverAt(nsServer) both issue remote RPCs +// instead of performing direct UDP DNS queries. +func UseFabricResolver(c FabricResolverClient) { + fabricClient = c + ResolverAt = func(nsServer string) Resolver { + return FabricResolver{Client: c, NSServer: nsServer} + } +} + +// FabricResolver performs DNS lookups via the fabric gRPC API. An empty +// NSServer lets the server perform recursive resolution from the root. +type FabricResolver struct { + Client FabricResolverClient + NSServer string +} + +func (r FabricResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { + resp, err := r.Client.ResolveIPAddr(ctx, &defangv1.ResolveIPAddrRequest{ + Domain: domain, + NsServer: r.NSServer, + }) + if err != nil { + return nil, err + } + ips := make([]net.IPAddr, 0, len(resp.IpAddrs)) + for _, s := range resp.IpAddrs { + if ip := net.ParseIP(s); ip != nil { + ips = append(ips, net.IPAddr{IP: ip}) + } + } + if len(ips) == 0 { + return nil, ErrNoSuchHost + } + return ips, nil +} + +func (r FabricResolver) LookupCNAME(ctx context.Context, domain string) (string, error) { + resp, err := r.Client.ResolveCNAME(ctx, &defangv1.ResolveCNAMERequest{ + Domain: domain, + NsServer: r.NSServer, + }) + if err != nil { + return "", err + } + if resp.Cname == "" { + return "", ErrNoSuchHost + } + return resp.Cname, nil +} + +func (r FabricResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { + resp, err := r.Client.ResolveNS(ctx, &defangv1.ResolveNSRequest{ + Domain: domain, + NsServer: r.NSServer, + }) + if err != nil { + return nil, err + } + nss := make([]*net.NS, 0, len(resp.Hosts)) + for _, h := range resp.Hosts { + nss = append(nss, &net.NS{Host: h}) + } + return nss, nil +} + type RootResolver struct{} // https://en.wikipedia.org/wiki/Root_name_server @@ -38,6 +117,9 @@ var rootServers = []*net.NS{ } func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { + if fabricClient != nil { + return FabricResolver{Client: fabricClient}.LookupIPAddr(ctx, domain) + } for range 10 { ips, err := r.getResolver(ctx, domain).LookupIPAddr(ctx, domain) if err != nil { @@ -54,10 +136,16 @@ func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IP } func (r RootResolver) LookupCNAME(ctx context.Context, domain string) (string, error) { + if fabricClient != nil { + return FabricResolver{Client: fabricClient}.LookupCNAME(ctx, domain) + } return r.getResolver(ctx, domain).LookupCNAME(ctx, domain) } func (r RootResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { + if fabricClient != nil { + return FabricResolver{Client: fabricClient}.LookupNS(ctx, domain) + } return r.getResolver(ctx, domain).LookupNS(ctx, domain) } From 1dbacadd9f13abd968a2e883cdf7d0a89054e8f8 Mon Sep 17 00:00:00 2001 From: Edward J Date: Tue, 21 Apr 2026 12:06:19 -0700 Subject: [PATCH 2/4] Fix fabric dns client race condition --- src/pkg/dns/resolver.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/pkg/dns/resolver.go b/src/pkg/dns/resolver.go index 71baf44a1..5525264bb 100644 --- a/src/pkg/dns/resolver.go +++ b/src/pkg/dns/resolver.go @@ -7,6 +7,7 @@ import ( "net" "slices" "sort" + "sync" "github.com/DefangLabs/defang/src/pkg" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" @@ -27,6 +28,10 @@ type FabricResolverClient interface { ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) } +// fabricMu guards concurrent access to fabricClient and the ResolverAt +// assignment inside UseFabricResolver. +var fabricMu sync.RWMutex + // fabricClient is set by UseFabricResolver. When non-nil, RootResolver and // ResolverAt route DNS lookups through the fabric gRPC API. var fabricClient FabricResolverClient @@ -35,12 +40,20 @@ var fabricClient FabricResolverClient // called, RootResolver{} and ResolverAt(nsServer) both issue remote RPCs // instead of performing direct UDP DNS queries. func UseFabricResolver(c FabricResolverClient) { + fabricMu.Lock() + defer fabricMu.Unlock() fabricClient = c ResolverAt = func(nsServer string) Resolver { return FabricResolver{Client: c, NSServer: nsServer} } } +func getFabricClient() FabricResolverClient { + fabricMu.RLock() + defer fabricMu.RUnlock() + return fabricClient +} + // FabricResolver performs DNS lookups via the fabric gRPC API. An empty // NSServer lets the server perform recursive resolution from the root. type FabricResolver struct { @@ -117,8 +130,8 @@ var rootServers = []*net.NS{ } func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { - if fabricClient != nil { - return FabricResolver{Client: fabricClient}.LookupIPAddr(ctx, domain) + if c := getFabricClient(); c != nil { + return FabricResolver{Client: c}.LookupIPAddr(ctx, domain) } for range 10 { ips, err := r.getResolver(ctx, domain).LookupIPAddr(ctx, domain) @@ -136,15 +149,15 @@ func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IP } func (r RootResolver) LookupCNAME(ctx context.Context, domain string) (string, error) { - if fabricClient != nil { - return FabricResolver{Client: fabricClient}.LookupCNAME(ctx, domain) + if c := getFabricClient(); c != nil { + return FabricResolver{Client: c}.LookupCNAME(ctx, domain) } return r.getResolver(ctx, domain).LookupCNAME(ctx, domain) } func (r RootResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { - if fabricClient != nil { - return FabricResolver{Client: fabricClient}.LookupNS(ctx, domain) + if c := getFabricClient(); c != nil { + return FabricResolver{Client: c}.LookupNS(ctx, domain) } return r.getResolver(ctx, domain).LookupNS(ctx, domain) } From b4ce07878f5413ef227d29b2f3263fdc4f16c1a3 Mon Sep 17 00:00:00 2001 From: Edward J Date: Thu, 23 Apr 2026 12:45:16 -0700 Subject: [PATCH 3/4] Fix resolverAt race condition --- src/pkg/dns/check_test.go | 32 +++++++++++++++--------------- src/pkg/dns/fabric_test.go | 38 +++++++++++++++++++++++++++++++++++- src/pkg/dns/resolver.go | 20 +++++++++++++++---- src/pkg/dns/resolver_test.go | 6 +++--- 4 files changed, 72 insertions(+), 24 deletions(-) diff --git a/src/pkg/dns/check_test.go b/src/pkg/dns/check_test.go index 904b0ac4d..46595a971 100644 --- a/src/pkg/dns/check_test.go +++ b/src/pkg/dns/check_test.go @@ -13,7 +13,7 @@ var notFound = errors.New("not found") func TestGetCNAMEInSync(t *testing.T) { t.Cleanup(func() { - ResolverAt = DirectResolverAt + resolverAt = DirectResolverAt }) notFoundResolver := MockResolver{Records: map[DNSRequest]DNSResponse{ @@ -27,7 +27,7 @@ func TestGetCNAMEInSync(t *testing.T) { // Test when the domain is not found t.Run("domain not found", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return notFoundResolver } + resolverAt = func(_ string) Resolver { return notFoundResolver } _, err := getCNAMEInSync(t.Context(), "web.test.com") if err != notFound { t.Errorf("Expected NotFound error, got %v", err) @@ -36,7 +36,7 @@ func TestGetCNAMEInSync(t *testing.T) { // Test when the domain is found but the DNS servers are not in sync t.Run("DNS servers not in sync", func(t *testing.T) { - ResolverAt = func(nsServer string) Resolver { + resolverAt = func(nsServer string) Resolver { if nsServer == "ns1.example.com" { return foundResolver } else { @@ -51,7 +51,7 @@ func TestGetCNAMEInSync(t *testing.T) { // Test when the domain is found and the DNS servers are in sync t.Run("DNS servers in sync", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return foundResolver } + resolverAt = func(_ string) Resolver { return foundResolver } cname, err := getCNAMEInSync(t.Context(), "web.test.com") if err != nil { t.Errorf("Expected no error, got %v", err) @@ -65,7 +65,7 @@ func TestGetCNAMEInSync(t *testing.T) { func TestGetIPInSync(t *testing.T) { t.Cleanup(func() { - ResolverAt = DirectResolverAt + resolverAt = DirectResolverAt }) notFoundResolver := MockResolver{Records: map[DNSRequest]DNSResponse{ @@ -83,7 +83,7 @@ func TestGetIPInSync(t *testing.T) { // Test when the domain is not found t.Run("domain not found", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return notFoundResolver } + resolverAt = func(_ string) Resolver { return notFoundResolver } _, err := getIPInSync(t.Context(), "test.com") if err != notFound { t.Errorf("Expected NotFound error, got %v", err) @@ -92,7 +92,7 @@ func TestGetIPInSync(t *testing.T) { // Test when the domain is found but the DNS servers are not in sync t.Run("DNS servers not in sync", func(t *testing.T) { - ResolverAt = func(nsServer string) Resolver { + resolverAt = func(nsServer string) Resolver { if nsServer == "ns1.example.com" { return foundResolver } else { @@ -107,7 +107,7 @@ func TestGetIPInSync(t *testing.T) { // 2nd not in sync scenario t.Run("DNS servers not in sync with partial results", func(t *testing.T) { - ResolverAt = func(nsServer string) Resolver { + resolverAt = func(nsServer string) Resolver { if nsServer == "ns1.example.com" { return partialFoundResolver } else { @@ -122,7 +122,7 @@ func TestGetIPInSync(t *testing.T) { // Test when the domain is found and the DNS servers are in sync t.Run("DNS servers in sync", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return foundResolver } + resolverAt = func(_ string) Resolver { return foundResolver } ips, err := getIPInSync(t.Context(), "test.com") if err != nil { t.Errorf("Expected no error, got %v", err) @@ -153,42 +153,42 @@ func TestCheckDomainDNSReady(t *testing.T) { }} resolver = hasARecordResolver - oldResolver, oldDebug := ResolverAt, term.DoDebug() + oldResolver, oldDebug := resolverAt, term.DoDebug() t.Cleanup(func() { - ResolverAt = oldResolver + resolverAt = oldResolver term.SetDebug(oldDebug) }) term.SetDebug(true) t.Run("CNAME and A records not found", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return emptyResolver } + resolverAt = func(_ string) Resolver { return emptyResolver } if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != false { t.Errorf("Expected false when both CNAME and A records are missing, got true") } }) t.Run("CNAME setup correctly", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return hasCNAMEResolver } + resolverAt = func(_ string) Resolver { return hasCNAMEResolver } if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != true { t.Errorf("Expected true when CNAME is setup correctly, got false") } }) t.Run("CNAME setup incorrectly", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return hasCNAMEResolver } + resolverAt = func(_ string) Resolver { return hasCNAMEResolver } if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-other-alb.domain.com"}) != false { t.Errorf("Expected false when CNAME is setup incorrectly, got true") } }) t.Run("A record setup correctly", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return hasARecordResolver } + resolverAt = func(_ string) Resolver { return hasARecordResolver } if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != true { t.Errorf("Expected true when A record is setup correctly, got false") } }) t.Run("A record setup incorrectly", func(t *testing.T) { - ResolverAt = func(_ string) Resolver { return hasWrongARecordResolver } + resolverAt = func(_ string) Resolver { return hasWrongARecordResolver } if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != false { t.Errorf("Expected false when A record is setup incorrectly, got true") } diff --git a/src/pkg/dns/fabric_test.go b/src/pkg/dns/fabric_test.go index e0707f802..a90b12932 100644 --- a/src/pkg/dns/fabric_test.go +++ b/src/pkg/dns/fabric_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "errors" + "sync" "testing" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" @@ -109,7 +110,7 @@ func TestFabricResolverLookupNS(t *testing.T) { func TestUseFabricResolver(t *testing.T) { t.Cleanup(func() { fabricClient = nil - ResolverAt = DirectResolverAt + resolverAt = DirectResolverAt }) m := &mockFabricClient{ipResp: &defangv1.ResolveIPAddrResponse{IpAddrs: []string{"9.9.9.9"}}} @@ -130,3 +131,38 @@ func TestUseFabricResolver(t *testing.T) { t.Errorf("ResolverAt did not return FabricResolver: %T %+v", r, r) } } + +// TestResolverAtConcurrentWithUseFabricResolver exercises the synchronization +// between UseFabricResolver (which swaps resolverAt) and ResolverAt callers. +// Run with `go test -race` — prior to the mutex-guarded ResolverAt, concurrent +// writes and reads on the package-level variable were a data race. +func TestResolverAtConcurrentWithUseFabricResolver(t *testing.T) { + t.Cleanup(func() { + fabricClient = nil + resolverAt = DirectResolverAt + }) + + m := &mockFabricClient{nsResp: &defangv1.ResolveNSResponse{Hosts: []string{"ns1.example.com."}}} + + var wg sync.WaitGroup + stop := make(chan struct{}) + for range 4 { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + _ = ResolverAt("ns.example.com") + } + } + }() + } + for range 200 { + UseFabricResolver(m) + } + close(stop) + wg.Wait() +} diff --git a/src/pkg/dns/resolver.go b/src/pkg/dns/resolver.go index 5525264bb..d9f853444 100644 --- a/src/pkg/dns/resolver.go +++ b/src/pkg/dns/resolver.go @@ -28,8 +28,7 @@ type FabricResolverClient interface { ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) } -// fabricMu guards concurrent access to fabricClient and the ResolverAt -// assignment inside UseFabricResolver. +// fabricMu guards concurrent access to fabricClient and resolverAt. var fabricMu sync.RWMutex // fabricClient is set by UseFabricResolver. When non-nil, RootResolver and @@ -43,7 +42,7 @@ func UseFabricResolver(c FabricResolverClient) { fabricMu.Lock() defer fabricMu.Unlock() fabricClient = c - ResolverAt = func(nsServer string) Resolver { + resolverAt = func(nsServer string) Resolver { return FabricResolver{Client: c, NSServer: nsServer} } } @@ -195,7 +194,20 @@ func DirectResolverAt(nsServer string) Resolver { return DirectResolver{NSServer: nsServer} } -var ResolverAt = DirectResolverAt +// resolverAt is the package-private function that produces a Resolver bound to +// a given nameserver. It is swapped out by UseFabricResolver. All reads must go +// through ResolverAt so they're synchronized with that write. +var resolverAt = DirectResolverAt + +// ResolverAt returns a Resolver bound to nsServer. When UseFabricResolver has +// wired in a fabric client, the returned Resolver issues remote RPCs; +// otherwise it performs direct UDP DNS queries. +func ResolverAt(nsServer string) Resolver { + fabricMu.RLock() + fn := resolverAt + fabricMu.RUnlock() + return fn(nsServer) +} var ErrNoSuchHost = &net.DNSError{Err: "no such host", IsNotFound: true} diff --git a/src/pkg/dns/resolver_test.go b/src/pkg/dns/resolver_test.go index 06eb5e7ac..243c370b5 100644 --- a/src/pkg/dns/resolver_test.go +++ b/src/pkg/dns/resolver_test.go @@ -7,11 +7,11 @@ import ( func TestFindNSServer(t *testing.T) { t.Cleanup(func() { - ResolverAt = DirectResolverAt + resolverAt = DirectResolverAt }) t.Run("NS server not exist on domain", func(t *testing.T) { - ResolverAt = func(nsServer string) Resolver { + resolverAt = func(nsServer string) Resolver { if strings.Contains(nsServer, "root-servers.net") { return MockResolver{Records: map[DNSRequest]DNSResponse{ {Type: "NS", Domain: "a.b.c.d"}: {Records: []string{"1.tld-servers.com", "2.tld-servers.com"}, Error: nil}, @@ -42,7 +42,7 @@ func TestFindNSServer(t *testing.T) { }) t.Run("NS server exist on domain (delegarted apex domain)", func(t *testing.T) { - ResolverAt = func(nsServer string) Resolver { + resolverAt = func(nsServer string) Resolver { if strings.Contains(nsServer, "root-servers.net") { return MockResolver{Records: map[DNSRequest]DNSResponse{ {Type: "NS", Domain: "a.b.c.d"}: {Records: []string{"1.tld-servers.com", "2.tld-servers.com"}, Error: nil}, From b6b6d9af3a244af860d87ea08207622c83c28810 Mon Sep 17 00:00:00 2001 From: Jordan Stephens Date: Fri, 1 May 2026 11:29:35 -0700 Subject: [PATCH 4/4] refactor(dns): dependency-inject fabric resolver; fix lookup regressions (#2078) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(aws): carry FabricResolverClient in ByocAws for explicit DNS routing PrepareDomainDelegation now builds a FabricResolver directly from the stored fabricClient instead of reading the global dns.ResolverAt. This makes the NS conflict check's DNS path explicit and traceable without relying on UseFabricResolver having been called first. Co-Authored-By: Claude Sonnet 4.6 * refactor(dns): thread resolverAt explicitly through check.go and FindNSServers CheckDomainDNSReady, getCNAMEInSync, getIPInSync, and FindNSServers now accept an explicit resolverAt parameter instead of reading the global. cert.go's waitForCNAME builds a FabricResolver-backed factory from its FabricClient argument and passes it through. MockFabricClient gains stub implementations for the three Resolve RPCs to prevent nil-interface panics when used in tests. Co-Authored-By: Claude Sonnet 4.6 * refactor(dns): remove global resolver state; add ResolverAt field to RootResolver UseFabricResolver, getFabricClient, fabricMu, fabricClient, and the package-level resolverAt var are all deleted. RootResolver gains an explicit ResolverAt field (nil falls back to DirectResolverAt), and FindNSServers takes a resolverAt parameter so callers control which resolver is used when walking the NS delegation chain. connect.go no longer calls UseFabricResolver. fabric_test.go drops the tests for the now-deleted functions. Co-Authored-By: Claude Sonnet 4.6 * fix(dns): restore fabric DNS routing for all lookup paths ALB/CNAME IP lookups in CheckDomainDNSReady, the HTTP client used for cert generation, and CheckTLSCert were still using direct DNS after the DI refactor. Thread the fabric resolver through all three paths so no lookup silently falls back to direct UDP DNS. - DirectResolverAt("") now returns RootResolver{} as the "root" fallback, enabling resolverAt("") as a uniform way to request recursive resolution regardless of whether fabric or direct DNS is in use - CheckDomainDNSReady uses resolverAt("") for ALB/CNAME IP lookups instead of a package-level RootResolver{} - cert.CheckTLSCert accepts a dns.Resolver parameter - cli/cert: replace global resolver/httpClient vars with newCertHTTPClient(r dns.Resolver); thread resolver through generateCert → triggerCertGeneration / waitForTLS / CheckTLSCert Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 --- src/cmd/cli/command/cd.go | 2 +- src/pkg/cert/check.go | 3 +- src/pkg/cli/cert.go | 44 ++++++++------ src/pkg/cli/cert_test.go | 44 ++++---------- src/pkg/cli/client/byoc/aws/byoc.go | 13 +++-- src/pkg/cli/client/byoc/aws/byoc_test.go | 2 +- src/pkg/cli/client/mock.go | 12 ++++ src/pkg/cli/connect.go | 7 +-- src/pkg/dns/check.go | 25 ++++---- src/pkg/dns/check_test.go | 52 +++++------------ src/pkg/dns/fabric_test.go | 61 ------------------- src/pkg/dns/resolver.go | 74 +++++++----------------- src/pkg/dns/resolver_test.go | 12 ++-- 13 files changed, 115 insertions(+), 236 deletions(-) diff --git a/src/cmd/cli/command/cd.go b/src/cmd/cli/command/cd.go index 6edde9cb7..d3a0f0109 100644 --- a/src/cmd/cli/command/cd.go +++ b/src/cmd/cli/command/cd.go @@ -211,7 +211,7 @@ var cdCloudformationCmd = &cobra.Command{ Args: cobra.NoArgs, Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { - provider := aws.NewByocProvider(cmd.Context(), global.Client.GetTenantName(), global.Stack.Name) + provider := aws.NewByocProvider(cmd.Context(), global.Client.GetTenantName(), global.Stack.Name, global.Client) if err := canIUseProvider(cmd.Context(), provider, "", 0, false); err != nil { return err diff --git a/src/pkg/cert/check.go b/src/pkg/cert/check.go index c00e79b4e..2b28e0a22 100644 --- a/src/pkg/cert/check.go +++ b/src/pkg/cert/check.go @@ -9,8 +9,7 @@ import ( "github.com/DefangLabs/defang/src/pkg/dns" ) -func CheckTLSCert(ctx context.Context, domain string) error { - resolver := dns.RootResolver{} +func CheckTLSCert(ctx context.Context, domain string, resolver dns.Resolver) error { ips, err := resolver.LookupIPAddr(ctx, domain) if err != nil { return err diff --git a/src/pkg/cli/cert.go b/src/pkg/cli/cert.go index 954958663..5b072f1ac 100644 --- a/src/pkg/cli/cert.go +++ b/src/pkg/cli/cert.go @@ -32,10 +32,14 @@ type DNSResult struct { } var ( - resolver dns.Resolver = dns.RootResolver{} - dnsCache = make(map[string]DNSResult) - dnsCacheDuration = 1 * time.Minute - httpClient HTTPClient = &http.Client{ + dnsCache = make(map[string]DNSResult) + dnsCacheDuration = 1 * time.Minute + + httpRetryDelayBase = 5 * time.Second +) + +func newCertHTTPClient(r dns.Resolver) HTTPClient { + return &http.Client{ // Based on the default transport: https://pkg.go.dev/net/http#RoundTripper Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -49,7 +53,7 @@ var ( if ok && cached.Expiry.After(time.Now()) { ips = cached.IPs } else { - ips, err = resolver.LookupIPAddr(ctx, host) + ips, err = r.LookupIPAddr(ctx, host) if err != nil { return nil, err } @@ -73,8 +77,7 @@ var ( return nil }, } - httpRetryDelayBase = 5 * time.Second -) +} func GenerateLetsEncryptCert(ctx context.Context, project *compose.Project, client client.FabricClient, provider client.Provider) error { term.Debugf("Generating TLS cert for project %q", project.Name) @@ -105,7 +108,7 @@ func GenerateLetsEncryptCert(ctx context.Context, project *compose.Project, clie } term.Debugf("Found service %v with domains %v and targets %v", service.Name, domains, targets) for _, domain := range domains { - generateCert(ctx, domain, targets, client) + generateCert(ctx, domain, targets, client, dns.FabricResolver{Client: client}) } } } @@ -131,7 +134,7 @@ func getDomainTargets(serviceInfo *defangv1.ServiceInfo, service compose.Service } } -func generateCert(ctx context.Context, domain string, targets []string, client client.FabricClient) { +func generateCert(ctx context.Context, domain string, targets []string, client client.FabricClient, r dns.Resolver) { term.Infof("Checking DNS setup for %v", domain) if err := waitForCNAME(ctx, domain, targets, client); err != nil { term.Errorf("Error waiting for CNAME: %v", err) @@ -139,7 +142,7 @@ func generateCert(ctx context.Context, domain string, targets []string, client c } term.Infof("%v DNS is properly configured!", domain) - if err := cert.CheckTLSCert(ctx, domain); err == nil { + if err := cert.CheckTLSCert(ctx, domain, r); err == nil { term.Infof("TLS cert for %v is already ready", domain) return } @@ -148,13 +151,13 @@ func generateCert(ctx context.Context, domain string, targets []string, client c return } term.Infof("Triggering cert generation for %v", domain) - if err := triggerCertGeneration(ctx, domain); err != nil { + if err := triggerCertGeneration(ctx, domain, r); err != nil { term.Errorf("Error triggering cert generation, please try again") return } term.Infof("Waiting for TLS cert to be online for %v, this could take a few minutes", domain) - if err := waitForTLS(ctx, domain); err != nil { + if err := waitForTLS(ctx, domain, r); err != nil { term.Errorf("Error waiting for TLS to be online: %v", err) // FIXME: Add more info on how to debug, possibly provided by the server side to avoid client type detection here return @@ -163,7 +166,7 @@ func generateCert(ctx context.Context, domain string, targets []string, client c term.Infof("TLS cert for %v is ready\n", domain) } -func triggerCertGeneration(ctx context.Context, domain string) error { +func triggerCertGeneration(ctx context.Context, domain string, r dns.Resolver) error { doSpinner := term.StdoutCanColor() && term.IsTerminal() if doSpinner { term.HideCursor() @@ -174,7 +177,7 @@ func triggerCertGeneration(ctx context.Context, domain string) error { defer cancelSpinner() } // Our own retry logic uses the root resolver to prevent cached DNS and retry on all non-200 errors - if err := getWithRetries(ctx, fmt.Sprintf("http://%v", domain), 5); err != nil { // Retry incase of DNS error + if err := getWithRetries(ctx, fmt.Sprintf("http://%v", domain), 5, newCertHTTPClient(r)); err != nil { // Retry incase of DNS error // Ignore possible tls error as cert attachment may take time term.Debugf("Error triggering cert generation: %v", err) return err @@ -182,7 +185,7 @@ func triggerCertGeneration(ctx context.Context, domain string) error { return nil } -func waitForTLS(ctx context.Context, domain string) error { +func waitForTLS(ctx context.Context, domain string, r dns.Resolver) error { ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() timeout, cancel := context.WithTimeout(ctx, 10*time.Minute) @@ -202,7 +205,7 @@ func waitForTLS(ctx context.Context, domain string) error { case <-timeout.Done(): return timeout.Err() case <-ticker.C: - if err := cert.CheckTLSCert(timeout, domain); err == nil { + if err := cert.CheckTLSCert(timeout, domain, r); err == nil { return nil } else { term.Debugf("Error checking TLS cert for %v: %v", domain, err) @@ -249,7 +252,10 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c } } if serverSideVerified || serverVerifyRpcFailure >= 3 { - locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets) + fabricResolverAt := func(nsServer string) dns.Resolver { + return dns.FabricResolver{Client: client, NSServer: nsServer} + } + locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets, fabricResolverAt) if serverSideVerified && !locallyVerified { term.Warnf("DNS settings for %v are verified, but changes may take a few minutes to propagate due to caching.", domain) return nil @@ -280,14 +286,14 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c } } -func getWithRetries(ctx context.Context, url string, tries int) error { +func getWithRetries(ctx context.Context, url string, tries int, c HTTPClient) error { var errs []error for i := range make([]struct{}, tries) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return err // No point retrying if we can't even create the request } - resp, err := httpClient.Do(req) + resp, err := c.Do(req) if err == nil { defer resp.Body.Close() _, err = io.ReadAll(resp.Body) // Read the body to ensure the request is not swallowed by alb diff --git a/src/pkg/cli/cert_test.go b/src/pkg/cli/cert_test.go index c513d5966..917141bae 100644 --- a/src/pkg/cli/cert_test.go +++ b/src/pkg/cli/cert_test.go @@ -59,10 +59,7 @@ func TestGetWithRetries(t *testing.T) { tc := &testClient{tries: []tryResult{ {result: &http.Response{StatusCode: 200, Body: mockBody("")}, err: nil}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 3) + err := getWithRetries(t.Context(), "http://example.com", 3, tc) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -76,10 +73,7 @@ func TestGetWithRetries(t *testing.T) { {result: nil, err: errors.New("error")}, {result: &http.Response{StatusCode: 200, Body: mockBody("")}, err: nil}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 3) + err := getWithRetries(t.Context(), "http://example.com", 3, tc) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -93,10 +87,7 @@ func TestGetWithRetries(t *testing.T) { {result: &http.Response{StatusCode: 503, Body: mockBody("Random Error")}, err: nil}, {result: nil, err: errors.New("error")}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 3) + err := getWithRetries(t.Context(), "http://example.com", 3, tc) if err == nil { t.Errorf("Expected error, got %v", err) } else if !strings.Contains(err.Error(), "HTTP: 503") { @@ -111,10 +102,7 @@ func TestGetWithRetries(t *testing.T) { tc := &testClient{tries: []tryResult{ {result: &http.Response{StatusCode: 503, Request: &http.Request{URL: redirectURL}, Body: mockBody("Random Error")}, err: nil}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 3) + err := getWithRetries(t.Context(), "http://example.com", 3, tc) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -126,10 +114,7 @@ func TestGetWithRetries(t *testing.T) { tc := &testClient{tries: []tryResult{ {result: nil, err: &tls.CertificateVerificationError{Err: errors.New("error")}}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 3) + err := getWithRetries(t.Context(), "http://example.com", 3, tc) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -143,10 +128,7 @@ func TestGetWithRetries(t *testing.T) { {result: &http.Response{StatusCode: 502, Body: mockBody("Random Error")}, err: nil}, {result: &http.Response{StatusCode: 503, Body: mockBody("Random Error")}, err: nil}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 3) + err := getWithRetries(t.Context(), "http://example.com", 3, tc) if err == nil { t.Errorf("Expected error, got %v", err) } else if !strings.Contains(err.Error(), "HTTP: 404") || !strings.Contains(err.Error(), "HTTP: 502") || !strings.Contains(err.Error(), "HTTP: 503") { @@ -162,10 +144,7 @@ func TestGetWithRetries(t *testing.T) { {result: &http.Response{StatusCode: 502, Body: mockBody("Random Error")}, err: nil}, {result: &http.Response{StatusCode: 503, Body: mockBody("Random Error")}, err: nil}, }} - originalClient := httpClient - t.Cleanup(func() { httpClient = originalClient }) - httpClient = tc - err := getWithRetries(t.Context(), "http://example.com", 1) + err := getWithRetries(t.Context(), "http://example.com", 1, tc) if err == nil { t.Errorf("Expected error, got %v", err) } @@ -198,9 +177,10 @@ func TestHttpClient(t *testing.T) { })) defer ts.Close() var mr MockResolver - resolver = &mr dnsCacheDuration = 50 * time.Millisecond + tc := newCertHTTPClient(&mr) + tsu, err := url.Parse(ts.URL) if err != nil { t.Fatalf("could not parse test server url '%v': %v", ts.URL, err) @@ -214,7 +194,7 @@ func TestHttpClient(t *testing.T) { t.Fatalf("failed to create request: %v", err) } - resp, err := httpClient.Do(req) + resp, err := tc.Do(req) if err != nil { t.Fatalf("failed to make http call: %v", err) } @@ -224,7 +204,7 @@ func TestHttpClient(t *testing.T) { t.Fatalf("expected 1 dns lookup, but got %v", mr.calls) } - resp, err = httpClient.Do(req) + resp, err = tc.Do(req) if err != nil { t.Fatalf("failed to make http call: %v", err) } @@ -234,7 +214,7 @@ func TestHttpClient(t *testing.T) { } time.Sleep(80 * time.Millisecond) - resp, err = httpClient.Do(req) + resp, err = tc.Do(req) if err != nil { t.Fatalf("failed to make http call: %v", err) } diff --git a/src/pkg/cli/client/byoc/aws/byoc.go b/src/pkg/cli/client/byoc/aws/byoc.go index babf9950b..138559962 100644 --- a/src/pkg/cli/client/byoc/aws/byoc.go +++ b/src/pkg/cli/client/byoc/aws/byoc.go @@ -52,7 +52,8 @@ type Config = awssdk.Config type ByocAws struct { *byoc.ByocBaseClient - driver *cfn.AwsCfn // TODO: ecs is stateful, contains the output of the cd cfn stack after SetUpCD + driver *cfn.AwsCfn // TODO: ecs is stateful, contains the output of the cd cfn stack after SetUpCD + fabricClient dns.FabricResolverClient cdEtag types.ETag cdStart time.Time @@ -114,7 +115,7 @@ func AnnotateAwsError(err error) error { return err } -func NewByocProvider(ctx context.Context, tenantName types.TenantLabel, stack string) *ByocAws { +func NewByocProvider(ctx context.Context, tenantName types.TenantLabel, stack string, fabricClient dns.FabricResolverClient) *ByocAws { if awsProfileName := os.Getenv("AWS_PROFILE"); awsProfileName != "" { AWSAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID") AWSSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY") @@ -129,7 +130,8 @@ func NewByocProvider(ctx context.Context, tenantName types.TenantLabel, stack st } b := &ByocAws{ - driver: cfn.New(byoc.CdTaskPrefix, aws.Region("")), // default region + driver: cfn.New(byoc.CdTaskPrefix, aws.Region("")), // default region + fabricClient: fabricClient, } b.ByocBaseClient = byoc.NewByocBaseClient(tenantName, b, stack) @@ -425,7 +427,10 @@ func (b *ByocAws) PrepareDomainDelegation(ctx context.Context, req client.Prepar r53Client := route53.NewFromConfig(cfg) projectDomain := req.DelegateDomain - nsServers, delegationSetId, err := prepareDomainDelegation(ctx, projectDomain, req.Project, b.PulumiStack, r53Client, dns.ResolverAt) + resolverAt := func(nsServer string) dns.Resolver { + return dns.FabricResolver{Client: b.fabricClient, NSServer: nsServer} + } + nsServers, delegationSetId, err := prepareDomainDelegation(ctx, projectDomain, req.Project, b.PulumiStack, r53Client, resolverAt) if err != nil { return nil, AnnotateAwsError(err) } diff --git a/src/pkg/cli/client/byoc/aws/byoc_test.go b/src/pkg/cli/client/byoc/aws/byoc_test.go index 11a3c7e6e..0c436f5d4 100644 --- a/src/pkg/cli/client/byoc/aws/byoc_test.go +++ b/src/pkg/cli/client/byoc/aws/byoc_test.go @@ -379,7 +379,7 @@ aws_secret_access_key = wJalrXUtnFEMI/KDEFANG/bPxRfiCYEXAMPLEKEY ctx := t.Context() // Create ByocAws instance - warning is printed here in NewByocProvider - b := NewByocProvider(ctx, "tenant1", "exampleStack") + b := NewByocProvider(ctx, "tenant1", "exampleStack", nil) _, err := b.AccountInfo(ctx) if tt.expectedError { diff --git a/src/pkg/cli/client/mock.go b/src/pkg/cli/client/mock.go index 048e1862f..2591b511d 100644 --- a/src/pkg/cli/client/mock.go +++ b/src/pkg/cli/client/mock.go @@ -272,6 +272,18 @@ func (m MockFabricClient) GetDefaultStack(context.Context, *defangv1.GetDefaultS }, nil } +func (m MockFabricClient) ResolveIPAddr(_ context.Context, _ *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) { + return &defangv1.ResolveIPAddrResponse{}, nil +} + +func (m MockFabricClient) ResolveCNAME(_ context.Context, _ *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) { + return &defangv1.ResolveCNAMEResponse{}, nil +} + +func (m MockFabricClient) ResolveNS(_ context.Context, _ *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) { + return &defangv1.ResolveNSResponse{}, nil +} + type MockLoader struct { Project composeTypes.Project Error error diff --git a/src/pkg/cli/connect.go b/src/pkg/cli/connect.go index 6d9b2cb32..e5e7fb898 100644 --- a/src/pkg/cli/connect.go +++ b/src/pkg/cli/connect.go @@ -7,7 +7,6 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/aws" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/do" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/gcp" - "github.com/DefangLabs/defang/src/pkg/dns" "github.com/DefangLabs/defang/src/pkg/term" "github.com/DefangLabs/defang/src/pkg/types" ) @@ -18,9 +17,7 @@ func Connect(fabricAddr string, requestedTenant types.TenantNameOrID) *client.Gr term.Debugf("Using tenant %q for cluster %q", requestedTenant, host) accessToken := client.GetExistingToken(host) - grpcClient := client.NewGrpcClient(host, accessToken, requestedTenant) - dns.UseFabricResolver(grpcClient) - return grpcClient + return client.NewGrpcClient(host, accessToken, requestedTenant) } func ConnectWithTenant(ctx context.Context, fabricAddr string, requestedTenant types.TenantNameOrID) (*client.GrpcClient, error) { @@ -41,7 +38,7 @@ func NewProvider(ctx context.Context, providerID client.ProviderID, fabricClient term.Debugf("Creating %s provider", providerID) switch providerID { case client.ProviderAWS: - provider = aws.NewByocProvider(ctx, fabricClient.GetTenantName(), stack) + provider = aws.NewByocProvider(ctx, fabricClient.GetTenantName(), stack, fabricClient) case client.ProviderDO: provider = do.NewByocProvider(ctx, fabricClient.GetTenantName(), stack) case client.ProviderGCP: diff --git a/src/pkg/dns/check.go b/src/pkg/dns/check.go index 839e26eb6..13c42ecc2 100644 --- a/src/pkg/dns/check.go +++ b/src/pkg/dns/check.go @@ -17,19 +17,18 @@ type logger interface { } var ( - Logger logger = term.DefaultTerm - resolver Resolver = RootResolver{} + Logger logger = term.DefaultTerm errDNSNotInSync = errors.New("DNS not in sync") ) // The DNS is considered ready if the CNAME of the domain is pointing to the ALB domain and in sync // OR if the A record of the domain is pointing to the same IP addresses of the ALB domain and in sync -func CheckDomainDNSReady(ctx context.Context, domain string, validCNAMEs []string) bool { +func CheckDomainDNSReady(ctx context.Context, domain string, validCNAMEs []string, resolverAt func(string) Resolver) bool { for i, validCNAME := range validCNAMEs { validCNAMEs[i] = Normalize(validCNAME) } - cname, err := getCNAMEInSync(ctx, domain) + cname, err := getCNAMEInSync(ctx, domain, resolverAt) Logger.Debugf("CNAME for %v is: '%v', err: %v", domain, cname, err) // Ignore other types of DNS errors if err == errDNSNotInSync { @@ -42,7 +41,7 @@ func CheckDomainDNSReady(ctx context.Context, domain string, validCNAMEs []strin return true } - albIPAddrs, err := resolver.LookupIPAddr(ctx, validCNAMEs[0]) + albIPAddrs, err := resolverAt("").LookupIPAddr(ctx, validCNAMEs[0]) if err != nil { Logger.Debugf("Could not resolve A/AAAA record for load balancer %v: %v", validCNAMEs[0], err) return false @@ -52,7 +51,7 @@ func CheckDomainDNSReady(ctx context.Context, domain string, validCNAMEs []strin // In sync CNAME may be pointing to the same IP addresses of the load balancer, considered as valid Logger.Debugf("Checking CNAME %v", cname) if cname != "" { - cnameIPAddrs, err := resolver.LookupIPAddr(ctx, cname) + cnameIPAddrs, err := resolverAt("").LookupIPAddr(ctx, cname) if err != nil { Logger.Debugf("Could not resolve A/AAAA record for %v: %v", cname, err) } else { @@ -66,7 +65,7 @@ func CheckDomainDNSReady(ctx context.Context, domain string, validCNAMEs []strin } // Check if an valid A record has been set - ips, err := getIPInSync(ctx, domain) + ips, err := getIPInSync(ctx, domain, resolverAt) if err != nil { Logger.Debugf("IP for %v not in sync: %v", domain, err) return false @@ -78,8 +77,8 @@ func CheckDomainDNSReady(ctx context.Context, domain string, validCNAMEs []strin return false } -func getCNAMEInSync(ctx context.Context, domain string) (string, error) { - ns, err := FindNSServers(ctx, domain) +func getCNAMEInSync(ctx context.Context, domain string, resolverAt func(string) Resolver) (string, error) { + ns, err := FindNSServers(ctx, domain, resolverAt) if err != nil { return "", err } @@ -88,7 +87,7 @@ func getCNAMEInSync(ctx context.Context, domain string) (string, error) { var cname string var lookupErr error for _, n := range ns { - cname, err = ResolverAt(n.Host).LookupCNAME(ctx, domain) + cname, err = resolverAt(n.Host).LookupCNAME(ctx, domain) if err != nil { Logger.Debugf("Error looking up CNAME for %v at %v: %v", domain, n, err) lookupErr = err @@ -102,8 +101,8 @@ func getCNAMEInSync(ctx context.Context, domain string) (string, error) { return cname, lookupErr } -func getIPInSync(ctx context.Context, domain string) ([]net.IP, error) { - ns, err := FindNSServers(ctx, domain) +func getIPInSync(ctx context.Context, domain string, resolverAt func(string) Resolver) ([]net.IP, error) { + ns, err := FindNSServers(ctx, domain, resolverAt) if err != nil { return nil, err } @@ -112,7 +111,7 @@ func getIPInSync(ctx context.Context, domain string) ([]net.IP, error) { var lookupErr error for i, n := range ns { var ipAddrs []net.IPAddr - ipAddrs, err = ResolverAt(n.Host).LookupIPAddr(ctx, domain) + ipAddrs, err = resolverAt(n.Host).LookupIPAddr(ctx, domain) if err != nil { Logger.Debugf("Error looking up IP for %v at %v: %v", domain, n, err) lookupErr = err diff --git a/src/pkg/dns/check_test.go b/src/pkg/dns/check_test.go index 46595a971..c308290e7 100644 --- a/src/pkg/dns/check_test.go +++ b/src/pkg/dns/check_test.go @@ -12,10 +12,6 @@ import ( var notFound = errors.New("not found") func TestGetCNAMEInSync(t *testing.T) { - t.Cleanup(func() { - resolverAt = DirectResolverAt - }) - notFoundResolver := MockResolver{Records: map[DNSRequest]DNSResponse{ {Type: "NS", Domain: "web.test.com"}: {Records: []string{"ns1.example.com", "ns2.example.com"}, Error: nil}, {Type: "CNAME", Domain: "web.test.com"}: {Records: nil, Error: notFound}, @@ -27,8 +23,7 @@ func TestGetCNAMEInSync(t *testing.T) { // Test when the domain is not found t.Run("domain not found", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return notFoundResolver } - _, err := getCNAMEInSync(t.Context(), "web.test.com") + _, err := getCNAMEInSync(t.Context(), "web.test.com", func(_ string) Resolver { return notFoundResolver }) if err != notFound { t.Errorf("Expected NotFound error, got %v", err) } @@ -36,14 +31,14 @@ func TestGetCNAMEInSync(t *testing.T) { // Test when the domain is found but the DNS servers are not in sync t.Run("DNS servers not in sync", func(t *testing.T) { - resolverAt = func(nsServer string) Resolver { + resolverAt := func(nsServer string) Resolver { if nsServer == "ns1.example.com" { return foundResolver } else { return notFoundResolver } } - _, err := getCNAMEInSync(t.Context(), "web.test.com") + _, err := getCNAMEInSync(t.Context(), "web.test.com", resolverAt) if err != errDNSNotInSync { t.Errorf("Expected NotInSync error, got %v", err) } @@ -51,8 +46,7 @@ func TestGetCNAMEInSync(t *testing.T) { // Test when the domain is found and the DNS servers are in sync t.Run("DNS servers in sync", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return foundResolver } - cname, err := getCNAMEInSync(t.Context(), "web.test.com") + cname, err := getCNAMEInSync(t.Context(), "web.test.com", func(_ string) Resolver { return foundResolver }) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -64,10 +58,6 @@ func TestGetCNAMEInSync(t *testing.T) { } func TestGetIPInSync(t *testing.T) { - t.Cleanup(func() { - resolverAt = DirectResolverAt - }) - notFoundResolver := MockResolver{Records: map[DNSRequest]DNSResponse{ {Type: "NS", Domain: "test.com"}: {Records: []string{"ns1.example.com", "ns2.example.com"}, Error: nil}, {Type: "A", Domain: "test.com"}: {Records: nil, Error: notFound}, @@ -83,8 +73,7 @@ func TestGetIPInSync(t *testing.T) { // Test when the domain is not found t.Run("domain not found", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return notFoundResolver } - _, err := getIPInSync(t.Context(), "test.com") + _, err := getIPInSync(t.Context(), "test.com", func(_ string) Resolver { return notFoundResolver }) if err != notFound { t.Errorf("Expected NotFound error, got %v", err) } @@ -92,14 +81,14 @@ func TestGetIPInSync(t *testing.T) { // Test when the domain is found but the DNS servers are not in sync t.Run("DNS servers not in sync", func(t *testing.T) { - resolverAt = func(nsServer string) Resolver { + resolverAt := func(nsServer string) Resolver { if nsServer == "ns1.example.com" { return foundResolver } else { return notFoundResolver } } - _, err := getIPInSync(t.Context(), "test.com") + _, err := getIPInSync(t.Context(), "test.com", resolverAt) if err != errDNSNotInSync { t.Errorf("Expected NotInSyncError error, got %v", err) } @@ -107,14 +96,14 @@ func TestGetIPInSync(t *testing.T) { // 2nd not in sync scenario t.Run("DNS servers not in sync with partial results", func(t *testing.T) { - resolverAt = func(nsServer string) Resolver { + resolverAt := func(nsServer string) Resolver { if nsServer == "ns1.example.com" { return partialFoundResolver } else { return foundResolver } } - _, err := getIPInSync(t.Context(), "test.com") + _, err := getIPInSync(t.Context(), "test.com", resolverAt) if err != errDNSNotInSync { t.Errorf("Expected NotInSyncError error, got %v", err) } @@ -122,8 +111,7 @@ func TestGetIPInSync(t *testing.T) { // Test when the domain is found and the DNS servers are in sync t.Run("DNS servers in sync", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return foundResolver } - ips, err := getIPInSync(t.Context(), "test.com") + ips, err := getIPInSync(t.Context(), "test.com", func(_ string) Resolver { return foundResolver }) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -151,45 +139,37 @@ func TestCheckDomainDNSReady(t *testing.T) { {Type: "A", Domain: "some-alb.domain.com"}: {Records: []string{"1.2.3.4", "5,6,7,8"}, Error: nil}, {Type: "CNAME", Domain: "api.test.com"}: {Records: []string{"some-alb.domain.com"}, Error: nil}, }} - resolver = hasARecordResolver - - oldResolver, oldDebug := resolverAt, term.DoDebug() + oldDebug := term.DoDebug() t.Cleanup(func() { - resolverAt = oldResolver term.SetDebug(oldDebug) }) term.SetDebug(true) t.Run("CNAME and A records not found", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return emptyResolver } - if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != false { + if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}, func(_ string) Resolver { return emptyResolver }) != false { t.Errorf("Expected false when both CNAME and A records are missing, got true") } }) t.Run("CNAME setup correctly", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return hasCNAMEResolver } - if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != true { + if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}, func(_ string) Resolver { return hasCNAMEResolver }) != true { t.Errorf("Expected true when CNAME is setup correctly, got false") } }) t.Run("CNAME setup incorrectly", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return hasCNAMEResolver } - if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-other-alb.domain.com"}) != false { + if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-other-alb.domain.com"}, func(_ string) Resolver { return hasCNAMEResolver }) != false { t.Errorf("Expected false when CNAME is setup incorrectly, got true") } }) t.Run("A record setup correctly", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return hasARecordResolver } - if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != true { + if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}, func(_ string) Resolver { return hasARecordResolver }) != true { t.Errorf("Expected true when A record is setup correctly, got false") } }) t.Run("A record setup incorrectly", func(t *testing.T) { - resolverAt = func(_ string) Resolver { return hasWrongARecordResolver } - if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != false { + if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}, func(_ string) Resolver { return hasWrongARecordResolver }) != false { t.Errorf("Expected false when A record is setup incorrectly, got true") } }) diff --git a/src/pkg/dns/fabric_test.go b/src/pkg/dns/fabric_test.go index a90b12932..2b6b1b62f 100644 --- a/src/pkg/dns/fabric_test.go +++ b/src/pkg/dns/fabric_test.go @@ -3,7 +3,6 @@ package dns import ( "context" "errors" - "sync" "testing" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" @@ -106,63 +105,3 @@ func TestFabricResolverLookupNS(t *testing.T) { t.Errorf("unexpected NS result: %+v", ns) } } - -func TestUseFabricResolver(t *testing.T) { - t.Cleanup(func() { - fabricClient = nil - resolverAt = DirectResolverAt - }) - - m := &mockFabricClient{ipResp: &defangv1.ResolveIPAddrResponse{IpAddrs: []string{"9.9.9.9"}}} - UseFabricResolver(m) - - // RootResolver should now delegate to FabricResolver. - ips, err := RootResolver{}.LookupIPAddr(t.Context(), "example.com") - if err != nil { - t.Fatalf("RootResolver.LookupIPAddr: %v", err) - } - if len(ips) != 1 || ips[0].IP.String() != "9.9.9.9" { - t.Errorf("unexpected IPs: %v", ips) - } - - // ResolverAt should return a FabricResolver bound to the NS. - r := ResolverAt("ns1.example.com") - if fr, ok := r.(FabricResolver); !ok || fr.NSServer != "ns1.example.com" { - t.Errorf("ResolverAt did not return FabricResolver: %T %+v", r, r) - } -} - -// TestResolverAtConcurrentWithUseFabricResolver exercises the synchronization -// between UseFabricResolver (which swaps resolverAt) and ResolverAt callers. -// Run with `go test -race` — prior to the mutex-guarded ResolverAt, concurrent -// writes and reads on the package-level variable were a data race. -func TestResolverAtConcurrentWithUseFabricResolver(t *testing.T) { - t.Cleanup(func() { - fabricClient = nil - resolverAt = DirectResolverAt - }) - - m := &mockFabricClient{nsResp: &defangv1.ResolveNSResponse{Hosts: []string{"ns1.example.com."}}} - - var wg sync.WaitGroup - stop := make(chan struct{}) - for range 4 { - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-stop: - return - default: - _ = ResolverAt("ns.example.com") - } - } - }() - } - for range 200 { - UseFabricResolver(m) - } - close(stop) - wg.Wait() -} diff --git a/src/pkg/dns/resolver.go b/src/pkg/dns/resolver.go index d9f853444..297e44168 100644 --- a/src/pkg/dns/resolver.go +++ b/src/pkg/dns/resolver.go @@ -7,7 +7,6 @@ import ( "net" "slices" "sort" - "sync" "github.com/DefangLabs/defang/src/pkg" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" @@ -28,31 +27,6 @@ type FabricResolverClient interface { ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) } -// fabricMu guards concurrent access to fabricClient and resolverAt. -var fabricMu sync.RWMutex - -// fabricClient is set by UseFabricResolver. When non-nil, RootResolver and -// ResolverAt route DNS lookups through the fabric gRPC API. -var fabricClient FabricResolverClient - -// UseFabricResolver wires DNS lookups through the fabric gRPC API. After it is -// called, RootResolver{} and ResolverAt(nsServer) both issue remote RPCs -// instead of performing direct UDP DNS queries. -func UseFabricResolver(c FabricResolverClient) { - fabricMu.Lock() - defer fabricMu.Unlock() - fabricClient = c - resolverAt = func(nsServer string) Resolver { - return FabricResolver{Client: c, NSServer: nsServer} - } -} - -func getFabricClient() FabricResolverClient { - fabricMu.RLock() - defer fabricMu.RUnlock() - return fabricClient -} - // FabricResolver performs DNS lookups via the fabric gRPC API. An empty // NSServer lets the server perform recursive resolution from the root. type FabricResolver struct { @@ -109,7 +83,20 @@ func (r FabricResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, return nss, nil } -type RootResolver struct{} +// RootResolver performs recursive DNS resolution starting from the root +// nameservers. Set ResolverAt to override how individual nameservers are +// queried (e.g. to route through the Fabric gRPC API). A nil ResolverAt +// falls back to DirectResolverAt. +type RootResolver struct { + ResolverAt func(string) Resolver +} + +func (r RootResolver) resolverFn() func(string) Resolver { + if r.ResolverAt != nil { + return r.ResolverAt + } + return DirectResolverAt +} // https://en.wikipedia.org/wiki/Root_name_server var rootServers = []*net.NS{ @@ -129,9 +116,6 @@ var rootServers = []*net.NS{ } func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { - if c := getFabricClient(); c != nil { - return FabricResolver{Client: c}.LookupIPAddr(ctx, domain) - } for range 10 { ips, err := r.getResolver(ctx, domain).LookupIPAddr(ctx, domain) if err != nil { @@ -148,34 +132,28 @@ func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IP } func (r RootResolver) LookupCNAME(ctx context.Context, domain string) (string, error) { - if c := getFabricClient(); c != nil { - return FabricResolver{Client: c}.LookupCNAME(ctx, domain) - } return r.getResolver(ctx, domain).LookupCNAME(ctx, domain) } func (r RootResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { - if c := getFabricClient(); c != nil { - return FabricResolver{Client: c}.LookupNS(ctx, domain) - } return r.getResolver(ctx, domain).LookupNS(ctx, domain) } func (r RootResolver) getResolver(ctx context.Context, domain string) Resolver { - ns, err := FindNSServers(ctx, domain) + ns, err := FindNSServers(ctx, domain, r.resolverFn()) if err != nil { return DirectResolver{} } return DirectResolver{NSServer: ns[pkg.RandomIndex(len(ns))].Host} } -func FindNSServers(ctx context.Context, domain string) ([]*net.NS, error) { +func FindNSServers(ctx context.Context, domain string, resolverAt func(string) Resolver) ([]*net.NS, error) { nsServers := rootServers retries := 3 for { index := pkg.RandomIndex(len(nsServers)) nsServer := nsServers[index].Host - ns, err := ResolverAt(nsServer).LookupNS(ctx, domain) + ns, err := resolverAt(nsServer).LookupNS(ctx, domain) sort.Slice(ns, func(i, j int) bool { return ns[i].Host < ns[j].Host }) if err != nil { if retries--; retries > 0 { @@ -191,24 +169,12 @@ func FindNSServers(ctx context.Context, domain string) ([]*net.NS, error) { } func DirectResolverAt(nsServer string) Resolver { + if nsServer == "" { + return RootResolver{} + } return DirectResolver{NSServer: nsServer} } -// resolverAt is the package-private function that produces a Resolver bound to -// a given nameserver. It is swapped out by UseFabricResolver. All reads must go -// through ResolverAt so they're synchronized with that write. -var resolverAt = DirectResolverAt - -// ResolverAt returns a Resolver bound to nsServer. When UseFabricResolver has -// wired in a fabric client, the returned Resolver issues remote RPCs; -// otherwise it performs direct UDP DNS queries. -func ResolverAt(nsServer string) Resolver { - fabricMu.RLock() - fn := resolverAt - fabricMu.RUnlock() - return fn(nsServer) -} - var ErrNoSuchHost = &net.DNSError{Err: "no such host", IsNotFound: true} type ErrCNAMEFound string diff --git a/src/pkg/dns/resolver_test.go b/src/pkg/dns/resolver_test.go index 243c370b5..0dd4a3c20 100644 --- a/src/pkg/dns/resolver_test.go +++ b/src/pkg/dns/resolver_test.go @@ -6,12 +6,8 @@ import ( ) func TestFindNSServer(t *testing.T) { - t.Cleanup(func() { - resolverAt = DirectResolverAt - }) - t.Run("NS server not exist on domain", func(t *testing.T) { - resolverAt = func(nsServer string) Resolver { + resolverAt := func(nsServer string) Resolver { if strings.Contains(nsServer, "root-servers.net") { return MockResolver{Records: map[DNSRequest]DNSResponse{ {Type: "NS", Domain: "a.b.c.d"}: {Records: []string{"1.tld-servers.com", "2.tld-servers.com"}, Error: nil}, @@ -29,7 +25,7 @@ func TestFindNSServer(t *testing.T) { return nil } - ns, err := FindNSServers(t.Context(), "a.b.c.d") + ns, err := FindNSServers(t.Context(), "a.b.c.d", resolverAt) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -42,7 +38,7 @@ func TestFindNSServer(t *testing.T) { }) t.Run("NS server exist on domain (delegarted apex domain)", func(t *testing.T) { - resolverAt = func(nsServer string) Resolver { + resolverAt := func(nsServer string) Resolver { if strings.Contains(nsServer, "root-servers.net") { return MockResolver{Records: map[DNSRequest]DNSResponse{ {Type: "NS", Domain: "a.b.c.d"}: {Records: []string{"1.tld-servers.com", "2.tld-servers.com"}, Error: nil}, @@ -64,7 +60,7 @@ func TestFindNSServer(t *testing.T) { return nil } - ns, err := FindNSServers(t.Context(), "a.b.c.d") + ns, err := FindNSServers(t.Context(), "a.b.c.d", resolverAt) if err != nil { t.Errorf("Expected no error, got %v", err) }