Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cmd/cli/command/cd.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ var cdCloudformationCmd = &cobra.Command{
Args: cobra.NoArgs,
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
provider := aws.NewByocProvider(cmd.Context(), global.Client.GetTenantName(), global.Stack.Name)
provider := aws.NewByocProvider(cmd.Context(), global.Client.GetTenantName(), global.Stack.Name, global.Client)

if err := canIUseProvider(cmd.Context(), provider, "", 0, false); err != nil {
return err
Expand Down
3 changes: 1 addition & 2 deletions src/pkg/cert/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import (
"github.com/DefangLabs/defang/src/pkg/dns"
)

func CheckTLSCert(ctx context.Context, domain string) error {
resolver := dns.RootResolver{}
func CheckTLSCert(ctx context.Context, domain string, resolver dns.Resolver) error {
ips, err := resolver.LookupIPAddr(ctx, domain)
if err != nil {
return err
Expand Down
44 changes: 25 additions & 19 deletions src/pkg/cli/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ type DNSResult struct {
}

var (
resolver dns.Resolver = dns.RootResolver{}
dnsCache = make(map[string]DNSResult)
dnsCacheDuration = 1 * time.Minute
httpClient HTTPClient = &http.Client{
dnsCache = make(map[string]DNSResult)
dnsCacheDuration = 1 * time.Minute

httpRetryDelayBase = 5 * time.Second
)

func newCertHTTPClient(r dns.Resolver) HTTPClient {
return &http.Client{
// Based on the default transport: https://pkg.go.dev/net/http#RoundTripper
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Expand All @@ -49,7 +53,7 @@ var (
if ok && cached.Expiry.After(time.Now()) {
ips = cached.IPs
} else {
ips, err = resolver.LookupIPAddr(ctx, host)
ips, err = r.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
Expand All @@ -73,8 +77,7 @@ var (
return nil
},
}
httpRetryDelayBase = 5 * time.Second
)
}

func GenerateLetsEncryptCert(ctx context.Context, project *compose.Project, client client.FabricClient, provider client.Provider) error {
term.Debugf("Generating TLS cert for project %q", project.Name)
Expand Down Expand Up @@ -105,7 +108,7 @@ func GenerateLetsEncryptCert(ctx context.Context, project *compose.Project, clie
}
term.Debugf("Found service %v with domains %v and targets %v", service.Name, domains, targets)
for _, domain := range domains {
generateCert(ctx, domain, targets, client)
generateCert(ctx, domain, targets, client, dns.FabricResolver{Client: client})
}
}
}
Expand All @@ -131,15 +134,15 @@ func getDomainTargets(serviceInfo *defangv1.ServiceInfo, service compose.Service
}
}

func generateCert(ctx context.Context, domain string, targets []string, client client.FabricClient) {
func generateCert(ctx context.Context, domain string, targets []string, client client.FabricClient, r dns.Resolver) {
term.Infof("Checking DNS setup for %v", domain)
if err := waitForCNAME(ctx, domain, targets, client); err != nil {
term.Errorf("Error waiting for CNAME: %v", err)
return
}

term.Infof("%v DNS is properly configured!", domain)
if err := cert.CheckTLSCert(ctx, domain); err == nil {
if err := cert.CheckTLSCert(ctx, domain, r); err == nil {
term.Infof("TLS cert for %v is already ready", domain)
return
}
Expand All @@ -148,13 +151,13 @@ func generateCert(ctx context.Context, domain string, targets []string, client c
return
}
term.Infof("Triggering cert generation for %v", domain)
if err := triggerCertGeneration(ctx, domain); err != nil {
if err := triggerCertGeneration(ctx, domain, r); err != nil {
term.Errorf("Error triggering cert generation, please try again")
return
}

term.Infof("Waiting for TLS cert to be online for %v, this could take a few minutes", domain)
if err := waitForTLS(ctx, domain); err != nil {
if err := waitForTLS(ctx, domain, r); err != nil {
term.Errorf("Error waiting for TLS to be online: %v", err)
// FIXME: Add more info on how to debug, possibly provided by the server side to avoid client type detection here
return
Expand All @@ -163,7 +166,7 @@ func generateCert(ctx context.Context, domain string, targets []string, client c
term.Infof("TLS cert for %v is ready\n", domain)
}

func triggerCertGeneration(ctx context.Context, domain string) error {
func triggerCertGeneration(ctx context.Context, domain string, r dns.Resolver) error {
doSpinner := term.StdoutCanColor() && term.IsTerminal()
if doSpinner {
term.HideCursor()
Expand All @@ -174,15 +177,15 @@ func triggerCertGeneration(ctx context.Context, domain string) error {
defer cancelSpinner()
}
// Our own retry logic uses the root resolver to prevent cached DNS and retry on all non-200 errors
if err := getWithRetries(ctx, fmt.Sprintf("http://%v", domain), 5); err != nil { // Retry incase of DNS error
if err := getWithRetries(ctx, fmt.Sprintf("http://%v", domain), 5, newCertHTTPClient(r)); err != nil { // Retry incase of DNS error
// Ignore possible tls error as cert attachment may take time
term.Debugf("Error triggering cert generation: %v", err)
return err
}
return nil
}

func waitForTLS(ctx context.Context, domain string) error {
func waitForTLS(ctx context.Context, domain string, r dns.Resolver) error {
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
timeout, cancel := context.WithTimeout(ctx, 10*time.Minute)
Expand All @@ -202,7 +205,7 @@ func waitForTLS(ctx context.Context, domain string) error {
case <-timeout.Done():
return timeout.Err()
case <-ticker.C:
if err := cert.CheckTLSCert(timeout, domain); err == nil {
if err := cert.CheckTLSCert(timeout, domain, r); err == nil {
return nil
} else {
term.Debugf("Error checking TLS cert for %v: %v", domain, err)
Expand Down Expand Up @@ -249,7 +252,10 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c
}
}
if serverSideVerified || serverVerifyRpcFailure >= 3 {
locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets)
fabricResolverAt := func(nsServer string) dns.Resolver {
return dns.FabricResolver{Client: client, NSServer: nsServer}
}
locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets, fabricResolverAt)
if serverSideVerified && !locallyVerified {
term.Warnf("DNS settings for %v are verified, but changes may take a few minutes to propagate due to caching.", domain)
return nil
Expand Down Expand Up @@ -280,14 +286,14 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c
}
}

func getWithRetries(ctx context.Context, url string, tries int) error {
func getWithRetries(ctx context.Context, url string, tries int, c HTTPClient) error {
var errs []error
for i := range make([]struct{}, tries) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err // No point retrying if we can't even create the request
}
resp, err := httpClient.Do(req)
resp, err := c.Do(req)
if err == nil {
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body) // Read the body to ensure the request is not swallowed by alb
Expand Down
44 changes: 12 additions & 32 deletions src/pkg/cli/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ func TestGetWithRetries(t *testing.T) {
tc := &testClient{tries: []tryResult{
{result: &http.Response{StatusCode: 200, Body: mockBody("")}, err: nil},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 3)
err := getWithRetries(t.Context(), "http://example.com", 3, tc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand All @@ -76,10 +73,7 @@ func TestGetWithRetries(t *testing.T) {
{result: nil, err: errors.New("error")},
{result: &http.Response{StatusCode: 200, Body: mockBody("")}, err: nil},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 3)
err := getWithRetries(t.Context(), "http://example.com", 3, tc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand All @@ -93,10 +87,7 @@ func TestGetWithRetries(t *testing.T) {
{result: &http.Response{StatusCode: 503, Body: mockBody("Random Error")}, err: nil},
{result: nil, err: errors.New("error")},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 3)
err := getWithRetries(t.Context(), "http://example.com", 3, tc)
if err == nil {
t.Errorf("Expected error, got %v", err)
} else if !strings.Contains(err.Error(), "HTTP: 503") {
Expand All @@ -111,10 +102,7 @@ func TestGetWithRetries(t *testing.T) {
tc := &testClient{tries: []tryResult{
{result: &http.Response{StatusCode: 503, Request: &http.Request{URL: redirectURL}, Body: mockBody("Random Error")}, err: nil},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 3)
err := getWithRetries(t.Context(), "http://example.com", 3, tc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand All @@ -126,10 +114,7 @@ func TestGetWithRetries(t *testing.T) {
tc := &testClient{tries: []tryResult{
{result: nil, err: &tls.CertificateVerificationError{Err: errors.New("error")}},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 3)
err := getWithRetries(t.Context(), "http://example.com", 3, tc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand All @@ -143,10 +128,7 @@ func TestGetWithRetries(t *testing.T) {
{result: &http.Response{StatusCode: 502, Body: mockBody("Random Error")}, err: nil},
{result: &http.Response{StatusCode: 503, Body: mockBody("Random Error")}, err: nil},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 3)
err := getWithRetries(t.Context(), "http://example.com", 3, tc)
if err == nil {
t.Errorf("Expected error, got %v", err)
} else if !strings.Contains(err.Error(), "HTTP: 404") || !strings.Contains(err.Error(), "HTTP: 502") || !strings.Contains(err.Error(), "HTTP: 503") {
Expand All @@ -162,10 +144,7 @@ func TestGetWithRetries(t *testing.T) {
{result: &http.Response{StatusCode: 502, Body: mockBody("Random Error")}, err: nil},
{result: &http.Response{StatusCode: 503, Body: mockBody("Random Error")}, err: nil},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(t.Context(), "http://example.com", 1)
err := getWithRetries(t.Context(), "http://example.com", 1, tc)
if err == nil {
t.Errorf("Expected error, got %v", err)
}
Expand Down Expand Up @@ -198,9 +177,10 @@ func TestHttpClient(t *testing.T) {
}))
defer ts.Close()
var mr MockResolver
resolver = &mr
dnsCacheDuration = 50 * time.Millisecond

tc := newCertHTTPClient(&mr)

tsu, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("could not parse test server url '%v': %v", ts.URL, err)
Expand All @@ -214,7 +194,7 @@ func TestHttpClient(t *testing.T) {
t.Fatalf("failed to create request: %v", err)
}

resp, err := httpClient.Do(req)
resp, err := tc.Do(req)
if err != nil {
t.Fatalf("failed to make http call: %v", err)
}
Expand All @@ -224,7 +204,7 @@ func TestHttpClient(t *testing.T) {
t.Fatalf("expected 1 dns lookup, but got %v", mr.calls)
}

resp, err = httpClient.Do(req)
resp, err = tc.Do(req)
if err != nil {
t.Fatalf("failed to make http call: %v", err)
}
Expand All @@ -234,7 +214,7 @@ func TestHttpClient(t *testing.T) {
}

time.Sleep(80 * time.Millisecond)
resp, err = httpClient.Do(req)
resp, err = tc.Do(req)
if err != nil {
t.Fatalf("failed to make http call: %v", err)
}
Expand Down
13 changes: 9 additions & 4 deletions src/pkg/cli/client/byoc/aws/byoc.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ type Config = awssdk.Config
type ByocAws struct {
*byoc.ByocBaseClient

driver *cfn.AwsCfn // TODO: ecs is stateful, contains the output of the cd cfn stack after SetUpCD
driver *cfn.AwsCfn // TODO: ecs is stateful, contains the output of the cd cfn stack after SetUpCD
fabricClient dns.FabricResolverClient

cdEtag types.ETag
cdStart time.Time
Expand Down Expand Up @@ -114,7 +115,7 @@ func AnnotateAwsError(err error) error {
return err
}

func NewByocProvider(ctx context.Context, tenantName types.TenantLabel, stack string) *ByocAws {
func NewByocProvider(ctx context.Context, tenantName types.TenantLabel, stack string, fabricClient dns.FabricResolverClient) *ByocAws {
if awsProfileName := os.Getenv("AWS_PROFILE"); awsProfileName != "" {
AWSAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID")
AWSSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY")
Expand All @@ -129,7 +130,8 @@ func NewByocProvider(ctx context.Context, tenantName types.TenantLabel, stack st
}

b := &ByocAws{
driver: cfn.New(byoc.CdTaskPrefix, aws.Region("")), // default region
driver: cfn.New(byoc.CdTaskPrefix, aws.Region("")), // default region
fabricClient: fabricClient,
}
b.ByocBaseClient = byoc.NewByocBaseClient(tenantName, b, stack)

Expand Down Expand Up @@ -425,7 +427,10 @@ func (b *ByocAws) PrepareDomainDelegation(ctx context.Context, req client.Prepar
r53Client := route53.NewFromConfig(cfg)

projectDomain := req.DelegateDomain
nsServers, delegationSetId, err := prepareDomainDelegation(ctx, projectDomain, req.Project, b.PulumiStack, r53Client, dns.ResolverAt)
resolverAt := func(nsServer string) dns.Resolver {
return dns.FabricResolver{Client: b.fabricClient, NSServer: nsServer}
}
nsServers, delegationSetId, err := prepareDomainDelegation(ctx, projectDomain, req.Project, b.PulumiStack, r53Client, resolverAt)
if err != nil {
return nil, AnnotateAwsError(err)
}
Expand Down
2 changes: 1 addition & 1 deletion src/pkg/cli/client/byoc/aws/byoc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ aws_secret_access_key = wJalrXUtnFEMI/KDEFANG/bPxRfiCYEXAMPLEKEY
ctx := t.Context()

// Create ByocAws instance - warning is printed here in NewByocProvider
b := NewByocProvider(ctx, "tenant1", "exampleStack")
b := NewByocProvider(ctx, "tenant1", "exampleStack", nil)
_, err := b.AccountInfo(ctx)

if tt.expectedError {
Expand Down
3 changes: 3 additions & 0 deletions src/pkg/cli/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type FabricClient interface {
Preview(context.Context, *defangv1.PreviewRequest) (*defangv1.PreviewResponse, error)
PutDeployment(context.Context, *defangv1.PutDeploymentRequest) error
PutStack(context.Context, *defangv1.PutStackRequest) error
ResolveCNAME(context.Context, *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error)
ResolveIPAddr(context.Context, *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error)
ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error)
RevokeToken(context.Context) error
SetOptions(context.Context, *defangv1.SetOptionsRequest) error
Token(context.Context, *defangv1.TokenRequest) (*defangv1.TokenResponse, error)
Expand Down
12 changes: 12 additions & 0 deletions src/pkg/cli/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,15 @@ func (g GrpcClient) GenerateCompose(ctx context.Context, req *defangv1.GenerateC
func (g GrpcClient) GetDefaultStack(ctx context.Context, req *defangv1.GetDefaultStackRequest) (*defangv1.GetStackResponse, error) {
return getMsg(g.client.GetDefaultStack(ctx, connect.NewRequest(req)))
}

func (g GrpcClient) ResolveIPAddr(ctx context.Context, req *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) {
return getMsg(g.client.ResolveIPAddr(ctx, connect.NewRequest(req)))
}

func (g GrpcClient) ResolveCNAME(ctx context.Context, req *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) {
return getMsg(g.client.ResolveCNAME(ctx, connect.NewRequest(req)))
}

func (g GrpcClient) ResolveNS(ctx context.Context, req *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) {
return getMsg(g.client.ResolveNS(ctx, connect.NewRequest(req)))
}
12 changes: 12 additions & 0 deletions src/pkg/cli/client/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ func (m MockFabricClient) GetDefaultStack(context.Context, *defangv1.GetDefaultS
}, nil
}

func (m MockFabricClient) ResolveIPAddr(_ context.Context, _ *defangv1.ResolveIPAddrRequest) (*defangv1.ResolveIPAddrResponse, error) {
return &defangv1.ResolveIPAddrResponse{}, nil
}

func (m MockFabricClient) ResolveCNAME(_ context.Context, _ *defangv1.ResolveCNAMERequest) (*defangv1.ResolveCNAMEResponse, error) {
return &defangv1.ResolveCNAMEResponse{}, nil
}

func (m MockFabricClient) ResolveNS(_ context.Context, _ *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error) {
return &defangv1.ResolveNSResponse{}, nil
}

type MockLoader struct {
Project composeTypes.Project
Error error
Expand Down
2 changes: 1 addition & 1 deletion src/pkg/cli/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func NewProvider(ctx context.Context, providerID client.ProviderID, fabricClient
term.Debugf("Creating %s provider", providerID)
switch providerID {
case client.ProviderAWS:
provider = aws.NewByocProvider(ctx, fabricClient.GetTenantName(), stack)
provider = aws.NewByocProvider(ctx, fabricClient.GetTenantName(), stack, fabricClient)
case client.ProviderDO:
provider = do.NewByocProvider(ctx, fabricClient.GetTenantName(), stack)
case client.ProviderGCP:
Expand Down
Loading
Loading