diff --git a/flake.nix b/flake.nix index 61faaf9d3..bced565f9 100644 --- a/flake.nix +++ b/flake.nix @@ -17,18 +17,20 @@ with pkgs; mkShell { buildInputs = [ + azure-cli actionlint bashInteractive # full bash with readline/completion so prompts render correctly crane gh git gnumake - less gnused # force Linux `sed` everywhere go_1_24 # must match GO_VERSION in Dockerfile - gopls golangci-lint + google-cloud-sdk + gopls goreleaser + less nixfmt-rfc-style nodejs_24 # for Pulumi, must match values in package.json and npm-build/action.yml openssh @@ -37,8 +39,8 @@ protoc-gen-go protolint pulumi + pulumiPackages.pulumi-go pulumiPackages.pulumi-nodejs - google-cloud-sdk vim ]; shellHook = '' diff --git a/pkgs/defang/cli.nix b/pkgs/defang/cli.nix index d9586a5ca..caabc4249 100644 --- a/pkgs/defang/cli.nix +++ b/pkgs/defang/cli.nix @@ -7,7 +7,7 @@ buildGo124Module { pname = "defang-cli"; version = "git"; src = lib.cleanSource ../../src; - vendorHash = "sha256-kuJPVIvKcffccWk6aU7slKXgGzVJgdu/NMokiW5Lpc8="; + vendorHash = "sha256-RDLJgsMv0iRbIiNWENOoV4JDcgjzD+4Hbi0vJiUxTzU="; subPackages = [ "cmd/cli" ]; diff --git a/src/cmd/cli/command/estimate.go b/src/cmd/cli/command/estimate.go index 6f109b6b6..2f50ec44b 100644 --- a/src/cmd/cli/command/estimate.go +++ b/src/cmd/cli/command/estimate.go @@ -85,11 +85,17 @@ func interactiveSelectProvider(providers []client.ProviderID) (client.ProviderID } // Default to the provider in the environment if available var defaultOption any // not string! - if pkg.AwsInEnv() != "" { + switch { + case pkg.AwsInEnv() != "": defaultOption = client.ProviderAWS.String() - } else if pkg.GcpInEnv() != "" { + case pkg.AzureInEnv() != "": + defaultOption = client.ProviderAzure.String() + case pkg.DoInEnv() != "": + defaultOption = client.ProviderDO.String() + case pkg.GcpInEnv() != "": defaultOption = client.ProviderGCP.String() } + var optionValue string if err := survey.AskOne(&survey.Select{ Default: defaultOption, diff --git a/src/go.mod b/src/go.mod index 1afad74fe..72103a14e 100644 --- a/src/go.mod +++ b/src/go.mod @@ -20,6 +20,20 @@ require ( cloud.google.com/go/storage v1.50.0 connectrpc.com/connect v1.19.1 github.com/AlecAivazis/survey/v2 v2.3.7 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers v1.1.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3 v3.1.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights v1.2.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2 v2.1.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage/v2 v2.0.0 + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets v1.4.0 + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 github.com/DefangLabs/secret-detector v0.0.0-20250811234530-d4b4214cd679 github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 github.com/aws/aws-sdk-go-v2 v1.41.5 @@ -41,7 +55,7 @@ require ( github.com/digitalocean/godo v1.131.1 github.com/docker/cli v29.2.0+incompatible github.com/firebase/genkit/go v1.2.0 - github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/googleapis/gax-go/v2 v2.14.2 github.com/gorilla/websocket v1.5.3 @@ -82,6 +96,8 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/longrunning v0.6.7 // indirect cloud.google.com/go/monitoring v1.24.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.52.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.52.0 // indirect @@ -111,6 +127,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect diff --git a/src/go.sum b/src/go.sum index f90411670..e0b99f544 100644 --- a/src/go.sum +++ b/src/go.sum @@ -34,6 +34,52 @@ connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers v1.1.0 h1:fdAOz6TFldGDoEcRa975i5L5QvWU8ptut+SJAIfuWUY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers v1.1.0/go.mod h1:qV+BWew22CAalRTwJEAHs+aSLP49k/csNlspqhMIDRU= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3 v3.1.0 h1:ilMZ576u8sm975EqV+AKEtD4u9TLwqEo2XY9csPXBRo= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3 v3.1.0/go.mod h1:LGhzy+pg9AKr1Z7ZRyTC1qr1xNyVqLsqydvLdY+2iQk= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 h1:Hp+EScFOu9HeCbeW8WU2yQPJd4gGwhMgKxWe+G6jNzw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0/go.mod h1:/pz8dyNQe+Ey3yBp/XuYz7oqX8YDNWVpPB0hH3XWfbc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0 h1:DWlwvVV5r/Wy1561nZ3wrpI1/vDIBRY/Wd1HWaRBZWA= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0/go.mod h1:E7ltexgRDmeJ0fJWv0D/HLwY2xbDdN+uv+X2uZtOx3w= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0 h1:nnQ9vXH039UrEFxi08pPuZBE7VfqSJt343uJLw0rhWI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0/go.mod h1:4YIVtzMFVsPwBvitCDX7J9sqthSj43QD1sP6fYc1egc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.2.0 h1:akP6VpxJGgQRpDR1P462piz/8OhYLRCreDj48AyNabc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.2.0/go.mod h1:8wzvopPfyZYPaQUoKW87Zfdul7jmJMDfp/k7YY3oJyA= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights v1.2.0 h1:4FlNvfcPu7tTvOgOzXxIbZLvwvmZq1OdhQUdIa9g2N4= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights v1.2.0/go.mod h1:A4nzEXwVd5pAyneR6KOvUAo72svUc5rmCzRHhAbP6lA= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2 v2.1.0 h1:seyVIpxalxYmfjoo8MB4rRzWaobMG+KJ2+MAUrEvDGU= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2 v2.1.0/go.mod h1:M3QD7IyKZBaC4uAKjitTOSOXdcPC6JS1A9oOW3hYjbQ= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 h1:wxQx2Bt4xzPIKvW59WQf1tJNx/ZZKPfN+EhPX3Z6CYY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0/go.mod h1:TpiwjwnW/khS0LKs4vW5UmmT9OWcxaveS8U7+tlknzo= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage/v2 v2.0.0 h1:+vh02EiRx2UmL9NDoA36U18Bgwl9luxs6ia0GAI9Rzg= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage/v2 v2.0.0/go.mod h1:iKOtU3WyuNvNc4L1Z4IxHaoO0dGq5tg+uhLix/KRmzE= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets v1.4.0 h1:/g8S6wk65vfC6m3FIxJ+i5QDyN9JWwXI8Hb0Img10hU= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets v1.4.0/go.mod h1:gpl+q95AzZlKVI3xSoseF9QPrypk0hQqBiJYeB/cR/I= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0/go.mod h1:ucUjca2JtSZboY8IoUqyQyuuXvwbMBVwFOm0vdQPNhA= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 h1:jWQK1GI+LeGGUKBADtcH2rRqPxYB1Ljwms5gFA2LqrM= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4/go.mod h1:8mwH4klAm9DUgR2EEHyEEAQlRDvLPyg5fQry3y+cDew= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/DefangLabs/cobra v1.8.0-defang h1:rTzAg1XbEk3yXUmQPumcwkLgi8iNCby5CjyG3sCwzKk= github.com/DefangLabs/cobra v1.8.0-defang/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/DefangLabs/secret-detector v0.0.0-20250811234530-d4b4214cd679 h1:qNT7R4qrN+5u5ajSbqSW1opHP4LA8lzA+ASyw5MQZjs= @@ -176,8 +222,8 @@ github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 h1:okN800+zMJOGHLJCgry+OGzhhtH6YrjQh1rluHmOacE= @@ -228,6 +274,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -237,6 +285,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= diff --git a/src/pkg/agent/tools/estimate_test.go b/src/pkg/agent/tools/estimate_test.go index f1d4e4b6a..65807058e 100644 --- a/src/pkg/agent/tools/estimate_test.go +++ b/src/pkg/agent/tools/estimate_test.go @@ -132,7 +132,7 @@ func TestHandleEstimateTool(t *testing.T) { setupMock: func(m *MockEstimateCLI) { m.Project = &compose.Project{Name: "test-project"} }, - expectedError: "invalid provider: \"invalid-provider\", not one of [auto defang aws digitalocean gcp]", + expectedError: "invalid provider: \"invalid-provider\", not one of [auto defang aws digitalocean gcp azure]", }, { name: "run_estimate_error", diff --git a/src/pkg/cli/client/byoc/aws/byoc.go b/src/pkg/cli/client/byoc/aws/byoc.go index babf9950b..31e796bec 100644 --- a/src/pkg/cli/client/byoc/aws/byoc.go +++ b/src/pkg/cli/client/byoc/aws/byoc.go @@ -476,6 +476,8 @@ func (b *ByocAws) bucketName() string { func (b *ByocAws) environment(projectName string) (map[string]string, error) { region := b.driver.Region // TODO: this should be the destination region, not the CD region; make customizable + + // From https://www.pulumi.com/docs/iac/concepts/state-and-backends/#aws-s3 defangStateUrl := fmt.Sprintf(`s3://%s?region=%s&awssdk=v2`, b.bucketName(), region) pulumiBackendKey, pulumiBackendValue, err := byoc.GetPulumiBackend(defangStateUrl) if err != nil { diff --git a/src/pkg/cli/client/byoc/aws/byoc_integration_test.go b/src/pkg/cli/client/byoc/aws/byoc_integration_test.go index c70abd1e0..b45f52580 100644 --- a/src/pkg/cli/client/byoc/aws/byoc_integration_test.go +++ b/src/pkg/cli/client/byoc/aws/byoc_integration_test.go @@ -7,13 +7,13 @@ import ( "strings" "testing" + "connectrpc.com/connect" "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc" "github.com/DefangLabs/defang/src/pkg/cli/compose" "github.com/DefangLabs/defang/src/pkg/clouds/aws" "github.com/DefangLabs/defang/src/pkg/clouds/aws/codebuild/cfn" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" - "connectrpc.com/connect" ) func TestDeploy(t *testing.T) { diff --git a/src/pkg/cli/client/byoc/azure/byoc.go b/src/pkg/cli/client/byoc/azure/byoc.go new file mode 100644 index 000000000..3e979b5a6 --- /dev/null +++ b/src/pkg/cli/client/byoc/azure/byoc.go @@ -0,0 +1,802 @@ +package azure + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "iter" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "connectrpc.com/connect" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/DefangLabs/defang/src/pkg" + "github.com/DefangLabs/defang/src/pkg/cli/client" + "github.com/DefangLabs/defang/src/pkg/cli/client/byoc" + "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/state" + "github.com/DefangLabs/defang/src/pkg/cli/compose" + cloudazure "github.com/DefangLabs/defang/src/pkg/clouds/azure" + "github.com/DefangLabs/defang/src/pkg/clouds/azure/aca" + "github.com/DefangLabs/defang/src/pkg/clouds/azure/acr" + "github.com/DefangLabs/defang/src/pkg/clouds/azure/cd" + "github.com/DefangLabs/defang/src/pkg/clouds/azure/keyvault" + defanghttp "github.com/DefangLabs/defang/src/pkg/http" + "github.com/DefangLabs/defang/src/pkg/term" + "github.com/DefangLabs/defang/src/pkg/tokenstore" + "github.com/DefangLabs/defang/src/pkg/types" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type ByocAzure struct { + *byoc.ByocBaseClient + + driver *cd.Driver + job *aca.Job + kv *keyvault.KeyVault + cdRunID string + cdEtag string + setUpDone bool // true once full setUp has completed; prevents redundant API calls +} + +var _ client.Provider = (*ByocAzure)(nil) + +func NewByocProvider(ctx context.Context, tenantLabel types.TenantLabel, stack string) *ByocAzure { + b := &ByocAzure{ + driver: cd.New("defang-cd", ""), // default location => from AZURE_LOCATION env var + job: &aca.Job{}, + } + b.ByocBaseClient = byoc.NewByocBaseClient(tenantLabel, b, stack) + b.driver.TokenStore = &tokenstore.LocalDirTokenStore{Dir: filepath.Join(client.StateDir, "providers", "azure")} + return b +} + +func (b *ByocAzure) Driver() string { + return "azure" +} + +// SetUpCD implements client.Provider. +func (b *ByocAzure) SetUpCD(context.Context, bool) error { + term.Debugf("SetUpCD: no-op for Azure; CD environment will be set up on demand during Deploy") + return nil +} + +// CdCommand implements byoc.ProjectBackend. +func (b *ByocAzure) CdCommand(ctx context.Context, req client.CdCommandRequest) (*client.CdCommandResponse, error) { + if err := b.setUpForConfig(ctx, req.Project); err != nil { + return nil, err + } + if err := b.setUp(ctx); err != nil { + return nil, err + } + envMap, err := b.buildCdEnv(req.Project) + if err != nil { + return nil, err + } + if err := b.setUpJob(ctx, envMap); err != nil { + return nil, err + } + etag := pkg.RandomID() + execName, err := b.job.StartJobExecution(ctx, aca.JobRequest{ + Image: b.CDImage, + Command: []string{"/app/cd", string(req.Command)}, + Envs: envMap, + Timeout: 30 * time.Minute, + }) + if err != nil { + return nil, err + } + b.cdRunID = execName + b.cdEtag = etag + return &client.CdCommandResponse{ + CdId: execName, + CdType: defangv1.CdType_CD_TYPE_AZURE_ACI_JOBID, + ETag: etag, + }, nil +} + +// CdList implements byoc.ProjectBackend. +func (b *ByocAzure) CdList(ctx context.Context, _ bool) (iter.Seq[state.Info], error) { + if err := b.setUp(ctx); err != nil { + return nil, err + } + + blobs, err := b.driver.IterateBlobsInContainer(ctx, cd.PulumiContainerName, ".pulumi/stacks/") + if err != nil { + return nil, err + } + + return func(yield func(state.Info) bool) { + for item, err := range blobs { + if err != nil { + term.Debugf("Error iterating blobs: %v", err) + return + } + st, err := state.ParsePulumiStateFile(ctx, item, cd.PulumiContainerName, func(ctx context.Context, container, blobName string) ([]byte, error) { + return b.driver.DownloadBlobFromContainer(ctx, container, blobName) + }) + if err != nil { + term.Debugf("Skipping %q: %v", item.Name(), err) + continue + } + if st == nil { + continue + } + if !yield(state.Info{ + Project: st.Project, + Stack: st.Name, + Workspace: string(st.Workspace), + CdRegion: b.driver.Location.String(), + }) { + return + } + } + }, nil +} + +// AccountInfo implements client.Provider. +func (b *ByocAzure) AccountInfo(context.Context) (*client.AccountInfo, error) { + return &client.AccountInfo{ + AccountID: b.driver.SubscriptionID, + Provider: client.ProviderAzure, + Region: b.driver.Location.String(), + }, nil +} + +// CreateUploadURL implements client.Provider. +func (b *ByocAzure) CreateUploadURL(ctx context.Context, req *defangv1.UploadURLRequest) (*defangv1.UploadURLResponse, error) { + if err := b.setUp(ctx); err != nil { + return nil, err + } + + url, err := b.driver.CreateUploadURL(ctx, req.Digest) + if err != nil { + return nil, err + } + + return &defangv1.UploadURLResponse{ + Url: url, + }, nil +} + +// Delete implements client.Provider. +func (b *ByocAzure) Delete(context.Context, *defangv1.DeleteRequest) (*defangv1.DeleteResponse, error) { + return nil, fmt.Errorf("Delete: %w", errors.ErrUnsupported) +} + +// DeleteConfig implements client.Provider. Read-only for discovery: if the +// Key Vault doesn't exist yet, there's nothing to delete — return success +// instead of provisioning it just to tear down. +func (b *ByocAzure) DeleteConfig(ctx context.Context, secrets *defangv1.Secrets) error { + found, err := b.findForConfig(ctx, secrets.Project) + if err != nil { + return err + } + if !found { + return nil // nothing configured yet, nothing to delete + } + for _, name := range secrets.Names { + key := b.StackDir(secrets.Project, name) + secretName := keyvault.ToSecretName(key) + term.Debugf("Deleting Key Vault secret %q", secretName) + if err := b.kv.DeleteSecret(ctx, secretName); err != nil { + return fmt.Errorf("failed to delete Key Vault secret %q: %w", name, err) + } + } + return nil +} + +// setUpLocation lazily resolves AZURE_LOCATION and AZURE_SUBSCRIPTION_ID from the environment +// and syncs the values to the job. It makes no API calls. +func (b *ByocAzure) setUpLocation() error { + if b.driver.Location == "" { + loc := cloudazure.Location(os.Getenv("AZURE_LOCATION")) + if loc == "" { + return errors.New("AZURE_LOCATION is not set; please ensure your stack includes the Azure region") + } + b.driver.SetLocation(loc) + } + if b.driver.SubscriptionID == "" { + b.driver.SubscriptionID = os.Getenv("AZURE_SUBSCRIPTION_ID") + } + b.job.Azure = b.driver.Azure + b.job.ResourceGroup = b.driver.ResourceGroupName() + return nil +} + +// projectResourceGroupName returns the resource group name for project-specific resources +// (App Configuration store and deployed services). +// Format: defang-{project}-{stack}-{location}, e.g. "defang-myapp-test-westus2". +// This group is owned by one project+stack and is separate from the shared CD resource group. +func (b *ByocAzure) projectResourceGroupName(projectName string) string { + return "defang-" + projectName + "-" + b.PulumiStack +} + +// findForConfig binds to a pre-existing project Key Vault without creating +// anything new. Returns (true, nil) when the vault exists, (false, nil) when +// it or its resource group doesn't — which callers like ListConfig / +// DeleteConfig treat as "nothing configured yet". +// +// On a successful Find we also self-grant "Key Vault Secrets Officer" to the +// current caller (idempotent — RoleAssignmentExists is a no-op). This +// onboards new teammates onto a shared stack whose vault was created by +// someone else; without it, the read paths would 403 forever for them. +func (b *ByocAzure) findForConfig(ctx context.Context, projectName string) (bool, error) { + if err := b.setUpLocation(); err != nil { + return false, err + } + if b.kv != nil { + return true, nil + } + rgName := b.projectResourceGroupName(projectName) + kv := keyvault.New(rgName, b.driver.Azure) + found, err := kv.Find(ctx) + if err != nil { + return false, err + } + if !found { + return false, nil + } + if err := kv.EnsureSecretsOfficer(ctx); err != nil { + return false, err + } + b.kv = kv + return true, nil +} + +// setUpForConfig creates the project-specific Key Vault (and the resource +// group that holds it) on first use. Idempotent. b.kv is only cached on +// successful SetUp, so a failed attempt doesn't mask the root cause for +// subsequent config operations within the same process. +func (b *ByocAzure) setUpForConfig(ctx context.Context, projectName string) error { + if err := b.setUpLocation(); err != nil { + return err + } + rgName := b.projectResourceGroupName(projectName) + if err := b.driver.CreateResourceGroup(ctx, rgName); err != nil { + return err + } + if b.kv == nil { + kv := keyvault.New(rgName, b.driver.Azure) + if err := kv.SetUp(ctx); err != nil { + return err + } + b.kv = kv + } + return nil +} + +// setUp sets up the shared CD infrastructure: resource group, blob storage, the Container +// Apps environment, and the job's managed identity. It does NOT create the CD job itself +// (SetUpJob must be called separately with env vars baked in) and does NOT set up +// project-specific resources (use setUpForConfig for App Configuration). +func (b *ByocAzure) setUp(ctx context.Context) error { + if err := b.setUpLocation(); err != nil { + return err + } + + if b.setUpDone { + return nil + } + + // Create the shared CD resource group (defang-cd-{location}). + if err := b.driver.SetUpResourceGroup(ctx); err != nil { + return err + } + + if _, err := b.driver.SetUpStorageAccount(ctx); err != nil { + return fmt.Errorf("failed to set up storage account: %w", err) + } + + if err := b.job.SetUpEnvironment(ctx); err != nil { + return fmt.Errorf("failed to set up container apps environment: %w", err) + } + + b.setUpDone = true + return nil +} + +// setUpJob creates/updates the CD job with the given env vars baked into its template, +// and grants the job's managed identity read access to the CD storage account. The job +// must already have SetUpEnvironment called on it (via setUp). The CD image is pulled +// anonymously — its registry must allow anonymous pull. +func (b *ByocAzure) setUpJob(ctx context.Context, envMap map[string]string) error { + if b.CDImage == "" { + return errors.New("CD image is not set; please set the DEFANG_CD_IMAGE environment variable") + } + if err := b.job.SetUpJob(ctx, b.CDImage, envMap); err != nil { + return fmt.Errorf("failed to set up CD job: %w", err) + } + if err := b.job.SetUpManagedIdentity(ctx, b.driver.StorageAccount); err != nil { + return fmt.Errorf("failed to set up managed identity: %w", err) + } + return nil +} + +// buildCdEnv returns the environment map that every CD container run needs. +func (b *ByocAzure) buildCdEnv(projectName string) (map[string]string, error) { + // Pulumi state lives in its own container (`pulumi`), separate from the + // `uploads` container (etag payloads, tarballs) and the `projects` + // container (project.pb audit blobs written by the CD task). + defangStateUrl := fmt.Sprintf(`azblob://%s?storage_account=%s`, cd.PulumiContainerName, b.driver.StorageAccount) + pulumiBackendKey, pulumiBackendValue, err := byoc.GetPulumiBackend(defangStateUrl) + if err != nil { + return nil, err + } + // AZURE_RESOURCE_GROUP and AZURE_KEY_VAULT_NAME are intentionally omitted: + // the Pulumi Azure provider now derives both deterministically from + // {project, stack, location, subscription} using the same formulas the + // CLI uses (projectResourceGroupName, keyvault.VaultName), so passing + // them as env vars would just be another spot for the two sides to + // drift out of sync. + env := map[string]string{ + "AZURE_LOCATION": b.driver.Location.String(), + "AZURE_SUBSCRIPTION_ID": b.driver.SubscriptionID, + "DEFANG_DEBUG": os.Getenv("DEFANG_DEBUG"), + "DEFANG_JSON": os.Getenv("DEFANG_JSON"), + "DEFANG_ORG": string(b.TenantLabel), + "DEFANG_PREFIX": b.Prefix, + "DEFANG_PULUMI_DEBUG": os.Getenv("DEFANG_PULUMI_DEBUG"), + "DEFANG_PULUMI_DIFF": os.Getenv("DEFANG_PULUMI_DIFF"), + "DEFANG_STATE_URL": defangStateUrl, + "HOME": "/root", // TODO: should be in Dockerfile + "NPM_CONFIG_UPDATE_NOTIFIER": "false", + "PROJECT": projectName, + pulumiBackendKey: pulumiBackendValue, // TODO: make secret + "PULUMI_AUTOMATION_API_SKIP_VERSION_CHECK": "true", + "PULUMI_CONFIG_PASSPHRASE": byoc.PulumiConfigPassphrase, // TODO: make secret + "PULUMI_COPILOT": "false", + // "PULUMI_DIY_BACKEND_DISABLE_CHECKPOINT_BACKUPS": "true", TODO: use versioned bucket + "PULUMI_SKIP_UPDATE_CHECK": "true", + "STACK": b.PulumiStack, + "USER": "root", // TODO: should be in Dockerfile + } + if targets := os.Getenv("DEFANG_PULUMI_TARGETS"); targets != "" { + env["DEFANG_PULUMI_TARGETS"] = targets + } + if !term.StdoutCanColor() { + env["NO_COLOR"] = "1" + } + return env, nil +} + +// Deploy implements client.Provider. +func (b *ByocAzure) Deploy(ctx context.Context, req *client.DeployRequest) (*client.DeployResponse, error) { + return b.deploy(ctx, req, "up") +} + +func (b *ByocAzure) deploy(ctx context.Context, req *client.DeployRequest, verb string) (*client.DeployResponse, error) { + if b.CDImage == "" { + return nil, errors.New("CD image is not set; please set the DEFANG_CD_IMAGE environment variable") + } + + // If multiple Compose files were provided, req.Compose is the merged representation of all the files + project, err := compose.LoadFromContent(ctx, req.Compose, "") + if err != nil { + return nil, err + } + + if err := b.setUpForConfig(ctx, project.Name); err != nil { + return nil, err + } + if err := b.setUp(ctx); err != nil { + return nil, err + } + + etag := pkg.RandomID() + serviceInfos, err := b.GetServiceInfos(ctx, project.Name, req.DelegateDomain, etag, project.Services) + if err != nil { + return nil, err + } + + data, err := proto.Marshal(&defangv1.ProjectUpdate{ + CdVersion: b.CDImage, + Compose: req.Compose, + Services: serviceInfos, + }) + if err != nil { + return nil, err + } + + envMap, err := b.buildCdEnv(project.Name) + if err != nil { + return nil, err + } + if err := b.setUpJob(ctx, envMap); err != nil { + return nil, err + } + + var payload string + if len(data) < 1000 { + payload = base64.StdEncoding.EncodeToString(data) + } else { + uploadURL, err := b.driver.CreateUploadURL(ctx, etag) + if err != nil { + return nil, err + } + resp, err := defanghttp.PutWithHeader(ctx, uploadURL, http.Header{ + "Content-Type": []string{"application/protobuf"}, + "x-ms-blob-type": []string{"BlockBlob"}, + }, bytes.NewReader(data)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 && resp.StatusCode != 201 { + return nil, fmt.Errorf("unexpected status code during upload: %s", resp.Status) + } + payload = defanghttp.RemoveQueryParam(uploadURL) // managed identity provides blob read access + } + + execName, err := b.job.StartJobExecution(ctx, aca.JobRequest{ + Image: b.CDImage, + Command: []string{"/app/cd", verb, payload}, + Envs: envMap, + Timeout: 30 * time.Minute, + }) + if err != nil { + return nil, err + } + b.cdRunID = execName + b.cdEtag = etag + return &client.DeployResponse{ + CdId: execName, + CdType: defangv1.CdType_CD_TYPE_AZURE_ACI_JOBID, + DeployResponse: &defangv1.DeployResponse{ + Etag: etag, Services: serviceInfos, + }, + }, nil +} + +// GetDeploymentStatus implements client.Provider. CD container output is streamed +// live via QueryLogs (follow=true) rather than drained from Log Analytics here, +// so this method can return as soon as the job reaches a terminal state. +func (b *ByocAzure) GetDeploymentStatus(ctx context.Context) (bool, error) { + if b.cdRunID == "" { + return false, nil + } + status, err := b.job.GetJobExecutionStatus(ctx, b.cdRunID) + if err != nil { + // Return the raw error so WaitForCdTaskExit's isTransientError can + // retry on flaky failures (e.g. AzureCLICredential subprocess timeouts). + // Wrapping as ErrDeploymentFailed here would mask transient errors as + // permanent deployment failures. + return false, err + } + if !status.IsTerminal() { + return false, nil + } + if !status.IsSuccess() { + msg := string(status.Status) + if status.ErrorMessage != "" { + msg += ": " + status.ErrorMessage + } + return true, client.ErrDeploymentFailed{Message: fmt.Sprintf("CD job %s: %s", b.cdRunID, msg)} + } + return true, nil +} + +// GetPrivateDomain implements byoc.ProjectBackend. +func (b *ByocAzure) GetPrivateDomain(projectName string) string { + return b.GetProjectLabel(projectName) + ".internal" +} + +// GetProjectUpdate implements byoc.ProjectBackend. It is read-only — it does +// not create the CD resource group, storage account, container apps +// environment, or any other provisioning side effect. On a subscription +// where defang has never been deployed the storage account lookup returns +// nothing and we report client.ErrNotExist immediately. +// +// The blob lives in the dedicated `projects` container (populated by the CD +// task before each deploy) at key `{project}/{stack}/project.pb`. +func (b *ByocAzure) GetProjectUpdate(ctx context.Context, projectName string) (*defangv1.ProjectUpdate, error) { + if projectName == "" { + return nil, client.ErrNotExist + } + if err := b.setUpLocation(); err != nil { + return nil, err + } + storageAccount, err := b.driver.FindStorageAccount(ctx) + if err != nil { + return nil, err + } + if storageAccount == "" { + // CD storage account hasn't been provisioned yet. + return nil, client.ErrNotExist + } + + // GetProjectUpdatePath returns "projects/{project}/{stack}/project.pb". + // The `projects` container already provides the top-level namespace, so + // strip the leading "projects/" when addressing the blob. + key := strings.TrimPrefix(b.GetProjectUpdatePath(projectName), "projects/") + term.Debug("Getting project update from blob:", cd.ProjectsContainerName, key) + pbBytes, err := b.driver.DownloadBlobFromContainer(ctx, cd.ProjectsContainerName, key) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && (respErr.StatusCode == 404 || respErr.ErrorCode == "ContainerNotFound" || respErr.ErrorCode == "BlobNotFound") { + return nil, client.ErrNotExist // no services yet + } + return nil, err + } + + var projUpdate defangv1.ProjectUpdate + if err := proto.Unmarshal(pbBytes, &projUpdate); err != nil { + return nil, err + } + return &projUpdate, nil +} + +// GetService implements client.Provider by fetching GetServices and filtering +// to the requested name — same pattern as the AWS and GCP providers. +func (b *ByocAzure) GetService(ctx context.Context, req *defangv1.GetRequest) (*defangv1.ServiceInfo, error) { + all, err := b.GetServices(ctx, &defangv1.GetServicesRequest{Project: req.Project}) + if err != nil { + return nil, err + } + for _, service := range all.Services { + if service.Service.Name == req.Name { + return service, nil + } + } + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("service %q not found", req.Name)) +} + +// GetServices implements client.Provider by reading the ProjectUpdate blob +// that the CD task uploads during Deploy — same pattern as the AWS and GCP +// providers. +func (b *ByocAzure) GetServices(ctx context.Context, req *defangv1.GetServicesRequest) (*defangv1.GetServicesResponse, error) { + projUpdate, err := b.GetProjectUpdate(ctx, req.Project) + if err != nil { + if errors.Is(err, client.ErrNotExist) { + return &defangv1.GetServicesResponse{}, nil + } + return nil, err + } + return &defangv1.GetServicesResponse{ + Services: projUpdate.Services, + Project: projUpdate.Project, + }, nil +} + +// ListConfig implements client.Provider. Read-only: when the project's +// App Configuration store or Key Vault hasn't been provisioned yet, returns +// an empty list instead of creating them. +func (b *ByocAzure) ListConfig(ctx context.Context, req *defangv1.ListConfigsRequest) (*defangv1.Secrets, error) { + found, err := b.findForConfig(ctx, req.Project) + if err != nil { + return nil, err + } + if !found { + return &defangv1.Secrets{}, nil // nothing configured yet + } + prefix := b.StackDir(req.Project, "") + secretPrefix := keyvault.ToSecretName(prefix) + term.Debugf("Listing Key Vault secrets with prefix %q (sanitized: %q)", prefix, secretPrefix) + entries, err := b.kv.ListSecrets(ctx, secretPrefix) + if err != nil { + return nil, err + } + names := make([]string, 0, len(entries)) + for _, e := range entries { + if e.OriginalKey == "" || !strings.HasPrefix(e.OriginalKey, prefix) { + continue + } + names = append(names, strings.TrimPrefix(e.OriginalKey, prefix)) + } + return &defangv1.Secrets{Names: names}, nil +} + +// PrepareDomainDelegation implements client.Provider. +func (b *ByocAzure) PrepareDomainDelegation(context.Context, client.PrepareDomainDelegationRequest) (*client.PrepareDomainDelegationResponse, error) { + return nil, nil // TODO: implement domain delegation for Azure +} + +// Preview implements client.Provider. +func (b *ByocAzure) Preview(ctx context.Context, req *client.DeployRequest) (*client.DeployResponse, error) { + return b.deploy(ctx, req, "preview") +} + +// PutConfig implements client.Provider. +func (b *ByocAzure) PutConfig(ctx context.Context, req *defangv1.PutConfigRequest) error { + if err := b.setUpForConfig(ctx, req.Project); err != nil { + return err + } + key := b.StackDir(req.Project, req.Name) + secretName := keyvault.ToSecretName(key) + term.Debugf("Putting Key Vault secret %q (original key %q)", secretName, key) + if err := b.kv.PutSecret(ctx, secretName, req.Value, key); err != nil { + return fmt.Errorf("failed to put Key Vault secret: %w", err) + } + return nil +} + +// QueryLogs implements client.Provider. +// Only CD container logs are supported; service logs are not yet implemented. +func (b *ByocAzure) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { + // Match the request etag to the stored CD etag so we tail the correct run. + if b.cdRunID == "" || (req.Etag != "" && req.Etag != b.cdEtag) { + return nil, fmt.Errorf("QueryLogs: no matching CD deployment for etag %q", req.Etag) + } + + const cdServiceName = "defang-cd" + etag := b.cdEtag + + if req.Follow { + logIter, err := b.job.TailJobLogs(ctx, b.cdRunID) + if err != nil { + return nil, err + } + + // Run ACR log iterator in a goroutine so we can select over it and ACA logs. + type cdLogEntry struct { + line string + err error + } + cdCh := make(chan cdLogEntry) + go func() { + defer close(cdCh) + for line, err := range logIter { + select { + case cdCh <- cdLogEntry{line: line, err: err}: + case <-ctx.Done(): + return + } + } + }() + + projectRG := b.projectResourceGroupName(req.Project) + + // Watch Container App logs from the PROJECT resource group. + acaClient := &aca.ContainerApp{ + Azure: b.driver.Azure, + ResourceGroup: projectRG, + } + acaCh := acaClient.WatchLogs(ctx) + + // Watch ACR build logs from the PROJECT resource group. + buildWatcher := &acr.BuildLogWatcher{ + Azure: b.driver.Azure, + ResourceGroup: projectRG, + } + buildCh := buildWatcher.WatchBuildLogs(ctx) + + return func(yield func(*defangv1.TailResponse, error) bool) { + for { + select { + case entry, ok := <-cdCh: + if !ok { + cdCh = nil + continue + } + if entry.err != nil { + if !yield(nil, entry.err) { + return + } + continue + } + if !yield(&defangv1.TailResponse{ + Entries: []*defangv1.LogEntry{{ + Message: entry.line, + Service: cdServiceName, + Etag: etag, + Timestamp: timestamppb.Now(), + }}, + Service: cdServiceName, + Etag: etag, + }, nil) { + return + } + case svc, ok := <-acaCh: + if !ok { + acaCh = nil + continue + } + if svc.Err != nil { + term.Debugf("Container Apps log error for %q: %v", svc.AppName, svc.Err) + continue + } + if !yield(&defangv1.TailResponse{ + Entries: []*defangv1.LogEntry{{ + Message: svc.Message, + Service: svc.AppName, + Etag: etag, + Timestamp: timestamppb.Now(), + }}, + Service: svc.AppName, + Etag: etag, + }, nil) { + return + } + case build, ok := <-buildCh: + if !ok { + buildCh = nil + continue + } + if build.Err != nil { + term.Debugf("ACR build log error for %q: %v", build.Service, build.Err) + continue + } + if !yield(&defangv1.TailResponse{ + Entries: []*defangv1.LogEntry{{ + Message: build.Line, + Service: build.Service + "-build", + Etag: etag, + Stderr: true, // show build logs even when not verbose + Timestamp: timestamppb.Now(), + }}, + Service: build.Service + "-build", + Etag: etag, + }, nil) { + return + } + case <-ctx.Done(): + return + } + if cdCh == nil && acaCh == nil && buildCh == nil { + return + } + } + }, nil + } + + // Non-follow: return a snapshot of current log content. + content, err := b.job.ReadJobLogs(ctx, b.cdRunID) + if err != nil { + return nil, err + } + return func(yield func(*defangv1.TailResponse, error) bool) { + if content == "" { + return + } + yield(&defangv1.TailResponse{ + Entries: []*defangv1.LogEntry{{ + Message: content, + Service: cdServiceName, + Etag: etag, + Timestamp: timestamppb.Now(), + }}, + Service: cdServiceName, + Etag: etag, + }, nil) + }, nil +} + +// RemoteProjectName implements client.Provider. +// Subtle: this method shadows the method (*ByocBaseClient).RemoteProjectName of ByocAzure.ByocBaseClient. +func (b *ByocAzure) RemoteProjectName(context.Context) (string, error) { + return "", fmt.Errorf("RemoteProjectName: %w", errors.ErrUnsupported) +} + +// ServiceDNS implements client.Provider. +// Subtle: this method shadows the method (*ByocBaseClient).ServiceDNS of ByocAzure.ByocBaseClient. +func (b *ByocAzure) ServiceDNS(host string) string { + return host +} + +// Subscribe implements client.Provider. +func (b *ByocAzure) Subscribe(context.Context, *defangv1.SubscribeRequest) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { + return func(yield func(*defangv1.SubscribeResponse, error) bool) { + // TODO: Implement subscription to deployment events for Azure + }, nil +} + +// TearDown implements client.Provider. +func (b *ByocAzure) TearDown(ctx context.Context) error { + return b.driver.TearDown(ctx) +} + +// TearDownCD implements client.Provider. +func (b *ByocAzure) TearDownCD(context.Context) error { + return fmt.Errorf("TearDownCD: %w", errors.ErrUnsupported) +} + +// UpdateShardDomain implements client.DNSResolver. +func (b *ByocAzure) UpdateShardDomain(context.Context) error { + return fmt.Errorf("UpdateShardDomain: %w", errors.ErrUnsupported) +} diff --git a/src/pkg/cli/client/byoc/azure/byoc_test.go b/src/pkg/cli/client/byoc/azure/byoc_test.go new file mode 100644 index 000000000..08a23d8ca --- /dev/null +++ b/src/pkg/cli/client/byoc/azure/byoc_test.go @@ -0,0 +1,447 @@ +package azure + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/DefangLabs/defang/src/pkg/cli/client" + cloudazure "github.com/DefangLabs/defang/src/pkg/clouds/azure" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" +) + +type fakeCred struct { + tok string + err error +} + +func (f fakeCred) GetToken(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) { + if f.err != nil { + return azcore.AccessToken{}, f.err + } + return azcore.AccessToken{Token: f.tok, ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func useFakeCred(t *testing.T, tok string, gerr error) { + t.Helper() + orig := cloudazure.NewCredsFunc + cloudazure.NewCredsFunc = func(_ cloudazure.Azure) (azcore.TokenCredential, error) { + return fakeCred{tok: tok, err: gerr}, nil + } + t.Cleanup(func() { cloudazure.NewCredsFunc = orig }) +} + +func newTestProvider(t *testing.T, location cloudazure.Location, subID string) *ByocAzure { + t.Helper() + t.Setenv("AZURE_LOCATION", string(location)) + t.Setenv("AZURE_SUBSCRIPTION_ID", subID) + t.Setenv("AZURE_TENANT_ID", "") + t.Setenv("AZURE_CLIENT_ID", "") + b := NewByocProvider(context.Background(), "test-tenant", "test-stack") + if b == nil { + t.Fatal("NewByocProvider returned nil") + } + return b +} + +func TestNewByocProvider(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationWestUS2, "sub-id") + if b.PulumiStack != "test-stack" { + t.Errorf("PulumiStack = %q, want test-stack", b.PulumiStack) + } + if b.TenantLabel != "test-tenant" { + t.Errorf("TenantLabel = %q, want test-tenant", b.TenantLabel) + } +} + +func TestDriver(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + if got := b.Driver(); got != "azure" { + t.Errorf("Driver() = %q, want azure", got) + } +} + +func TestServiceDNS(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + host := "my-service.example" + if got := b.ServiceDNS(host); got != host { + t.Errorf("ServiceDNS(%q) = %q, want pass-through", host, got) + } +} + +func TestGetPrivateDomain(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + got := b.GetPrivateDomain("myproject") + if got == "" || got[len(got)-len(".internal"):] != ".internal" { + t.Errorf("GetPrivateDomain = %q, want *.internal", got) + } +} + +func TestProjectResourceGroupName(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationWestUS2, "sub") + if err := b.setUpLocation(); err != nil { + t.Fatalf("setUpLocation: %v", err) + } + got := b.projectResourceGroupName("myapp") + want := "defang-myapp-test-stack" + if got != want { + t.Errorf("projectResourceGroupName = %q, want %q", got, want) + } +} + +func TestSetUpLocationMissing(t *testing.T) { + t.Setenv("AZURE_LOCATION", "") + t.Setenv("AZURE_SUBSCRIPTION_ID", "") + b := NewByocProvider(context.Background(), "t", "s") + if err := b.setUpLocation(); err == nil { + t.Error("expected error when AZURE_LOCATION is unset") + } +} + +func TestSetUpLocationFromEnv(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationWestUS3, "sub-id") + if err := b.setUpLocation(); err != nil { + t.Fatalf("setUpLocation: %v", err) + } + if b.driver.Location != cloudazure.LocationWestUS3 { + t.Errorf("driver.Location = %q", b.driver.Location) + } + if b.driver.SubscriptionID != "sub-id" { + t.Errorf("driver.SubscriptionID = %q", b.driver.SubscriptionID) + } + if b.job.ResourceGroup != "defang-cd-westus3" { + t.Errorf("job.ResourceGroup = %q", b.job.ResourceGroup) + } +} + +func TestAccountInfo(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub-1") + if err := b.setUpLocation(); err != nil { + t.Fatalf("setUpLocation: %v", err) + } + info, err := b.AccountInfo(context.Background()) + if err != nil { + t.Fatalf("AccountInfo: %v", err) + } + if info.AccountID != "sub-1" { + t.Errorf("AccountID = %q, want sub-1", info.AccountID) + } + if info.Region != "eastus" { + t.Errorf("Region = %q, want eastus", info.Region) + } + if info.Provider != client.ProviderAzure { + t.Errorf("Provider = %v, want Azure", info.Provider) + } +} + +func TestSetUpCDNoOp(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + if err := b.SetUpCD(context.Background(), false); err != nil { + t.Errorf("SetUpCD should be no-op, got %v", err) + } +} + +func TestUnsupportedOps(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + + if _, err := b.Delete(context.Background(), nil); err == nil { + t.Error("Delete should return unsupported") + } + if _, err := b.RemoteProjectName(context.Background()); err == nil { + t.Error("RemoteProjectName should return unsupported") + } + if err := b.TearDownCD(context.Background()); err == nil { + t.Error("TearDownCD should return unsupported") + } + if err := b.UpdateShardDomain(context.Background()); err == nil { + t.Error("UpdateShardDomain should return unsupported") + } +} + +func TestGetServicesEmptyProjectReturnsEmpty(t *testing.T) { + // Empty project name short-circuits GetProjectUpdate with ErrNotExist, + // and GetServices translates that into an empty response — same contract + // as the AWS/GCP providers. + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + resp, err := b.GetServices(context.Background(), &defangv1.GetServicesRequest{Project: ""}) + if err != nil { + t.Fatalf("GetServices(empty project): %v", err) + } + if len(resp.Services) != 0 { + t.Errorf("expected empty services, got %d", len(resp.Services)) + } +} + +func TestGetServiceEmptyProjectNotFound(t *testing.T) { + // With no deployments, GetService should surface a NotFound for any name. + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + _, err := b.GetService(context.Background(), &defangv1.GetRequest{Project: "", Name: "app"}) + if err == nil { + t.Error("GetService should fail when the named service doesn't exist") + } +} + +func TestPrepareDomainDelegationNil(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + resp, err := b.PrepareDomainDelegation(context.Background(), client.PrepareDomainDelegationRequest{}) + if err != nil { + t.Errorf("PrepareDomainDelegation err: %v", err) + } + if resp != nil { + t.Errorf("PrepareDomainDelegation response = %v, want nil (TODO)", resp) + } +} + +func TestSubscribe(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + seq, err := b.Subscribe(context.Background(), &defangv1.SubscribeRequest{}) + if err != nil { + t.Fatalf("Subscribe err: %v", err) + } + if seq == nil { + t.Fatal("Subscribe returned nil seq") + } + // The TODO stub simply yields nothing — iterating should finish immediately. + for range seq { + t.Error("Subscribe iterator yielded unexpectedly") + } +} + +func TestGetDeploymentStatusNoRun(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + done, err := b.GetDeploymentStatus(context.Background()) + if err != nil { + t.Errorf("GetDeploymentStatus err: %v", err) + } + if done { + t.Error("GetDeploymentStatus should be not-done when cdRunID is empty") + } +} + +func TestGetDeploymentStatusCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + b.cdRunID = "run-1" + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err := b.GetDeploymentStatus(ctx) + if err == nil { + t.Error("GetDeploymentStatus should surface SDK error") + } +} + +func TestGetProjectUpdateCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := b.GetProjectUpdate(ctx, "proj"); err == nil { + t.Error("GetProjectUpdate should surface credential error") + } +} + +func TestQueryLogsUnknownEtag(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + // cdRunID is empty — QueryLogs should reject the request rather than panic. + _, err := b.QueryLogs(context.Background(), &defangv1.TailRequest{Etag: "some-etag"}) + if err == nil { + t.Error("QueryLogs should reject when cdRunID is empty") + } + var _ = errors.New // silence unused when build tag trims +} + +func TestQueryLogsEtagMismatch(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + b.cdRunID = "run-1" + b.cdEtag = "etag-A" + _, err := b.QueryLogs(context.Background(), &defangv1.TailRequest{Etag: "etag-B"}) + if err == nil { + t.Error("QueryLogs should reject etag mismatch") + } +} + +func TestAuthenticateNonInteractiveFailsWithoutCreds(t *testing.T) { + // Point the SDK at an ARM endpoint that returns 401 so DefaultAzureCredential's + // token always fails validation — no real Azure call is made by our code beyond + // hitting the subscription endpoint. + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + // interactive=false: no valid creds → error. + if err := b.Authenticate(ctx, false); err == nil { + t.Error("Authenticate(interactive=false) should fail without working creds") + } +} + +func TestDeployMissingCDImage(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + // CDImage is empty by default. + if _, err := b.Deploy(context.Background(), &client.DeployRequest{}); err == nil { + t.Error("Deploy should fail without CDImage") + } + if _, err := b.Preview(context.Background(), &client.DeployRequest{}); err == nil { + t.Error("Preview should fail without CDImage") + } +} + +func TestSetUpJobMissingCDImage(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + if err := b.setUpJob(context.Background(), nil); err == nil { + t.Error("setUpJob should fail without CDImage") + } +} + +func TestSetUpMissingLocation(t *testing.T) { + // Clear AZURE_LOCATION so setUp's setUpLocation step fails early. + t.Setenv("AZURE_LOCATION", "") + t.Setenv("AZURE_SUBSCRIPTION_ID", "") + b := NewByocProvider(context.Background(), "t", "s") + if err := b.setUp(context.Background()); err == nil { + t.Error("setUp should fail without AZURE_LOCATION") + } + // Same for setUpForConfig. + if err := b.setUpForConfig(context.Background(), "proj"); err == nil { + t.Error("setUpForConfig should fail without AZURE_LOCATION") + } + // CreateUploadURL and CdList go through setUp, so they should also fail. + if _, err := b.CreateUploadURL(context.Background(), &defangv1.UploadURLRequest{Digest: "d"}); err == nil { + t.Error("CreateUploadURL should fail without AZURE_LOCATION") + } + if _, err := b.CdList(context.Background(), false); err == nil { + t.Error("CdList should fail without AZURE_LOCATION") + } + // GetProjectUpdate with empty project bails early. + if _, err := b.GetProjectUpdate(context.Background(), ""); err == nil { + t.Error("GetProjectUpdate should fail with empty project name") + } + // DeleteConfig and ListConfig also go through setUpForConfig. + if err := b.DeleteConfig(context.Background(), &defangv1.Secrets{Project: "p"}); err == nil { + t.Error("DeleteConfig should fail without AZURE_LOCATION") + } + if _, err := b.ListConfig(context.Background(), &defangv1.ListConfigsRequest{Project: "p"}); err == nil { + t.Error("ListConfig should fail without AZURE_LOCATION") + } + if err := b.PutConfig(context.Background(), &defangv1.PutConfigRequest{Project: "p", Name: "n", Value: "v"}); err == nil { + t.Error("PutConfig should fail without AZURE_LOCATION") + } +} + +func TestCdCommandCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := b.CdCommand(ctx, client.CdCommandRequest{Project: "p", Command: "up"}); err == nil { + t.Error("CdCommand should fail when ARM calls fail") + } +} + +func TestPutConfigCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := b.PutConfig(ctx, &defangv1.PutConfigRequest{Project: "p", Name: "n", Value: "v"}); err == nil { + t.Error("PutConfig should fail when ARM calls fail") + } +} + +func TestCreateUploadURLSubset(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := b.CreateUploadURL(ctx, &defangv1.UploadURLRequest{Digest: "d"}); err == nil { + t.Error("CreateUploadURL should fail when ARM calls fail") + } +} + +func TestDeployInvalidCompose(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + b.CDImage = "img" + // An invalid compose payload should fail to load. + req := &client.DeployRequest{} + req.Compose = []byte("not valid yaml: [") + if _, err := b.Deploy(context.Background(), req); err == nil { + t.Error("Deploy should fail with invalid compose") + } +} + +func TestTearDownCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + if err := b.TearDown(context.Background()); err == nil { + t.Error("TearDown should surface credential error") + } +} + +func TestQueryLogsNonFollow(t *testing.T) { + useFakeCred(t, "tok", nil) + b := newTestProvider(t, cloudazure.LocationEastUS, "sub") + b.cdRunID = "run-1" + b.cdEtag = "etag" + + // ReadJobLogs calls Log Analytics workspace SDK client, which will fail + // without real Azure access. We just want the non-follow path to return + // an error (not panic). + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err := b.QueryLogs(ctx, &defangv1.TailRequest{Etag: "etag", Follow: false}) + if err == nil { + t.Error("QueryLogs non-follow should fail without real Azure workspace") + } +} + +func TestCdCommandMissingCDImage(t *testing.T) { + // setUpForConfig needs to succeed for CdCommand to reach the CDImage check. + // Easier: exercise the setUpLocation-fails path which happens first. + t.Setenv("AZURE_LOCATION", "") + t.Setenv("AZURE_SUBSCRIPTION_ID", "") + b := NewByocProvider(context.Background(), "t", "s") + if _, err := b.CdCommand(context.Background(), client.CdCommandRequest{Project: "p", Command: "up"}); err == nil { + t.Error("CdCommand should fail without AZURE_LOCATION") + } +} + +func TestBuildCdEnv(t *testing.T) { + b := newTestProvider(t, cloudazure.LocationWestUS2, "sub-1") + if err := b.setUpLocation(); err != nil { + t.Fatalf("setUpLocation: %v", err) + } + b.driver.StorageAccount = "acct" + b.driver.BlobContainerName = "uploads" + + env, err := b.buildCdEnv("myproj") + if err != nil { + t.Fatalf("buildCdEnv: %v", err) + } + if got := env["PROJECT"]; got != "myproj" { + t.Errorf("PROJECT = %q", got) + } + if got := env["AZURE_LOCATION"]; got != "westus2" { + t.Errorf("AZURE_LOCATION = %q", got) + } + if got := env["AZURE_SUBSCRIPTION_ID"]; got != "sub-1" { + t.Errorf("AZURE_SUBSCRIPTION_ID = %q", got) + } + // AZURE_RESOURCE_GROUP / AZURE_KEY_VAULT_NAME should NOT be passed — the + // Pulumi provider derives them deterministically from the same inputs. + if _, ok := env["AZURE_RESOURCE_GROUP"]; ok { + t.Errorf("AZURE_RESOURCE_GROUP should not be passed to CD; provider derives it") + } + if _, ok := env["AZURE_KEY_VAULT_NAME"]; ok { + t.Errorf("AZURE_KEY_VAULT_NAME should not be passed to CD; provider derives it") + } + if got := env["STACK"]; got != "test-stack" { + t.Errorf("STACK = %q", got) + } + if got := env["DEFANG_STATE_URL"]; got != "azblob://pulumi?storage_account=acct" { + t.Errorf("DEFANG_STATE_URL = %q", got) + } + if _, ok := env["PULUMI_CONFIG_PASSPHRASE"]; !ok { + t.Error("PULUMI_CONFIG_PASSPHRASE missing") + } +} diff --git a/src/pkg/cli/client/byoc/azure/login.go b/src/pkg/cli/client/byoc/azure/login.go new file mode 100644 index 000000000..e2f1a2996 --- /dev/null +++ b/src/pkg/cli/client/byoc/azure/login.go @@ -0,0 +1,7 @@ +package azure + +import "context" + +func (b *ByocAzure) Authenticate(ctx context.Context, interactive bool) error { + return b.driver.Authenticate(ctx, interactive) +} diff --git a/src/pkg/cli/client/byoc/gcp/byoc.go b/src/pkg/cli/client/byoc/gcp/byoc.go index 94c0e2c1f..4b3c515b1 100644 --- a/src/pkg/cli/client/byoc/gcp/byoc.go +++ b/src/pkg/cli/client/byoc/gcp/byoc.go @@ -412,6 +412,7 @@ type CloudBuildStep struct { } func (b *ByocGcp) runCdCommand(ctx context.Context, cmd cdCommand) (string, error) { + // From https://www.pulumi.com/docs/iac/concepts/state-and-backends/#google-cloud-storage defangStateUrl := `gs://` + b.bucket pulumiBackendKey, pulumiBackendValue, err := byoc.GetPulumiBackend(defangStateUrl) if err != nil { diff --git a/src/pkg/cli/client/provider_id.go b/src/pkg/cli/client/provider_id.go index 75932da49..add91f20f 100644 --- a/src/pkg/cli/client/provider_id.go +++ b/src/pkg/cli/client/provider_id.go @@ -11,11 +11,11 @@ type ProviderID string const ( ProviderAuto ProviderID = "auto" - ProviderDefang ProviderID = "defang" ProviderAWS ProviderID = "aws" + ProviderAzure ProviderID = "azure" + ProviderDefang ProviderID = "defang" ProviderDO ProviderID = "digitalocean" ProviderGCP ProviderID = "gcp" - // ProviderAzure ProviderID = "azure" ) var allProviders = []ProviderID{ @@ -24,7 +24,7 @@ var allProviders = []ProviderID{ ProviderAWS, ProviderDO, ProviderGCP, - // ProviderAzure, + ProviderAzure, } func AllProviders() []ProviderID { @@ -39,10 +39,12 @@ func (p ProviderID) Name() string { switch p { case ProviderAuto: return "Auto" - case ProviderDefang: - return "Defang Playground" case ProviderAWS: return "AWS" + case ProviderAzure: + return "Azure" + case ProviderDefang: + return "Defang Playground" case ProviderDO: return "DigitalOcean" case ProviderGCP: @@ -54,10 +56,12 @@ func (p ProviderID) Name() string { func (p ProviderID) Value() defangv1.Provider { switch p { - case ProviderDefang: - return defangv1.Provider_DEFANG case ProviderAWS: return defangv1.Provider_AWS + case ProviderAzure: + return defangv1.Provider_AZURE + case ProviderDefang: + return defangv1.Provider_DEFANG case ProviderDO: return defangv1.Provider_DIGITALOCEAN case ProviderGCP: @@ -80,10 +84,12 @@ func (p *ProviderID) Set(str string) error { func (p *ProviderID) SetValue(val defangv1.Provider) { switch val { - case defangv1.Provider_DEFANG: - *p = ProviderDefang case defangv1.Provider_AWS: *p = ProviderAWS + case defangv1.Provider_AZURE: + *p = ProviderAzure + case defangv1.Provider_DEFANG: + *p = ProviderDefang case defangv1.Provider_DIGITALOCEAN: *p = ProviderDO case defangv1.Provider_GCP: diff --git a/src/pkg/cli/client/region.go b/src/pkg/cli/client/region.go index f2e66d00a..f8152db33 100644 --- a/src/pkg/cli/client/region.go +++ b/src/pkg/cli/client/region.go @@ -5,9 +5,10 @@ import ( ) const ( - RegionDefaultAWS = "us-west-2" - RegionDefaultDO = "nyc3" - RegionDefaultGCP = "us-central1" // Defaults to us-central1 for lower price + RegionDefaultAWS = "us-west-2" + RegionDefaultAzure = "westus" // Default region for Azure + RegionDefaultDO = "nyc3" + RegionDefaultGCP = "us-central1" // Defaults to us-central1 for lower price ) func GetRegion(provider ProviderID) string { @@ -15,6 +16,8 @@ func GetRegion(provider ProviderID) string { switch provider { case ProviderAWS: defaultRegion = RegionDefaultAWS + case ProviderAzure: + defaultRegion = RegionDefaultAzure case ProviderGCP: defaultRegion = RegionDefaultGCP case ProviderDO: @@ -33,6 +36,8 @@ func GetRegionVarName(provider ProviderID) string { switch provider { case ProviderAWS: return "AWS_REGION" + case ProviderAzure: + return "AZURE_LOCATION" case ProviderGCP: // Try standard GCP environment variables in order of precedence GCPRegionEnvVar, _ := pkg.GetFirstEnv(pkg.GCPRegionEnvVars...) diff --git a/src/pkg/cli/compose/context.go b/src/pkg/cli/compose/context.go index 40d1106a1..2533f136d 100644 --- a/src/pkg/cli/compose/context.go +++ b/src/pkg/cli/compose/context.go @@ -259,20 +259,22 @@ func uploadArchive(ctx context.Context, provider client.Provider, projectName st } // Do an HTTP PUT to the generated URL - resp, err := http.Put(ctx, res.Url, string(archiveType.MimeType), body) + header := http.Header{"Content-Type": []string{string(archiveType.MimeType)}} + header.Set("X-Ms-Blob-Type", "BlockBlob") // HACK: move to Azure provider + resp, err := http.PutWithHeader(ctx, res.Url, header, body) if err != nil { return "", err } defer resp.Body.Close() - if resp.StatusCode != 200 { + if resp.StatusCode != 200 && resp.StatusCode != 201 { return "", fmt.Errorf("HTTP PUT failed with status code %v", resp.Status) } - url := http.RemoveQueryParam(res.Url) - const gcpPrefix = "https://storage.googleapis.com/" - if strings.HasPrefix(url, gcpPrefix) { - url = "gs://" + url[len(gcpPrefix):] - } + url := http.RemoveQueryParam(res.Url) // remove any access signature + + const gcpPrefix = "https://storage.googleapis.com/" // HACK: move to GCP provider + url = strings.Replace(url, gcpPrefix, "gs://", 1) + return url, nil } diff --git a/src/pkg/cli/connect.go b/src/pkg/cli/connect.go index e46ebcf77..6336f2e90 100644 --- a/src/pkg/cli/connect.go +++ b/src/pkg/cli/connect.go @@ -5,6 +5,7 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/aws" + "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/azure" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/do" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/gcp" "github.com/DefangLabs/defang/src/pkg/term" @@ -43,6 +44,8 @@ func NewProvider(ctx context.Context, providerID client.ProviderID, fabricClient provider = do.NewByocProvider(ctx, fabricClient.GetTenantName(), stack) case client.ProviderGCP: provider = gcp.NewByocProvider(ctx, fabricClient.GetTenantName(), stack) + case client.ProviderAzure: + provider = azure.NewByocProvider(ctx, fabricClient.GetTenantName(), stack) default: provider = client.NewPlaygroundProvider(fabricClient, stack) } diff --git a/src/pkg/cli/tailAndMonitor.go b/src/pkg/cli/tailAndMonitor.go index 183136d8f..d28fe9030 100644 --- a/src/pkg/cli/tailAndMonitor.go +++ b/src/pkg/cli/tailAndMonitor.go @@ -64,7 +64,7 @@ func TailAndMonitor(ctx context.Context, project *compose.Project, provider clie go func() { wg.Wait() - pkg.SleepWithContext(ctx, 2*time.Second) // a delay before cancelling tail to make sure we get last status messages + pkg.SleepWithContext(ctx, 6*time.Second) // a delay before cancelling tail to make sure we get last log messages cancelTail(errMonitoringDone) // cancel the tail when both goroutines are done }() diff --git a/src/pkg/clouds/azure/aca/common.go b/src/pkg/clouds/azure/aca/common.go new file mode 100644 index 000000000..128119e15 --- /dev/null +++ b/src/pkg/clouds/azure/aca/common.go @@ -0,0 +1,150 @@ +package aca + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers" + cloudazure "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +const apiVersion = "2023-05-01" + +type ContainerApp struct { + cloudazure.Azure + ResourceGroup string +} + +func (c *ContainerApp) newContainerAppsClient() (*armappcontainers.ContainerAppsClient, error) { + cred, err := c.NewCreds() + if err != nil { + return nil, err + } + return armappcontainers.NewContainerAppsClient(c.SubscriptionID, cred, nil) +} + +func (c *ContainerApp) newReplicasClient() (*armappcontainers.ContainerAppsRevisionReplicasClient, error) { + cred, err := c.NewCreds() + if err != nil { + return nil, err + } + return armappcontainers.NewContainerAppsRevisionReplicasClient(c.SubscriptionID, cred, nil) +} + +// getAuthToken fetches a short-lived token for the Container Apps log-stream endpoint. +// This operation is not yet exposed in the ARM Go SDK, so we call the REST API directly. +func (c *ContainerApp) getAuthToken(ctx context.Context, appName string) (string, error) { + return c.FetchLogStreamAuthToken(ctx, c.ResourceGroup, "Microsoft.App/containerApps/"+appName, apiVersion) +} + +// getEventStreamBase returns the host portion of the container app's eventStreamEndpoint +// (everything before "/subscriptions/"). This is not in SDK v1.1.0, so we call the REST API directly. +func (c *ContainerApp) getEventStreamBase(ctx context.Context, appName string) (string, error) { + armTok, err := c.ArmToken(ctx) + if err != nil { + return "", err + } + + url := fmt.Sprintf( + "%s/subscriptions/%s/resourceGroups/%s/providers/Microsoft.App/containerApps/%s?api-version=%s", + cloudazure.ManagementEndpoint, c.SubscriptionID, c.ResourceGroup, appName, apiVersion, + ) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+armTok) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("getContainerApp: HTTP %s", resp.Status) + } + + var result struct { + Properties struct { + EventStreamEndpoint string `json:"eventStreamEndpoint"` + } `json:"properties"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("getContainerApp: decode response: %w", err) + } + endpoint := result.Properties.EventStreamEndpoint + idx := strings.Index(endpoint, "/subscriptions/") + if idx < 0 { + return "", fmt.Errorf("unexpected eventStreamEndpoint format: %q", endpoint) + } + return endpoint[:idx], nil +} + +// ResolveLogTarget resolves the latest active revision, first replica, and first container +// name for the given app. Any of the return values that were already provided as non-empty +// strings are passed through unchanged. +func (c *ContainerApp) ResolveLogTarget(ctx context.Context, appName, revision, replica, container string) (string, string, string, error) { + if revision == "" { + appsClient, err := c.newContainerAppsClient() + if err != nil { + return "", "", "", err + } + app, err := appsClient.Get(ctx, c.ResourceGroup, appName, nil) + if err != nil { + return "", "", "", fmt.Errorf("get container app: %w", err) + } + if app.Properties == nil || app.Properties.LatestRevisionName == nil { + return "", "", "", fmt.Errorf("container app %q has no active revision", appName) + } + revision = *app.Properties.LatestRevisionName + + // Opportunistically pick the container name from the app template. + if container == "" && app.Properties.Template != nil { + for _, ctr := range app.Properties.Template.Containers { + if ctr != nil && ctr.Name != nil { + container = *ctr.Name + break + } + } + } + } + + if replica == "" { + replicasClient, err := c.newReplicasClient() + if err != nil { + return "", "", "", err + } + list, err := replicasClient.ListReplicas(ctx, c.ResourceGroup, appName, revision, nil) + if err != nil { + return "", "", "", fmt.Errorf("list replicas: %w", err) + } + if len(list.Value) == 0 || list.Value[0] == nil { + return "", "", "", fmt.Errorf("no replicas found for revision %q", revision) + } + rep := list.Value[0] + if rep.Name == nil { + return "", "", "", errors.New("replica has no name") + } + replica = *rep.Name + + // Opportunistically pick the container from the replica if still unset. + if container == "" && rep.Properties != nil { + for _, ctr := range rep.Properties.Containers { + if ctr != nil && ctr.Name != nil { + container = *ctr.Name + break + } + } + } + } + + if container == "" { + return "", "", "", fmt.Errorf("could not determine container name for app %q", appName) + } + + return revision, replica, container, nil +} diff --git a/src/pkg/clouds/azure/aca/common_test.go b/src/pkg/clouds/azure/aca/common_test.go new file mode 100644 index 000000000..0f5260332 --- /dev/null +++ b/src/pkg/clouds/azure/aca/common_test.go @@ -0,0 +1,127 @@ +package aca + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + cloudazure "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +func newTestContainerApp() *ContainerApp { + return &ContainerApp{ + Azure: cloudazure.Azure{ + SubscriptionID: "sub", + Location: cloudazure.LocationWestUS2, + }, + ResourceGroup: "rg", + } +} + +func TestContainerAppGetAuthToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, "Microsoft.App/containerApps/myapp/getAuthToken") { + t.Errorf("path = %q", r.URL.Path) + } + _, _ = w.Write([]byte(`{"properties":{"token":"ca-tok"}}`)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + c := newTestContainerApp() + tok, err := c.getAuthToken(context.Background(), "myapp") + if err != nil { + t.Fatalf("getAuthToken: %v", err) + } + if tok != "ca-tok" { + t.Errorf("token = %q", tok) + } +} + +func TestContainerAppGetEventStreamBase(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"properties":{"eventStreamEndpoint":"https://westus2.ms.management.azure.com/subscriptions/x/foo"}}`)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + c := newTestContainerApp() + base, err := c.getEventStreamBase(context.Background(), "myapp") + if err != nil { + t.Fatalf("getEventStreamBase: %v", err) + } + if base != "https://westus2.ms.management.azure.com" { + t.Errorf("base = %q", base) + } +} + +func TestContainerAppGetEventStreamBaseMalformed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"properties":{"eventStreamEndpoint":"no-subscription-path"}}`)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + c := newTestContainerApp() + if _, err := c.getEventStreamBase(context.Background(), "myapp"); err == nil { + t.Error("expected error for malformed eventStreamEndpoint") + } +} + +func TestContainerAppGetEventStreamBaseHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + c := newTestContainerApp() + if _, err := c.getEventStreamBase(context.Background(), "myapp"); err == nil { + t.Error("expected HTTP error") + } +} + +func TestContainerAppGetEventStreamBaseBadJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`not-json`)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + c := newTestContainerApp() + if _, err := c.getEventStreamBase(context.Background(), "myapp"); err == nil { + t.Error("expected decode error") + } +} + +func TestContainerAppGetEventStreamBaseArmError(t *testing.T) { + useFakeCred(t, "", errors.New("no arm token")) + c := newTestContainerApp() + if _, err := c.getEventStreamBase(context.Background(), "myapp"); err == nil { + t.Error("expected ArmToken error") + } +} + +func TestContainerAppNewClients(t *testing.T) { + useFakeCred(t, "tok", nil) + c := newTestContainerApp() + if cli, err := c.newContainerAppsClient(); err != nil || cli == nil { + t.Errorf("newContainerAppsClient: %v, client=%v", err, cli) + } + if cli, err := c.newReplicasClient(); err != nil || cli == nil { + t.Errorf("newReplicasClient: %v, client=%v", err, cli) + } +} diff --git a/src/pkg/clouds/azure/aca/job.go b/src/pkg/clouds/azure/aca/job.go new file mode 100644 index 000000000..c1d2239fc --- /dev/null +++ b/src/pkg/clouds/azure/aca/job.go @@ -0,0 +1,893 @@ +package aca + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "iter" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + armappcontainersv3 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" + "github.com/DefangLabs/defang/src/pkg/term" + "github.com/google/uuid" +) + +const ( + cdJobName = "defang-cd" + cdEnvironmentName = "defang-cd" + cdLogWorkspaceName = "defang-cd" + jobLogPollInterval = 3 * time.Second + // jobAPIVersion is required for job getAuthToken / executions/replicas which + // are not yet exposed in the stable SDK or the 2023-05-01 API version. + jobAPIVersion = "2024-02-02-preview" + // cdJobCPU and cdJobMemory size the CD container so Pulumi has enough room to + // run. The Consumption profile (the only profile available in the default + // environment we provision) caps a single replica at 2 vCPU / 4 GiB, so this + // is the largest supported pair. The default (0.25 vCPU / 0.5 GiB) was far + // too small for Pulumi previews and noticeably slowed CD runs. + cdJobCPU = 2.0 + cdJobMemory = "4Gi" +) + +// logAnalyticsEndpoint is the base URL for the Log Analytics query API, overridable for tests. +var logAnalyticsEndpoint = "https://api.loganalytics.io" + +// JobRequest contains parameters for starting a Container Apps Job execution. +type JobRequest struct { + Image string + Command []string + // Envs are plain-text environment variables. + Envs map[string]string + // SecretEnvs are environment variables that should be stored as secrets (not shown in plain text). + SecretEnvs map[string]string + // Timeout is the maximum execution duration. + Timeout time.Duration +} + +// JobStatus represents the status of a Container Apps Job execution. +type JobStatus struct { + ExecutionName string + Status armappcontainersv3.JobExecutionRunningState + ErrorMessage string +} + +// IsTerminal returns true if the execution has reached a final state. +func (s *JobStatus) IsTerminal() bool { + switch s.Status { + case armappcontainersv3.JobExecutionRunningStateSucceeded, + armappcontainersv3.JobExecutionRunningStateFailed, + armappcontainersv3.JobExecutionRunningStateStopped, + armappcontainersv3.JobExecutionRunningStateDegraded: + return true + } + return false +} + +// IsSuccess returns true if the execution completed successfully. +func (s *JobStatus) IsSuccess() bool { + return s.Status == armappcontainersv3.JobExecutionRunningStateSucceeded +} + +// Job manages Container Apps Jobs and the environment they run in. +// It owns the CD job lifecycle: creating the environment, setting up the job, +// running executions, streaming logs, and assigning the job's managed identity roles. +type Job struct { + azure.Azure + ResourceGroup string + EnvironmentID string + SystemPrincipalID string + cdJobImage string + identitySetUp bool +} + +func (j *Job) newManagedEnvironmentsClient() (*armappcontainersv3.ManagedEnvironmentsClient, error) { + cred, err := j.NewCreds() + if err != nil { + return nil, err + } + client, err := armappcontainersv3.NewManagedEnvironmentsClient(j.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create managed environments client: %w", err) + } + return client, nil +} + +func (j *Job) newJobsClient() (*armappcontainersv3.JobsClient, error) { + cred, err := j.NewCreds() + if err != nil { + return nil, err + } + client, err := armappcontainersv3.NewJobsClient(j.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create jobs client: %w", err) + } + return client, nil +} + +func (j *Job) newJobsExecutionsClient() (*armappcontainersv3.JobsExecutionsClient, error) { + cred, err := j.NewCreds() + if err != nil { + return nil, err + } + client, err := armappcontainersv3.NewJobsExecutionsClient(j.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create jobs executions client: %w", err) + } + return client, nil +} + +// setUpLogWorkspace creates (or retrieves) a Log Analytics workspace and returns its +// customer ID (workspace GUID) and primary shared key, which are needed to configure +// ACA environment log streaming. +func (j *Job) setUpLogWorkspace(ctx context.Context) (customerID, sharedKey string, err error) { + cred, err := j.NewCreds() + if err != nil { + return "", "", err + } + wsClient, err := armoperationalinsights.NewWorkspacesClient(j.SubscriptionID, cred, nil) + if err != nil { + return "", "", fmt.Errorf("failed to create log analytics workspaces client: %w", err) + } + + // Create or update the workspace (idempotent). + term.Debugf("setUpLogWorkspace: creating/updating workspace %q in %q", cdLogWorkspaceName, j.ResourceGroup) + wsPoller, err := wsClient.BeginCreateOrUpdate(ctx, j.ResourceGroup, cdLogWorkspaceName, armoperationalinsights.Workspace{ + Location: j.Location.Ptr(), + Properties: &armoperationalinsights.WorkspaceProperties{ + SKU: &armoperationalinsights.WorkspaceSKU{ + Name: to.Ptr(armoperationalinsights.WorkspaceSKUNameEnumPerGB2018), + }, + RetentionInDays: to.Ptr(int32(30)), + }, + }, nil) + if err != nil { + return "", "", fmt.Errorf("failed to create log analytics workspace: %w", err) + } + wsResult, err := wsPoller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return "", "", fmt.Errorf("failed to poll workspace creation: %w", err) + } + if wsResult.Properties == nil || wsResult.Properties.CustomerID == nil { + return "", "", errors.New("log analytics workspace did not return a customer ID") + } + customerID = *wsResult.Properties.CustomerID + + // Fetch the shared key (not available on the workspace resource itself). + keysClient, err := armoperationalinsights.NewSharedKeysClient(j.SubscriptionID, cred, nil) + if err != nil { + return "", "", fmt.Errorf("failed to create shared keys client: %w", err) + } + keysResp, err := keysClient.GetSharedKeys(ctx, j.ResourceGroup, cdLogWorkspaceName, nil) + if err != nil { + return "", "", fmt.Errorf("failed to get workspace shared keys: %w", err) + } + if keysResp.PrimarySharedKey == nil { + return "", "", errors.New("log analytics workspace returned no primary shared key") + } + return customerID, *keysResp.PrimarySharedKey, nil +} + +// SetUpEnvironment creates (or retrieves) the Container Apps Environment that hosts the CD job. +// It also creates a Log Analytics workspace and configures the environment to stream logs +// there so they're visible in the Azure portal and via Log Analytics queries. +// The environment resource ID is stored in j.EnvironmentID. +func (j *Job) SetUpEnvironment(ctx context.Context) error { + if j.EnvironmentID != "" { + term.Debugf("SetUpEnvironment: already set (%s)", j.EnvironmentID) + return nil + } + + // Set up Log Analytics workspace first so we can wire it into the environment. + customerID, sharedKey, err := j.setUpLogWorkspace(ctx) + if err != nil { + return err + } + + envClient, err := j.newManagedEnvironmentsClient() + if err != nil { + return err + } + + appLogsConfig := &armappcontainersv3.AppLogsConfiguration{ + Destination: to.Ptr("log-analytics"), + LogAnalyticsConfiguration: &armappcontainersv3.LogAnalyticsConfiguration{ + CustomerID: to.Ptr(customerID), + SharedKey: to.Ptr(sharedKey), + }, + } + + term.Debugf("SetUpEnvironment: checking if %q exists in %q", cdEnvironmentName, j.ResourceGroup) + if resp, err := envClient.Get(ctx, j.ResourceGroup, cdEnvironmentName, nil); err == nil { + // Environment exists. Ensure its AppLogsConfiguration points to our workspace + // (idempotent update — safe to run on every call). + term.Debugf("SetUpEnvironment: updating existing environment %s to use Log Analytics", *resp.ID) + updatePoller, err := envClient.BeginCreateOrUpdate(ctx, j.ResourceGroup, cdEnvironmentName, armappcontainersv3.ManagedEnvironment{ + Location: j.Location.Ptr(), + Properties: &armappcontainersv3.ManagedEnvironmentProperties{ + ZoneRedundant: to.Ptr(false), + AppLogsConfiguration: appLogsConfig, + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to update container apps environment: %w", err) + } + result, err := updatePoller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return fmt.Errorf("failed to poll environment update: %w", err) + } + j.EnvironmentID = *result.ID + return nil + } + + term.Infof("Creating Container Apps environment %q in %q", cdEnvironmentName, j.ResourceGroup) + poller, err := envClient.BeginCreateOrUpdate(ctx, j.ResourceGroup, cdEnvironmentName, armappcontainersv3.ManagedEnvironment{ + Location: j.Location.Ptr(), + Properties: &armappcontainersv3.ManagedEnvironmentProperties{ + ZoneRedundant: to.Ptr(false), + AppLogsConfiguration: appLogsConfig, + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to create container apps environment: %w", err) + } + result, err := poller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return fmt.Errorf("failed to poll environment creation: %w", err) + } + j.EnvironmentID = *result.ID + term.Infof("Created Container Apps environment %s", j.EnvironmentID) + return nil +} + +// Well-known Azure built-in role definition IDs. +const ( + storageBlobDataContributorRoleID = "ba92f5b4-2d11-453d-a403-e96b0029c9fe" // nolint:gosec + contributorRoleID = "b24988ac-6180-42a0-ab88-20f7382dd24c" // nolint:gosec + userAccessAdministratorRoleID = "18d7d88d-d35e-4fb5-a5c3-7773c20a72d9" // nolint:gosec + keyVaultSecretsUserRoleID = "4633458b-17de-408a-b874-0445c86b69e6" // nolint:gosec +) + +// assignRole assigns a built-in role to the given principal at the given scope. +// It silently ignores RoleAssignmentExists errors (idempotent). +func assignRole(ctx context.Context, raClient *armauthorization.RoleAssignmentsClient, subscriptionID, scope, roleDefID, principalID string) error { + fullRoleDefID := fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", subscriptionID, roleDefID) + _, err := raClient.Create(ctx, scope, uuid.NewString(), armauthorization.RoleAssignmentCreateParameters{ + Properties: &armauthorization.RoleAssignmentProperties{ + PrincipalID: to.Ptr(principalID), + RoleDefinitionID: to.Ptr(fullRoleDefID), + PrincipalType: to.Ptr(armauthorization.PrincipalTypeServicePrincipal), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) || respErr.ErrorCode != "RoleAssignmentExists" { + return err + } + } + return nil +} + +// SetUpManagedIdentity assigns the necessary roles to the CD job's system-assigned managed +// identity so it can provision Azure resources and access Pulumi state in storageAccount. +// SetUpJob must be called before this to populate SystemPrincipalID. +func (j *Job) SetUpManagedIdentity(ctx context.Context, storageAccount string) error { + if j.identitySetUp { + return nil + } + if j.SystemPrincipalID == "" { + return errors.New("CD job system-assigned identity principal ID is not set; ensure SetUpJob was called first") + } + + cred, err := j.NewCreds() + if err != nil { + return err + } + raClient, err := armauthorization.NewRoleAssignmentsClient(j.SubscriptionID, cred, nil) + if err != nil { + return fmt.Errorf("failed to create role assignments client: %w", err) + } + + // Contributor + User Access Administrator on the subscription so Pulumi can provision any + // Azure resource and create role assignments (e.g. ACR pull role for Container Apps). + subscriptionScope := "/subscriptions/" + j.SubscriptionID + if err := assignRole(ctx, raClient, j.SubscriptionID, subscriptionScope, contributorRoleID, j.SystemPrincipalID); err != nil { + return fmt.Errorf("failed to assign Contributor role: %w", err) + } + if err := assignRole(ctx, raClient, j.SubscriptionID, subscriptionScope, userAccessAdministratorRoleID, j.SystemPrincipalID); err != nil { + return fmt.Errorf("failed to assign User Access Administrator role: %w", err) + } + // Key Vault Secrets User on the subscription so the CD container can read project + // secrets from the Key Vault (used both by the CD itself and by Pulumi). + if err := assignRole(ctx, raClient, j.SubscriptionID, subscriptionScope, keyVaultSecretsUserRoleID, j.SystemPrincipalID); err != nil { + return fmt.Errorf("failed to assign Key Vault Secrets User role: %w", err) + } + + // Storage Blob Data Contributor on the storage account for Pulumi state and payload access. + storageScope := fmt.Sprintf( + "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Storage/storageAccounts/%s", + j.SubscriptionID, j.ResourceGroup, storageAccount, + ) + if err := assignRole(ctx, raClient, j.SubscriptionID, storageScope, storageBlobDataContributorRoleID, j.SystemPrincipalID); err != nil { + return fmt.Errorf("failed to assign Storage Blob Data Contributor role: %w", err) + } + + j.identitySetUp = true + return nil +} + +// SetUpJob creates (or updates) the Container Apps Job used to run the CD container. +// Environment variables are baked into the job template so they're available to every +// execution (the execution-time override for env vars is unreliable — the job template is +// the authoritative source). +// The CD image is pulled anonymously; the image's registry must allow anonymous pull. +// It enables a system-assigned managed identity on the job and stores the principal ID +// in j.SystemPrincipalID for subsequent role assignments. +// SetUpEnvironment must be called first. +func (j *Job) SetUpJob(ctx context.Context, image string, envMap map[string]string) error { + if j.EnvironmentID == "" { + return errors.New("environment ID is not set; ensure SetUpEnvironment was called first") + } + + term.Debugf("SetUpJob: creating/updating job %q with image %q (%d env vars)", cdJobName, image, len(envMap)) + jobsClient, err := j.newJobsClient() + if err != nil { + return err + } + + var envVars []*armappcontainersv3.EnvironmentVar + for k, v := range envMap { + envVars = append(envVars, &armappcontainersv3.EnvironmentVar{ + Name: to.Ptr(k), + Value: to.Ptr(v), + }) + } + + timeout := int32((30 * time.Minute).Seconds()) + const tmpVolumeName = "tmp" + poller, err := jobsClient.BeginCreateOrUpdate(ctx, j.ResourceGroup, cdJobName, armappcontainersv3.Job{ + Location: j.Location.Ptr(), + Identity: &armappcontainersv3.ManagedServiceIdentity{ + Type: to.Ptr(armappcontainersv3.ManagedServiceIdentityTypeSystemAssigned), + }, + Properties: &armappcontainersv3.JobProperties{ + EnvironmentID: to.Ptr(j.EnvironmentID), + Configuration: &armappcontainersv3.JobConfiguration{ + TriggerType: to.Ptr(armappcontainersv3.TriggerTypeManual), + ReplicaTimeout: to.Ptr(timeout), + ReplicaRetryLimit: to.Ptr(int32(0)), + }, + Template: &armappcontainersv3.JobTemplate{ + Volumes: []*armappcontainersv3.Volume{ + { + Name: to.Ptr(tmpVolumeName), + StorageType: to.Ptr(armappcontainersv3.StorageTypeEmptyDir), + }, + }, + Containers: []*armappcontainersv3.Container{ + { + Name: to.Ptr(cdJobName), + Image: to.Ptr(image), + Env: envVars, + Resources: &armappcontainersv3.ContainerResources{ + CPU: to.Ptr(cdJobCPU), + Memory: to.Ptr(cdJobMemory), + }, + VolumeMounts: []*armappcontainersv3.VolumeMount{ + { + VolumeName: to.Ptr(tmpVolumeName), + MountPath: to.Ptr("/tmp"), + }, + }, + }, + }, + }, + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to create/update CD job: %w", err) + } + + result, err := poller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return fmt.Errorf("failed to poll CD job creation: %w", err) + } + + if result.Identity != nil && result.Identity.PrincipalID != nil { + j.SystemPrincipalID = *result.Identity.PrincipalID + } + j.cdJobImage = image + return nil +} + +// StartJobExecution starts a new execution of the CD job with the given image, command, +// and environment variables. Returns the execution name. +func (j *Job) StartJobExecution(ctx context.Context, req JobRequest) (string, error) { + jobsClient, err := j.newJobsClient() + if err != nil { + return "", err + } + + // Build environment variable list. Secrets are stored on the job and referenced by name. + var envVars []*armappcontainersv3.EnvironmentVar + var secrets []*armappcontainersv3.Secret + for k, v := range req.Envs { + envVars = append(envVars, &armappcontainersv3.EnvironmentVar{ + Name: to.Ptr(k), + Value: to.Ptr(v), + }) + } + for k, v := range req.SecretEnvs { + secretName := strings.ToLower(strings.ReplaceAll(k, "_", "-")) + secrets = append(secrets, &armappcontainersv3.Secret{ + Name: to.Ptr(secretName), + Value: to.Ptr(v), + }) + envVars = append(envVars, &armappcontainersv3.EnvironmentVar{ + Name: to.Ptr(k), + SecretRef: to.Ptr(secretName), + }) + } + + // Update job secrets if any were added. + if len(secrets) > 0 { + secretsPoller, err := jobsClient.BeginUpdate(ctx, j.ResourceGroup, cdJobName, armappcontainersv3.JobPatchProperties{ + Properties: &armappcontainersv3.JobPatchPropertiesProperties{ + Configuration: &armappcontainersv3.JobConfiguration{ + TriggerType: to.Ptr(armappcontainersv3.TriggerTypeManual), + ReplicaTimeout: to.Ptr(int32((30 * time.Minute).Seconds())), + Secrets: secrets, + }, + }, + }, nil) + if err != nil { + return "", fmt.Errorf("failed to update job secrets: %w", err) + } + if _, err := secretsPoller.PollUntilDone(ctx, azure.PollOptions); err != nil { + return "", fmt.Errorf("failed to poll job secrets update: %w", err) + } + } + + // Build the command args list. + var args []*string + for _, a := range req.Command[1:] { + args = append(args, to.Ptr(a)) + } + var cmd []*string + if len(req.Command) > 0 { + cmd = []*string{to.Ptr(req.Command[0])} + } + + // Resources must be repeated here: ACA replaces matching containers from the + // execution override rather than merging field-by-field, so omitting Resources + // silently falls back to the platform default (0.25 vCPU / 0.5 GiB) regardless + // of what the job template says. + execContainer := &armappcontainersv3.JobExecutionContainer{ + Name: to.Ptr(cdJobName), + Image: to.Ptr(req.Image), + Env: envVars, + Resources: &armappcontainersv3.ContainerResources{ + CPU: to.Ptr(cdJobCPU), + Memory: to.Ptr(cdJobMemory), + }, + } + if len(cmd) > 0 { + execContainer.Command = cmd + execContainer.Args = args + } + + poller, err := jobsClient.BeginStart(ctx, j.ResourceGroup, cdJobName, &armappcontainersv3.JobsClientBeginStartOptions{ + Template: &armappcontainersv3.JobExecutionTemplate{ + Containers: []*armappcontainersv3.JobExecutionContainer{execContainer}, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to start job execution: %w", err) + } + + result, err := poller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return "", fmt.Errorf("failed to poll job start: %w", err) + } + + if result.Name == nil { + return "", errors.New("job execution started but returned no name") + } + return *result.Name, nil +} + +// GetJobExecutionStatus returns the current status of a job execution by listing executions +// and finding the one with the given name. +func (j *Job) GetJobExecutionStatus(ctx context.Context, executionName string) (*JobStatus, error) { + execClient, err := j.newJobsExecutionsClient() + if err != nil { + return nil, err + } + + pager := execClient.NewListPager(j.ResourceGroup, cdJobName, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list job executions: %w", err) + } + for _, exec := range page.Value { + if exec.Name != nil && *exec.Name == executionName { + status := &JobStatus{ExecutionName: executionName} + if exec.Properties != nil && exec.Properties.Status != nil { + status.Status = *exec.Properties.Status + } + return status, nil + } + } + } + return nil, fmt.Errorf("execution %q not found", executionName) +} + +// TailJobLogs streams real-time container logs from the running job execution +// by opening the container's logStreamEndpoint (the same one the Azure portal +// uses). This delivers output within seconds, unlike Log Analytics (ReadJobLogs) +// which typically lags by minutes. If the execution fails, the iterator yields +// a terminal error after the stream closes. For historical queries on older +// executions, use ReadJobLogs. +// +// The stream can drop mid-execution (e.g. when the log endpoint reports a +// transient "Kubernetes error" while the replica is starting), so we reconnect +// until the job reaches a terminal state. +func (j *Job) TailJobLogs(ctx context.Context, executionName string) (iter.Seq2[string, error], error) { + return func(yield func(string, error) bool) { + // 300 on the first successful connect catches output emitted during pod + // startup; 0 on reconnects avoids re-printing lines after a transient drop. + const initialBackfill = 300 + connected := false + + for { + if ctx.Err() != nil { + return + } + + status, err := j.GetJobExecutionStatus(ctx, executionName) + if err == nil && status.IsTerminal() { + // Drain any remaining logs once more and then return. If we never + // successfully streamed anything, request backfill; otherwise just + // pick up trailing output we haven't seen. + backfill := 0 + if !connected { + backfill = initialBackfill + } + if logCh, err := j.streamJobExecutionLogs(ctx, executionName, backfill); err == nil { + forwardStream(ctx, logCh, yield) + } + if !status.IsSuccess() { + msg := string(status.Status) + if status.ErrorMessage != "" { + msg += ": " + status.ErrorMessage + } + yield("", fmt.Errorf("CD job %s: %s", executionName, msg)) + } + return + } + + backfill := 0 + if !connected { + backfill = initialBackfill + } + logCh, err := j.streamJobExecutionLogs(ctx, executionName, backfill) + if err != nil { + term.Debugf("TailJobLogs: waiting for replica: %v", err) + select { + case <-ctx.Done(): + return + case <-time.After(jobLogPollInterval): + } + continue + } + gotLines, keepGoing := forwardStream(ctx, logCh, yield) + if !keepGoing { + return + } + // Only mark as connected after actually receiving a line. A stream + // that opens and closes immediately (e.g. transient Kubernetes error) + // shouldn't consume our one-shot backfill budget. + if gotLines { + connected = true + } + // Stream closed — loop back to reconnect or detect terminal state. + } + }, nil +} + +// forwardStream forwards all log lines from ch to yield. Returns (gotLines, keepGoing): +// gotLines is true when at least one message was forwarded, keepGoing is false +// when yield signals an early exit (consumer stopped iterating). +func forwardStream(ctx context.Context, ch <-chan LogEntry, yield func(string, error) bool) (bool, bool) { + gotLines := false + for entry := range ch { + if ctx.Err() != nil { + return gotLines, false + } + if entry.Err != nil { + if !yield("", entry.Err) { + return gotLines, false + } + continue + } + gotLines = true + if !yield(entry.Message, nil) { + return gotLines, false + } + } + return gotLines, true +} + +// getJobAuthToken fetches a short-lived bearer token accepted by the job's +// logStreamEndpoint. The token differs from the ARM token and is required even +// though the URL is already scoped to the subscription. +func (j *Job) getJobAuthToken(ctx context.Context) (string, error) { + return j.FetchLogStreamAuthToken(ctx, j.ResourceGroup, "Microsoft.App/jobs/"+cdJobName, jobAPIVersion) +} + +// getCDContainerLogStreamURL lists the execution's replicas and returns the +// logstream URL of the main cdJobName container, only once the container has +// reached a runningState where the log endpoint actually serves output +// (Running or Terminated — anything earlier returns a Kubernetes error). +// Returns an empty string (no error) while the replica is still initialising. +func (j *Job) getCDContainerLogStreamURL(ctx context.Context, executionName string) (string, error) { + armTok, err := j.ArmToken(ctx) + if err != nil { + return "", err + } + + url := fmt.Sprintf( + "%s/subscriptions/%s/resourceGroups/%s/providers/Microsoft.App/jobs/%s/executions/%s/replicas?api-version=%s", + azure.ManagementEndpoint, j.SubscriptionID, j.ResourceGroup, cdJobName, executionName, jobAPIVersion, + ) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+armTok) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("listReplicas: HTTP %s", resp.Status) + } + + var result struct { + Value []struct { + Properties struct { + Containers []struct { + Name string `json:"name"` + RunningState string `json:"runningState"` + LogStreamEndpoint string `json:"logStreamEndpoint"` + } `json:"containers"` + } `json:"properties"` + } `json:"value"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("listReplicas: decode: %w", err) + } + + for _, r := range result.Value { + for _, c := range r.Properties.Containers { + if c.Name != cdJobName || c.LogStreamEndpoint == "" { + continue + } + switch c.RunningState { + case "Running", "Terminated": + return c.LogStreamEndpoint, nil + } + } + } + return "", nil +} + +// streamJobExecutionLogs opens the job container's logStreamEndpoint and +// returns a channel that emits log lines until the container exits or ctx is +// cancelled. Returns an error when the replica is not yet available so the +// caller can retry. +// +// backfillLines controls how much of the container's existing log buffer is +// replayed on connect (capped at 300 by the API). Use a large value on the +// first connect to capture output emitted during pod startup, and 0 on +// reconnects so we don't re-print lines we already streamed. +func (j *Job) streamJobExecutionLogs(ctx context.Context, executionName string, backfillLines int) (<-chan LogEntry, error) { + streamURL, err := j.getCDContainerLogStreamURL(ctx, executionName) + if err != nil { + return nil, err + } + if streamURL == "" { + return nil, errors.New("no replica container with logStreamEndpoint yet") + } + + authToken, err := j.getJobAuthToken(ctx) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, streamURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + q := req.URL.Query() + q.Set("follow", "true") + q.Set("output", "text") + if backfillLines > 0 { + q.Set("tailLines", strconv.Itoa(backfillLines)) + } + req.URL.RawQuery = q.Encode() + + resp, err := http.DefaultClient.Do(req) // nolint: bodyclose // body is closed in the goroutine + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("logstream: HTTP %s", resp.Status) + } + + ch := make(chan LogEntry) + go func() { + defer close(ch) + defer resp.Body.Close() + scanner := bufio.NewScanner(resp.Body) + // Pulumi output lines can be long (especially Diagnostics blocks). + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + select { + case ch <- LogEntry{Message: line}: + case <-ctx.Done(): + return + } + } + if err := scanner.Err(); err != nil && ctx.Err() == nil { + select { + case ch <- LogEntry{Err: err}: + case <-ctx.Done(): + } + } + }() + return ch, nil +} + +// ReadJobLogs returns all log output captured for a job execution from Log Analytics. +// Subject to a short ingestion delay (seconds to a couple of minutes on cold workspaces). +func (j *Job) ReadJobLogs(ctx context.Context, executionName string) (string, error) { + return j.fetchLogsFromWorkspace(ctx, executionName) +} + +// getLogAnalyticsToken returns a Bearer token for the Log Analytics query API. +func (j *Job) getLogAnalyticsToken(ctx context.Context) (string, error) { + cred, err := j.NewCreds() + if err != nil { + return "", err + } + tok, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://api.loganalytics.io/.default"}, + }) + if err != nil { + return "", err + } + return tok.Token, nil +} + +// getLogWorkspaceCustomerID returns the customer ID (GUID) of the CD Log Analytics +// workspace. This is what the Log Analytics query API addresses workspaces by. +func (j *Job) getLogWorkspaceCustomerID(ctx context.Context) (string, error) { + cred, err := j.NewCreds() + if err != nil { + return "", err + } + wsClient, err := armoperationalinsights.NewWorkspacesClient(j.SubscriptionID, cred, nil) + if err != nil { + return "", fmt.Errorf("creating log analytics workspaces client: %w", err) + } + resp, err := wsClient.Get(ctx, j.ResourceGroup, cdLogWorkspaceName, nil) + if err != nil { + return "", fmt.Errorf("getting log analytics workspace: %w", err) + } + if resp.Properties == nil || resp.Properties.CustomerID == nil { + return "", errors.New("log analytics workspace has no customer ID") + } + return *resp.Properties.CustomerID, nil +} + +// fetchLogsFromWorkspace queries Log Analytics for all console log entries belonging to +// the given job execution, ordered by time. Returns empty string when the workspace has +// no rows yet (first-time workspaces can take 2–5 minutes to ingest data). +func (j *Job) fetchLogsFromWorkspace(ctx context.Context, executionName string) (string, error) { + workspaceID, err := j.getLogWorkspaceCustomerID(ctx) + if err != nil { + return "", err + } + return j.fetchLogsByWorkspaceID(ctx, workspaceID, executionName) +} + +// fetchLogsByWorkspaceID is the lower half of fetchLogsFromWorkspace, kept separate +// so tests can exercise it with a known workspace ID without needing the SDK +// workspaces client to be mocked. +func (j *Job) fetchLogsByWorkspaceID(ctx context.Context, workspaceID, executionName string) (string, error) { + token, err := j.getLogAnalyticsToken(ctx) + if err != nil { + return "", err + } + + // Filter by pod name (ContainerGroupName_s), which has the form + // "{executionName}-{randomsuffix}" — ContainerJobName_s is always just the job name + // ("defang-cd") so it can't disambiguate executions. Execution names are Azure-generated + // alphanumeric + hyphens, so no quoting hazard inlining them into the query. + query := fmt.Sprintf( + `ContainerAppConsoleLogs_CL `+ + `| where ContainerName_s == "%s" and ContainerGroupName_s startswith "%s-" `+ + `| order by TimeGenerated asc `+ + `| project TimeGenerated, Log_s`, + cdJobName, executionName, + ) + body, err := json.Marshal(map[string]string{"query": query}) + if err != nil { + return "", err + } + + url := logAnalyticsEndpoint + "/v1/workspaces/" + workspaceID + "/query" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("log analytics query: HTTP %s", resp.Status) + } + + var result struct { + Tables []struct { + Rows [][]any `json:"rows"` + } `json:"tables"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("decode log analytics response: %w", err) + } + + var sb strings.Builder + if len(result.Tables) > 0 { + for _, row := range result.Tables[0].Rows { + if len(row) < 2 { + continue + } + ts, _ := row[0].(string) + line, _ := row[1].(string) + sb.WriteString(ts) + sb.WriteByte(' ') + sb.WriteString(line) + if !strings.HasSuffix(line, "\n") { + sb.WriteByte('\n') + } + } + } + return sb.String(), nil +} diff --git a/src/pkg/clouds/azure/aca/job_test.go b/src/pkg/clouds/azure/aca/job_test.go new file mode 100644 index 000000000..a3b2acc15 --- /dev/null +++ b/src/pkg/clouds/azure/aca/job_test.go @@ -0,0 +1,666 @@ +package aca + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + armappcontainersv3 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3" + cloudazure "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +type fakeCredential struct { + token string + err error +} + +func (f fakeCredential) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + if f.err != nil { + return azcore.AccessToken{}, f.err + } + return azcore.AccessToken{Token: f.token, ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func useFakeCred(t *testing.T, tok string, gerr error) { + t.Helper() + orig := cloudazure.NewCredsFunc + cloudazure.NewCredsFunc = func(_ cloudazure.Azure) (azcore.TokenCredential, error) { + return fakeCredential{token: tok, err: gerr}, nil + } + t.Cleanup(func() { cloudazure.NewCredsFunc = orig }) +} + +func useTestEndpoints(t *testing.T, mgmtURL, logAnalyticsURL string) { + t.Helper() + origMgmt := cloudazure.ManagementEndpoint + origLA := logAnalyticsEndpoint + cloudazure.ManagementEndpoint = mgmtURL + if logAnalyticsURL != "" { + logAnalyticsEndpoint = logAnalyticsURL + } + t.Cleanup(func() { + cloudazure.ManagementEndpoint = origMgmt + logAnalyticsEndpoint = origLA + }) +} + +func TestJobStatusIsTerminal(t *testing.T) { + tests := []struct { + state armappcontainersv3.JobExecutionRunningState + want bool + }{ + {armappcontainersv3.JobExecutionRunningStateSucceeded, true}, + {armappcontainersv3.JobExecutionRunningStateFailed, true}, + {armappcontainersv3.JobExecutionRunningStateStopped, true}, + {armappcontainersv3.JobExecutionRunningStateDegraded, true}, + {armappcontainersv3.JobExecutionRunningStateRunning, false}, + {armappcontainersv3.JobExecutionRunningStateProcessing, false}, + } + for _, tt := range tests { + s := &JobStatus{Status: tt.state} + if got := s.IsTerminal(); got != tt.want { + t.Errorf("IsTerminal(%q) = %v, want %v", tt.state, got, tt.want) + } + } +} + +func TestJobStatusIsSuccess(t *testing.T) { + s := &JobStatus{Status: armappcontainersv3.JobExecutionRunningStateSucceeded} + if !s.IsSuccess() { + t.Error("Succeeded state should be success") + } + s.Status = armappcontainersv3.JobExecutionRunningStateFailed + if s.IsSuccess() { + t.Error("Failed state should not be success") + } +} + +func TestForwardStream(t *testing.T) { + ctx := context.Background() + ch := make(chan LogEntry, 3) + ch <- LogEntry{Message: "a"} + ch <- LogEntry{Message: "b"} + ch <- LogEntry{Message: "c"} + close(ch) + + var got []string + gotLines, keepGoing := forwardStream(ctx, ch, func(msg string, err error) bool { + if err != nil { + t.Errorf("unexpected err: %v", err) + } + got = append(got, msg) + return true + }) + if !gotLines { + t.Error("gotLines should be true") + } + if !keepGoing { + t.Error("keepGoing should be true after drain") + } + if len(got) != 3 || got[0] != "a" || got[2] != "c" { + t.Errorf("forwarded = %v", got) + } +} + +func TestForwardStreamEmpty(t *testing.T) { + ctx := context.Background() + ch := make(chan LogEntry) + close(ch) + gotLines, keepGoing := forwardStream(ctx, ch, func(string, error) bool { return true }) + if gotLines { + t.Error("gotLines should be false for empty stream") + } + if !keepGoing { + t.Error("keepGoing should be true for empty stream") + } +} + +func TestForwardStreamErrorEntry(t *testing.T) { + ctx := context.Background() + ch := make(chan LogEntry, 2) + ch <- LogEntry{Err: context.Canceled} + ch <- LogEntry{Message: "after err"} + close(ch) + + var sawErr bool + var msgCount int + gotLines, keepGoing := forwardStream(ctx, ch, func(msg string, err error) bool { + if err != nil { + sawErr = true + } else { + msgCount++ + } + return true + }) + if !sawErr { + t.Error("expected error to be forwarded") + } + if msgCount != 1 { + t.Errorf("msgCount = %d, want 1", msgCount) + } + if !gotLines { + t.Error("gotLines should be true (non-error line was forwarded)") + } + if !keepGoing { + t.Error("keepGoing should be true") + } +} + +func TestForwardStreamEarlyExit(t *testing.T) { + ctx := context.Background() + ch := make(chan LogEntry, 3) + ch <- LogEntry{Message: "a"} + ch <- LogEntry{Message: "b"} + ch <- LogEntry{Message: "c"} + close(ch) + + count := 0 + gotLines, keepGoing := forwardStream(ctx, ch, func(string, error) bool { + count++ + return count < 2 // stop after second call + }) + if keepGoing { + t.Error("keepGoing should be false when yield returns false") + } + if !gotLines { + t.Error("gotLines should be true") + } +} + +func TestForwardStreamCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ch := make(chan LogEntry, 1) + ch <- LogEntry{Message: "ignored"} + close(ch) + + _, keepGoing := forwardStream(ctx, ch, func(string, error) bool { return true }) + if keepGoing { + t.Error("keepGoing should be false when context is cancelled") + } +} + +func newTestJob() *Job { + return &Job{ + Azure: cloudazure.Azure{ + SubscriptionID: "sub", + Location: cloudazure.LocationWestUS2, + }, + ResourceGroup: "rg", + } +} + +func TestGetJobAuthToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, "Microsoft.App/jobs/defang-cd/getAuthToken") { + t.Errorf("path = %q", r.URL.Path) + } + _, _ = w.Write([]byte(`{"properties":{"token":"jwt-here"}}`)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + j := newTestJob() + tok, err := j.getJobAuthToken(context.Background()) + if err != nil { + t.Fatalf("getJobAuthToken: %v", err) + } + if tok != "jwt-here" { + t.Errorf("token = %q", tok) + } +} + +func TestGetCDContainerLogStreamURLRunning(t *testing.T) { + // replicas response: one container in Running state. + resp := `{"value":[{"properties":{"containers":[ + {"name":"defang-cd","runningState":"Running","logStreamEndpoint":"https://example/stream"} + ]}}]}` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, "executions/exec-1/replicas") { + t.Errorf("path = %q", r.URL.Path) + } + _, _ = w.Write([]byte(resp)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + j := newTestJob() + url, err := j.getCDContainerLogStreamURL(context.Background(), "exec-1") + if err != nil { + t.Fatalf("getCDContainerLogStreamURL: %v", err) + } + if url != "https://example/stream" { + t.Errorf("url = %q", url) + } +} + +func TestGetCDContainerLogStreamURLWaiting(t *testing.T) { + // Container is still Waiting — expect empty URL (caller retries). + resp := `{"value":[{"properties":{"containers":[ + {"name":"defang-cd","runningState":"Waiting","logStreamEndpoint":"https://example/stream"} + ]}}]}` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(resp)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + j := newTestJob() + url, err := j.getCDContainerLogStreamURL(context.Background(), "exec-1") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if url != "" { + t.Errorf("URL should be empty while container is Waiting, got %q", url) + } +} + +func TestGetCDContainerLogStreamURLMissingContainer(t *testing.T) { + resp := `{"value":[]}` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(resp)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + j := newTestJob() + url, err := j.getCDContainerLogStreamURL(context.Background(), "exec-1") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if url != "" { + t.Errorf("URL should be empty when replicas list is empty") + } +} + +func TestGetCDContainerLogStreamURLHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + j := newTestJob() + if _, err := j.getCDContainerLogStreamURL(context.Background(), "exec-1"); err == nil { + t.Error("expected error for 401") + } +} + +func TestStreamJobExecutionLogsNoReplica(t *testing.T) { + // Empty replicas list — streamJobExecutionLogs should surface an error so + // the caller can retry. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"value":[]}`)) + })) + defer srv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, srv.URL, "") + + j := newTestJob() + if _, err := j.streamJobExecutionLogs(context.Background(), "exec-1", 0); err == nil { + t.Error("expected error when no replica") + } +} + +func TestStreamJobExecutionLogs(t *testing.T) { + streamBody := "first\nsecond\nthird\n" + + // Chain two servers: the first serves replicas + auth token, the second + // serves the logstream. The replicas response points at the stream server. + var streamSrv *httptest.Server + streamSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("follow") != "true" { + t.Errorf("follow = %q, want true", r.URL.Query().Get("follow")) + } + _, _ = w.Write([]byte(streamBody)) + })) + defer streamSrv.Close() + + mgmtSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "getAuthToken"): + _, _ = w.Write([]byte(`{"properties":{"token":"stream-tok"}}`)) + case strings.Contains(r.URL.Path, "replicas"): + resp := `{"value":[{"properties":{"containers":[ + {"name":"defang-cd","runningState":"Running","logStreamEndpoint":"` + streamSrv.URL + `"} + ]}}]}` + _, _ = w.Write([]byte(resp)) + default: + t.Errorf("unexpected path %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer mgmtSrv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, mgmtSrv.URL, "") + + j := newTestJob() + ch, err := j.streamJobExecutionLogs(context.Background(), "exec-1", 0) + if err != nil { + t.Fatalf("streamJobExecutionLogs: %v", err) + } + var got []string + for entry := range ch { + if entry.Err != nil { + t.Errorf("entry err: %v", entry.Err) + continue + } + got = append(got, entry.Message) + } + if len(got) != 3 || got[0] != "first" || got[2] != "third" { + t.Errorf("got lines %v", got) + } +} + +func TestStreamJobExecutionLogsBackfill(t *testing.T) { + // Verify that backfillLines > 0 adds tailLines query param. + var gotTail string + streamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotTail = r.URL.Query().Get("tailLines") + _, _ = w.Write([]byte("x\n")) + })) + defer streamSrv.Close() + + mgmtSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "getAuthToken"): + _, _ = w.Write([]byte(`{"properties":{"token":"tok"}}`)) + case strings.Contains(r.URL.Path, "replicas"): + resp := `{"value":[{"properties":{"containers":[ + {"name":"defang-cd","runningState":"Running","logStreamEndpoint":"` + streamSrv.URL + `"} + ]}}]}` + _, _ = w.Write([]byte(resp)) + } + })) + defer mgmtSrv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, mgmtSrv.URL, "") + + j := newTestJob() + ch, err := j.streamJobExecutionLogs(context.Background(), "exec-1", 250) + if err != nil { + t.Fatalf("streamJobExecutionLogs: %v", err) + } + for range ch { + } + if gotTail != "250" { + t.Errorf("tailLines = %q, want 250", gotTail) + } +} + +func TestStreamJobExecutionLogsHTTPFailure(t *testing.T) { + streamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer streamSrv.Close() + + mgmtSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "getAuthToken"): + _, _ = w.Write([]byte(`{"properties":{"token":"t"}}`)) + case strings.Contains(r.URL.Path, "replicas"): + resp := `{"value":[{"properties":{"containers":[ + {"name":"defang-cd","runningState":"Running","logStreamEndpoint":"` + streamSrv.URL + `"} + ]}}]}` + _, _ = w.Write([]byte(resp)) + } + })) + defer mgmtSrv.Close() + + useFakeCred(t, "arm", nil) + useTestEndpoints(t, mgmtSrv.URL, "") + + j := newTestJob() + if _, err := j.streamJobExecutionLogs(context.Background(), "exec-1", 0); err == nil { + t.Error("expected error for 403 from stream endpoint") + } +} + +func TestStreamJobExecutionLogsCredError(t *testing.T) { + mgmtSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := `{"value":[{"properties":{"containers":[ + {"name":"defang-cd","runningState":"Running","logStreamEndpoint":"https://example/x"} + ]}}]}` + _, _ = w.Write([]byte(resp)) + })) + defer mgmtSrv.Close() + + useFakeCred(t, "", errors.New("cred fail")) + useTestEndpoints(t, mgmtSrv.URL, "") + + j := newTestJob() + // First call (replicas list) goes through ArmToken → fails. + if _, err := j.streamJobExecutionLogs(context.Background(), "exec-1", 0); err == nil { + t.Error("expected credential error") + } +} + +func TestReadJobLogs(t *testing.T) { + // Log Analytics query returns two rows of (timestamp, line). + laBody := `{"tables":[{"rows":[ + ["2026-04-17T16:00:00Z","hello"], + ["2026-04-17T16:00:01Z","world"] + ]}]}` + laSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v1/workspaces/") || !strings.HasSuffix(r.URL.Path, "/query") { + t.Errorf("LA path = %q", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("LA method = %s", r.Method) + } + _, _ = w.Write([]byte(laBody)) + })) + defer laSrv.Close() + + // The workspace customerID is fetched from ARM via the SDK's workspaces client, + // which is hard to mock. We bypass by only testing the fetchLogsFromWorkspace + // helper directly with a pre-known workspaceID (via a small thin wrapper). + // The helper is private but we can still drive the Log Analytics endpoint + // through the public ReadJobLogs path with a pre-populated workspace. + // Since we cannot populate the SDK client response, we at least verify the + // LA path by calling fetchLogsFromWorkspace indirectly. + useFakeCred(t, "la-tok", nil) + useTestEndpoints(t, "http://unused", laSrv.URL) + + j := newTestJob() + // Call the low-level fetch function directly to avoid the SDK workspace lookup. + got, err := j.fetchLogsByWorkspaceID(context.Background(), "workspace-guid", "exec-1") + if err != nil { + t.Fatalf("fetchLogsByWorkspaceID: %v", err) + } + if !strings.Contains(got, "hello") || !strings.Contains(got, "world") { + t.Errorf("logs = %q", got) + } +} + +func TestReadJobLogsTokenError(t *testing.T) { + useFakeCred(t, "", errors.New("token denied")) + j := newTestJob() + if _, err := j.fetchLogsByWorkspaceID(context.Background(), "ws", "exec"); err == nil { + t.Error("expected token error") + } +} + +func TestReadJobLogsHTTPError(t *testing.T) { + laSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer laSrv.Close() + + useFakeCred(t, "tok", nil) + useTestEndpoints(t, "http://unused", laSrv.URL) + + j := newTestJob() + if _, err := j.fetchLogsByWorkspaceID(context.Background(), "ws", "exec"); err == nil { + t.Error("expected error for 500") + } +} + +func TestReadJobLogsBadJSON(t *testing.T) { + laSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`not-json`)) + })) + defer laSrv.Close() + + useFakeCred(t, "tok", nil) + useTestEndpoints(t, "http://unused", laSrv.URL) + + j := newTestJob() + if _, err := j.fetchLogsByWorkspaceID(context.Background(), "ws", "exec"); err == nil { + t.Error("expected decode error") + } +} + +func TestNewClients(t *testing.T) { + useFakeCred(t, "tok", nil) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if c, err := j.newManagedEnvironmentsClient(); err != nil || c == nil { + t.Errorf("newManagedEnvironmentsClient: %v, client=%v", err, c) + } + if c, err := j.newJobsClient(); err != nil || c == nil { + t.Errorf("newJobsClient: %v, client=%v", err, c) + } + if c, err := j.newJobsExecutionsClient(); err != nil || c == nil { + t.Errorf("newJobsExecutionsClient: %v, client=%v", err, c) + } +} + +func TestSetUpManagedIdentityPreconditions(t *testing.T) { + useFakeCred(t, "tok", nil) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + // SystemPrincipalID is not set. + if err := j.SetUpManagedIdentity(context.Background(), "acct"); err == nil { + t.Error("expected error when SystemPrincipalID is empty") + } + + // idempotent when identitySetUp is true. + j.identitySetUp = true + if err := j.SetUpManagedIdentity(context.Background(), "acct"); err != nil { + t.Errorf("identity already set up should short-circuit, got %v", err) + } +} + +func TestSetUpEnvironmentShortCircuit(t *testing.T) { + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg", EnvironmentID: "/already"} + if err := j.SetUpEnvironment(context.Background()); err != nil { + t.Errorf("SetUpEnvironment should short-circuit when EnvironmentID is set, got %v", err) + } +} + +func TestSetUpJobMissingEnvironment(t *testing.T) { + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if err := j.SetUpJob(context.Background(), "image", nil); err == nil { + t.Error("SetUpJob should fail when EnvironmentID is empty") + } +} + +func TestStartJobExecutionCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if _, err := j.StartJobExecution(context.Background(), JobRequest{ + Image: "img", + Command: []string{"/bin/true"}, + }); err == nil { + t.Error("expected cred error") + } +} + +func TestTailJobLogsCancelled(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + seq, err := j.TailJobLogs(ctx, "exec-1") + if err != nil { + t.Fatalf("TailJobLogs: %v", err) + } + for range seq { + // drain + } +} + +func TestGetJobExecutionStatusCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if _, err := j.GetJobExecutionStatus(context.Background(), "exec"); err == nil { + t.Error("expected cred error") + } +} + +func TestReadJobLogsCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if _, err := j.ReadJobLogs(context.Background(), "exec"); err == nil { + t.Error("expected cred error") + } +} + +func TestGetLogAnalyticsTokenCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if _, err := j.getLogAnalyticsToken(context.Background()); err == nil { + t.Error("expected cred error") + } +} + +func TestGetLogWorkspaceCustomerIDCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if _, err := j.getLogWorkspaceCustomerID(context.Background()); err == nil { + t.Error("expected cred error") + } +} + +func TestSetUpLogWorkspaceCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + if _, _, err := j.setUpLogWorkspace(context.Background()); err == nil { + t.Error("expected cred error") + } +} + +func TestFetchLogsFromWorkspaceSDKError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + // fetchLogsFromWorkspace first calls getLogWorkspaceCustomerID which uses + // the SDK (will fail), then bails. + if _, err := j.fetchLogsFromWorkspace(context.Background(), "exec"); err == nil { + t.Error("expected error from fetchLogsFromWorkspace") + } +} + +func TestFetchLogsFromWorkspaceViaTailJobLogs(t *testing.T) { + // Exercise the non-error return path in forwardStream by producing one + // terminal status and no logs. + useFakeCred(t, "", errors.New("denied")) + j := &Job{Azure: cloudazure.Azure{SubscriptionID: "sub"}, ResourceGroup: "rg"} + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + seq, err := j.TailJobLogs(ctx, "exec") + if err != nil { + t.Fatalf("TailJobLogs: %v", err) + } + for range seq { + } +} diff --git a/src/pkg/clouds/azure/aca/tail.go b/src/pkg/clouds/azure/aca/tail.go new file mode 100644 index 000000000..d47377517 --- /dev/null +++ b/src/pkg/clouds/azure/aca/tail.go @@ -0,0 +1,198 @@ +package aca + +import ( + "bufio" + "context" + "fmt" + "net/http" + "strconv" + "sync" + "time" +) + +const watchInterval = 5 * time.Second + +type LogEntry struct { + Message string + Err error +} + +// ServiceLogEntry is a LogEntry annotated with the Container App name it came from. +type ServiceLogEntry struct { + AppName string + LogEntry +} + +// WatchLogs polls the resource group for Container Apps every watchInterval and streams +// logs from each one as soon as it is discovered. New apps that appear after the initial +// poll are picked up automatically. +func (c *ContainerApp) WatchLogs(ctx context.Context) <-chan ServiceLogEntry { + out := make(chan ServiceLogEntry) + go func() { + // streaming tracks apps that currently have a live tail goroutine. An + // app is re-added to the map on the next poll once its goroutine exits + // (so replicas that roll or streams that drop mid-run are retried). + var mu sync.Mutex + streaming := map[string]struct{}{} + + // senders tracks inner goroutines that send on `out` (per-app tailers + // and pollErr). We must wait for all of them before closing `out`, + // otherwise a `case out <- …` racing with our close panics with + // "send on closed channel". + var senders sync.WaitGroup + defer func() { + senders.Wait() + close(out) + }() + + startTailing := func(appName string) { + senders.Add(1) + go func() { + defer senders.Done() + defer func() { + mu.Lock() + delete(streaming, appName) + mu.Unlock() + }() + appCh, err := c.StreamLogs(ctx, appName, "", "", "", true) + if err != nil { + select { + case out <- ServiceLogEntry{AppName: appName, LogEntry: LogEntry{Err: err}}: + case <-ctx.Done(): + } + return + } + for entry := range appCh { + select { + case out <- ServiceLogEntry{AppName: appName, LogEntry: entry}: + case <-ctx.Done(): + return + } + } + }() + } + + sendErr := func(err error) { + select { + case out <- ServiceLogEntry{LogEntry: LogEntry{Err: err}}: + case <-ctx.Done(): + } + } + + poll := func() { + client, err := c.newContainerAppsClient() + if err != nil { + sendErr(fmt.Errorf("WatchLogs: create container apps client: %w", err)) + return + } + pager := client.NewListByResourceGroupPager(c.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + sendErr(fmt.Errorf("WatchLogs: list container apps: %w", err)) + return + } + for _, app := range page.Value { + if app == nil || app.Name == nil { + continue + } + name := *app.Name + mu.Lock() + if _, active := streaming[name]; active { + mu.Unlock() + continue + } + streaming[name] = struct{}{} + mu.Unlock() + startTailing(name) + } + } + } + + poll() + ticker := time.NewTicker(watchInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + poll() + } + } + }() + return out +} + +// StreamLogs streams real-time logs from a Container App container via Server-Sent Events. +// revision, replica, and container may be empty; they will be resolved to the latest active +// revision, first replica, and first container automatically. +// When follow is false, the stream ends when there are no more buffered log lines. +func (c *ContainerApp) StreamLogs(ctx context.Context, appName, revision, replica, container string, follow bool) (<-chan LogEntry, error) { + var err error + revision, replica, container, err = c.ResolveLogTarget(ctx, appName, revision, replica, container) + if err != nil { + return nil, err + } + + baseURL, err := c.getEventStreamBase(ctx, appName) + if err != nil { + return nil, err + } + + authToken, err := c.getAuthToken(ctx, appName) + if err != nil { + return nil, err + } + + streamURL := fmt.Sprintf( + "%s/subscriptions/%s/resourceGroups/%s/containerApps/%s/revisions/%s/replicas/%s/containers/%s/logstream", + baseURL, c.SubscriptionID, c.ResourceGroup, appName, revision, replica, container, + ) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, streamURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + + q := req.URL.Query() + q.Set("follow", strconv.FormatBool(follow)) + q.Set("output", "text") + req.URL.RawQuery = q.Encode() + + resp, err := http.DefaultClient.Do(req) // nolint resp.Body is closed by the goroutine below via defer resp.Body.Close() + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("log stream: HTTP %s", resp.Status) + } + + ch := make(chan LogEntry) + go func() { + defer close(ch) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + select { + case ch <- LogEntry{Message: line}: + case <-ctx.Done(): + return + } + } + if err := scanner.Err(); err != nil && ctx.Err() == nil { + select { + case ch <- LogEntry{Err: err}: + case <-ctx.Done(): + } + } + }() + return ch, nil +} diff --git a/src/pkg/clouds/azure/aca/tail_test.go b/src/pkg/clouds/azure/aca/tail_test.go new file mode 100644 index 000000000..96e694c10 --- /dev/null +++ b/src/pkg/clouds/azure/aca/tail_test.go @@ -0,0 +1,267 @@ +package aca + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + cloudazure "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +func TestWatchLogsCancelled(t *testing.T) { + useFakeCred(t, "tok", nil) + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + ch := c.WatchLogs(ctx) + for range ch { + // drain + } + // If we got here, WatchLogs properly exits on ctx cancellation. +} + +func TestStreamLogsResolveFailure(t *testing.T) { + // With a fake credential, the SDK construct succeeds but API calls fail. + // StreamLogs should surface the error from ResolveLogTarget. + useFakeCred(t, "tok", nil) + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := c.StreamLogs(ctx, "myapp", "", "", "", true); err == nil { + t.Error("StreamLogs should fail when SDK calls can't reach ARM") + } +} + +func TestResolveLogTargetAllProvided(t *testing.T) { + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + // When every arg is non-empty, ResolveLogTarget returns them as-is without + // any API call. + rev, rep, con, err := c.ResolveLogTarget(context.Background(), "app", "rev1", "rep1", "ctr1") + if err != nil { + t.Fatalf("ResolveLogTarget: %v", err) + } + if rev != "rev1" || rep != "rep1" || con != "ctr1" { + t.Errorf("got (%q, %q, %q)", rev, rep, con) + } +} + +func TestResolveLogTargetMissingContainer(t *testing.T) { + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + // revision + replica provided, container empty — no SDK call, but container + // resolution fails with an error. + if _, _, _, err := c.ResolveLogTarget(context.Background(), "app", "rev", "rep", ""); err == nil { + t.Error("ResolveLogTarget should fail when container can't be determined") + } +} + +func TestResolveLogTargetSDKFailure(t *testing.T) { + useFakeCred(t, "tok", nil) + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + // With no revision provided, ResolveLogTarget calls appsClient.Get which + // will attempt a real HTTP request and fail. + if _, _, _, err := c.ResolveLogTarget(ctx, "app", "", "", ""); err == nil { + t.Error("ResolveLogTarget should error when SDK call fails") + } +} + +func TestStreamLogsCredError(t *testing.T) { + // Credential-layer error during ResolveLogTarget. + orig := cloudazure.NewCredsFunc + cloudazure.NewCredsFunc = func(_ cloudazure.Azure) (azcore.TokenCredential, error) { + return nil, errors.New("cred fail") + } + t.Cleanup(func() { cloudazure.NewCredsFunc = orig }) + + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + if _, err := c.StreamLogs(context.Background(), "app", "", "", "", true); err == nil { + t.Error("StreamLogs should fail when credential resolution fails") + } +} + +func TestStreamLogsMissingRevisionSDKError(t *testing.T) { + // With a working fake cred but no access to ARM, the SDK call in + // ResolveLogTarget fails. + useFakeCred(t, "tok", nil) + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + // Pass revision+replica, container empty — bypasses the Get() SDK call but + // still needs container resolution which fails → fast error, no network. + if _, _, _, err := c.ResolveLogTarget(ctx, "app", "rev", "rep", ""); err == nil { + t.Error("ResolveLogTarget should fail when container cannot be resolved") + } +} + +func TestStreamLogsFullPath(t *testing.T) { + // Serve the Container App GET (for eventStreamEndpoint), getAuthToken, + // and the logstream on the same httptest server. + var eventStreamEndpoint string + mgmt := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/getAuthToken"): + _, _ = w.Write([]byte(`{"properties":{"token":"stream-tok"}}`)) + case strings.HasSuffix(r.URL.Path, "/containerApps/myapp"): + // Use the same test server's base URL in the endpoint so the logstream + // request ends up here too. + _, _ = w.Write([]byte(`{"properties":{"eventStreamEndpoint":"` + eventStreamEndpoint + `/subscriptions/sub/foo"}}`)) + case strings.Contains(r.URL.Path, "/logstream"): + if !strings.Contains(r.URL.Path, "containerApps/myapp/revisions/rev/replicas/rep/containers/ctr/logstream") { + t.Errorf("stream path = %q", r.URL.Path) + } + _, _ = w.Write([]byte("a\nb\n")) + default: + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer mgmt.Close() + eventStreamEndpoint = mgmt.URL + + useFakeCred(t, "arm", nil) + origMgmt := cloudazure.ManagementEndpoint + cloudazure.ManagementEndpoint = mgmt.URL + t.Cleanup(func() { cloudazure.ManagementEndpoint = origMgmt }) + + c := &ContainerApp{ + Azure: cloudazure.Azure{ + SubscriptionID: "sub", + Location: cloudazure.LocationWestUS2, + }, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + ch, err := c.StreamLogs(ctx, "myapp", "rev", "rep", "ctr", true) + if err != nil { + t.Fatalf("StreamLogs: %v", err) + } + var got []string + for entry := range ch { + if entry.Err != nil { + continue + } + got = append(got, entry.Message) + } + if len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Errorf("got lines %v", got) + } +} + +func TestStreamLogsAuthTokenError(t *testing.T) { + mgmt := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/containerApps/myapp"): + _, _ = w.Write([]byte(`{"properties":{"eventStreamEndpoint":"https://unused/subscriptions/s/x"}}`)) + case strings.Contains(r.URL.Path, "/getAuthToken"): + w.WriteHeader(http.StatusForbidden) + } + })) + defer mgmt.Close() + + useFakeCred(t, "arm", nil) + origMgmt := cloudazure.ManagementEndpoint + cloudazure.ManagementEndpoint = mgmt.URL + t.Cleanup(func() { cloudazure.ManagementEndpoint = origMgmt }) + + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := c.StreamLogs(ctx, "myapp", "rev", "rep", "ctr", false); err == nil { + t.Error("StreamLogs should fail when getAuthToken returns 403") + } +} + +func TestStreamLogsHTTPError(t *testing.T) { + var base string + mgmt := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/containerApps/myapp"): + _, _ = w.Write([]byte(`{"properties":{"eventStreamEndpoint":"` + base + `/subscriptions/s/x"}}`)) + case strings.Contains(r.URL.Path, "/getAuthToken"): + _, _ = w.Write([]byte(`{"properties":{"token":"t"}}`)) + case strings.Contains(r.URL.Path, "/logstream"): + w.WriteHeader(http.StatusInternalServerError) + } + })) + defer mgmt.Close() + base = mgmt.URL + + useFakeCred(t, "arm", nil) + origMgmt := cloudazure.ManagementEndpoint + cloudazure.ManagementEndpoint = mgmt.URL + t.Cleanup(func() { cloudazure.ManagementEndpoint = origMgmt }) + + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := c.StreamLogs(ctx, "myapp", "rev", "rep", "ctr", false); err == nil { + t.Error("StreamLogs should fail when log stream returns 500") + } +} + +func TestWatchLogsNewClientOK(t *testing.T) { + // With fake cred succeeding, newContainerAppsClient works; WatchLogs + // begins polling and the poll will error on ARM call, but the retry loop + // continues until ctx is cancelled. + useFakeCred(t, "tok", nil) + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + ch := c.WatchLogs(ctx) + for range ch { + } +} + +func TestResolveLogTargetCredError(t *testing.T) { + // Swap in a credential function that actually errors out at construction. + orig := cloudazure.NewCredsFunc + cloudazure.NewCredsFunc = func(_ cloudazure.Azure) (azcore.TokenCredential, error) { + return nil, errors.New("no cred") + } + t.Cleanup(func() { cloudazure.NewCredsFunc = orig }) + + c := &ContainerApp{ + Azure: cloudazure.Azure{SubscriptionID: "sub", Location: cloudazure.LocationWestUS2}, + ResourceGroup: "rg", + } + if _, _, _, err := c.ResolveLogTarget(context.Background(), "app", "", "", ""); err == nil { + t.Error("expected credential error") + } +} diff --git a/src/pkg/clouds/azure/acr/buildlogs.go b/src/pkg/clouds/azure/acr/buildlogs.go new file mode 100644 index 000000000..c2a739474 --- /dev/null +++ b/src/pkg/clouds/azure/acr/buildlogs.go @@ -0,0 +1,264 @@ +package acr + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" + "github.com/DefangLabs/defang/src/pkg/term" +) + +const buildPollInterval = 5 * time.Second + +// BuildLogEntry is a log line from an ACR task run, annotated with the service being built. +type BuildLogEntry struct { + Service string + Line string + Err error +} + +// BuildLogWatcher polls for ACR task runs in a resource group and streams their logs. +type BuildLogWatcher struct { + azure.Azure + ResourceGroup string +} + +// WatchBuildLogs discovers ACR registries in the resource group, polls for active task +// runs, and streams their build log output. The registry itself is created lazily by +// Pulumi during the CD run, so registry discovery is retried on every poll until one +// appears (or ctx is cancelled). The returned channel is closed when ctx is cancelled. +func (w *BuildLogWatcher) WatchBuildLogs(ctx context.Context) <-chan BuildLogEntry { + out := make(chan BuildLogEntry) + go func() { + // senders tracks the per-run streaming goroutines started by poll(). + // We must wait for all of them to finish before closing `out`, otherwise + // a `case out <- …` racing with our close panics with + // "send on closed channel". + var senders sync.WaitGroup + defer func() { + senders.Wait() + close(out) + }() + watchStart := time.Now().Add(-2 * time.Minute) // catch builds that started up to 2 min before tailing + + cred, err := w.NewCreds() + if err != nil { + term.Debugf("WatchBuildLogs: failed to get credentials: %v", err) + return + } + + regClient, err := armcontainerregistry.NewRegistriesClient(w.SubscriptionID, cred, nil) + if err != nil { + term.Debugf("WatchBuildLogs: failed to create registries client: %v", err) + return + } + runsClient, err := armcontainerregistry.NewRunsClient(w.SubscriptionID, cred, nil) + if err != nil { + term.Debugf("WatchBuildLogs: failed to create runs client: %v", err) + return + } + + // Registry is discovered lazily — Pulumi creates it partway through the CD run, + // so it is not guaranteed to exist when WatchBuildLogs starts. + var registryName string + // Track runs we're already streaming so we don't duplicate. + streaming := map[string]struct{}{} + // defaultService is learned from any run's OutputImages (populated after + // completion) so we can label in-progress runs with the right service name. + defaultService := "" + + findRegistry := func() string { + pager := regClient.NewListByResourceGroupPager(w.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + term.Debugf("WatchBuildLogs: failed to list registries: %v", err) + return "" + } + for _, reg := range page.Value { + if reg.Name != nil { + return *reg.Name + } + } + } + return "" + } + + poll := func() { + if registryName == "" { + registryName = findRegistry() + if registryName == "" { + return // no registry yet; retry next tick + } + term.Debugf("WatchBuildLogs: found registry %q in %q", registryName, w.ResourceGroup) + } + + // List the most recent runs (no status filter) so we catch builds that + // started and finished between polls. + top := int32(10) + pager := runsClient.NewListPager(w.ResourceGroup, registryName, &armcontainerregistry.RunsClientListOptions{ + Top: &top, + }) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + term.Debugf("WatchBuildLogs: failed to list runs: %v", err) + return + } + for _, run := range page.Value { + if run.Properties == nil || run.Properties.RunID == nil { + continue + } + // Learn service name from any completed run's OutputImages. + if imgs := run.Properties.OutputImages; len(imgs) > 0 && imgs[0].Repository != nil { + defaultService = *imgs[0].Repository + } + runID := *run.Properties.RunID + if _, ok := streaming[runID]; ok { + continue + } + // Only stream runs that started after the watcher was created. + if run.Properties.CreateTime != nil && run.Properties.CreateTime.Before(watchStart) { + continue + } + streaming[runID] = struct{}{} + service := defaultService + if service == "" { + service = runID + } + term.Debugf("WatchBuildLogs: streaming run %s (service %s)", runID, service) + senders.Add(1) + go func() { + defer senders.Done() + w.streamRunLog(ctx, runsClient, w.ResourceGroup, registryName, runID, service, out) + }() + } + } + } + + poll() + ticker := time.NewTicker(buildPollInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + poll() + } + } + }() + return out +} + +// streamRunLog polls GetLogSasURL for a run and streams new log content as it grows. +func (w *BuildLogWatcher) streamRunLog( + ctx context.Context, + runsClient *armcontainerregistry.RunsClient, + rgName, registryName, runID, service string, + out chan<- BuildLogEntry, +) { + var lastLen int + + for { + // Get the log SAS URL (regenerated each call but points to the same growing blob). + logResp, err := runsClient.GetLogSasURL(ctx, rgName, registryName, runID, nil) + if err != nil { + term.Debugf("streamRunLog %s: GetLogSasURL error: %v", runID, err) + select { + case <-ctx.Done(): + return + case <-time.After(buildPollInterval): + } + continue + } + if logResp.LogLink == nil { + select { + case <-ctx.Done(): + return + case <-time.After(buildPollInterval): + } + continue + } + + // Fetch the full log content and emit only new lines. + content, err := fetchLogContent(ctx, *logResp.LogLink) + if err != nil { + term.Debugf("streamRunLog %s: fetch error: %v", runID, err) + } else if len(content) > lastLen { + newContent := content[lastLen:] + lastLen = len(content) + for _, line := range strings.Split(strings.TrimRight(newContent, "\n"), "\n") { + if line == "" { + continue + } + select { + case out <- BuildLogEntry{Service: service, Line: line}: + case <-ctx.Done(): + return + } + } + } + + // Check if run is still active. + runResp, err := runsClient.Get(ctx, rgName, registryName, runID, nil) + if err == nil && runResp.Properties != nil && runResp.Properties.Status != nil { + switch *runResp.Properties.Status { + case armcontainerregistry.RunStatusSucceeded, + armcontainerregistry.RunStatusFailed, + armcontainerregistry.RunStatusError, + armcontainerregistry.RunStatusCanceled, + armcontainerregistry.RunStatusTimeout: + // Do one final fetch to capture any remaining log lines. + if logResp.LogLink != nil { + finalContent, err := fetchLogContent(ctx, *logResp.LogLink) + if err == nil && len(finalContent) > lastLen { + for _, line := range strings.Split(strings.TrimRight(finalContent[lastLen:], "\n"), "\n") { + if line == "" { + continue + } + select { + case out <- BuildLogEntry{Service: service, Line: line}: + case <-ctx.Done(): + return + } + } + } + } + return + } + } + + select { + case <-ctx.Done(): + return + case <-time.After(buildPollInterval): + } + } +} + +func fetchLogContent(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %s", resp.Status) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB max + if err != nil { + return "", err + } + return string(body), nil +} diff --git a/src/pkg/clouds/azure/acr/buildlogs_test.go b/src/pkg/clouds/azure/acr/buildlogs_test.go new file mode 100644 index 000000000..0dfe8e6f2 --- /dev/null +++ b/src/pkg/clouds/azure/acr/buildlogs_test.go @@ -0,0 +1,206 @@ +package acr + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +func armRunsClientFromCred(cred azcore.TokenCredential) (*armcontainerregistry.RunsClient, error) { + return armcontainerregistry.NewRunsClient("sub", cred, nil) +} + +type fakeCred struct { + tok string + err error +} + +func (f fakeCred) GetToken(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) { + if f.err != nil { + return azcore.AccessToken{}, f.err + } + return azcore.AccessToken{Token: f.tok, ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func useFakeCred(t *testing.T, tok string, gerr error) { + t.Helper() + orig := azure.NewCredsFunc + azure.NewCredsFunc = func(_ azure.Azure) (azcore.TokenCredential, error) { + return fakeCred{tok: tok, err: gerr}, nil + } + t.Cleanup(func() { azure.NewCredsFunc = orig }) +} + +func TestFetchLogContent(t *testing.T) { + body := "line 1\nline 2\nline 3\n" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(body)) + })) + defer srv.Close() + + got, err := fetchLogContent(context.Background(), srv.URL) + if err != nil { + t.Fatalf("fetchLogContent error: %v", err) + } + if got != body { + t.Errorf("content = %q, want %q", got, body) + } +} + +func TestFetchLogContentNotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + _, err := fetchLogContent(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error for 404 response") + } + if !strings.Contains(err.Error(), "404") { + t.Errorf("error should mention 404: %v", err) + } +} + +func TestFetchLogContentTruncation(t *testing.T) { + // The helper caps reads at 1MB — verify it doesn't blow up on larger content. + huge := strings.Repeat("x", 2*1024*1024) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(huge)) + })) + defer srv.Close() + + got, err := fetchLogContent(context.Background(), srv.URL) + if err != nil { + t.Fatalf("fetchLogContent error: %v", err) + } + // Should be capped at the 1MB limit. + if len(got) > 1024*1024 { + t.Errorf("content length = %d, expected <= 1MB", len(got)) + } +} + +func TestFetchLogContentCancelledContext(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := fetchLogContent(ctx, srv.URL); err == nil { + t.Error("expected error from cancelled context") + } +} + +func TestBuildLogEntry(t *testing.T) { + e := BuildLogEntry{Service: "app", Line: "hello"} + if e.Service != "app" || e.Line != "hello" { + t.Errorf("BuildLogEntry fields = %+v", e) + } +} + +func TestWatchBuildLogsCredError(t *testing.T) { + // With a bad credential, the goroutine should exit without emitting entries + // and close the channel promptly. + useFakeCred(t, "", errors.New("denied")) + + w := &BuildLogWatcher{ + Azure: azure.Azure{ + SubscriptionID: "sub", + Location: azure.LocationWestUS2, + }, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + ch := w.WatchBuildLogs(ctx) + // With good credential but a non-existent RG, the watcher retries indefinitely + // until ctx cancels. With bad credential (like here), clients may still be + // constructed but will error on each call; the goroutine should still exit + // when ctx expires without emitting anything. + select { + case entry, ok := <-ch: + if ok && entry.Err != nil { + // Acceptable: surfaced an error. + } + // Keep draining until close. + for range ch { + } + case <-ctx.Done(): + // Give the goroutine a chance to see cancellation and close the channel. + } + // Drain to make sure the goroutine exits. + for range ch { + } +} + +func TestStreamRunLogCancelled(t *testing.T) { + // streamRunLog retries GetLogSasURL forever while it fails; with ctx + // cancelled it should bail promptly. + useFakeCred(t, "", errors.New("denied")) + + cred, err := azure.Azure{SubscriptionID: "sub"}.NewCreds() + if err != nil { + t.Fatalf("NewCreds: %v", err) + } + runsClient, err := armRunsClientFromCred(cred) + if err != nil { + t.Fatalf("runs client: %v", err) + } + + w := &BuildLogWatcher{ + Azure: azure.Azure{ + SubscriptionID: "sub", + Location: azure.LocationWestUS2, + }, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + out := make(chan BuildLogEntry) + done := make(chan struct{}) + go func() { + w.streamRunLog(ctx, runsClient, "rg", "registry", "run-id", "svc", out) + close(done) + }() + go func() { + for range out { + } + }() + <-done +} + +func TestWatchBuildLogsCancelled(t *testing.T) { + // Verify the watcher exits promptly when ctx is cancelled, even when no + // registry is ever found. + useFakeCred(t, "tok", nil) + + w := &BuildLogWatcher{ + Azure: azure.Azure{ + SubscriptionID: "sub", + Location: azure.LocationWestUS2, + }, + ResourceGroup: "rg", + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + ch := w.WatchBuildLogs(ctx) + for range ch { + } + // If we got here, the channel was closed — good. +} diff --git a/src/pkg/clouds/azure/cd/blob.go b/src/pkg/clouds/azure/cd/blob.go new file mode 100644 index 000000000..922780c64 --- /dev/null +++ b/src/pkg/clouds/azure/cd/blob.go @@ -0,0 +1,98 @@ +package cd + +import ( + "context" + "errors" + "fmt" + "io" + "iter" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" +) + +const maxBlobDownloadSize = 32 * 1024 * 1024 // 32 MiB + +// BlobItem represents a blob in the storage account container. +type BlobItem struct { + name string + size int64 +} + +func (b BlobItem) Name() string { return b.name } +func (b BlobItem) Size() int64 { return b.size } + +func (d *Driver) newSharedKeyCredential(ctx context.Context) (*azblob.SharedKeyCredential, error) { + storageKey := os.Getenv("AZURE_STORAGE_KEY") + if storageKey == "" { + accountsClient, err := d.NewStorageAccountsClient() + if err != nil { + return nil, err + } + keys, err := accountsClient.ListKeys(ctx, d.resourceGroupName, d.StorageAccount, nil) + if err != nil { + return nil, err + } + if len(keys.Keys) == 0 || keys.Keys[0].Value == nil { + return nil, errors.New("no storage account keys returned") + } + storageKey = *keys.Keys[0].Value + } + return azblob.NewSharedKeyCredential(d.StorageAccount, storageKey) +} + +func (d *Driver) newBlobContainerClient(ctx context.Context, containerName string) (*container.Client, error) { + keyCred, err := d.newSharedKeyCredential(ctx) + if err != nil { + return nil, err + } + containerURL := fmt.Sprintf("https://%s.blob.core.windows.net/%s", d.StorageAccount, containerName) + return container.NewClientWithSharedKeyCredential(containerURL, keyCred, nil) +} + +// IterateBlobsInContainer is the container-explicit variant of IterateBlobs. +func (d *Driver) IterateBlobsInContainer(ctx context.Context, containerName, prefix string) (iter.Seq2[BlobItem, error], error) { + client, err := d.newBlobContainerClient(ctx, containerName) + if err != nil { + return nil, err + } + pager := client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + Prefix: &prefix, + }) + return func(yield func(BlobItem, error) bool) { + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + yield(BlobItem{}, err) + return + } + for _, item := range page.Segment.BlobItems { + if item.Name == nil { + continue + } + var size int64 + if item.Properties != nil && item.Properties.ContentLength != nil { + size = *item.Properties.ContentLength + } + if !yield(BlobItem{name: *item.Name, size: size}, nil) { + return + } + } + } + }, nil +} + +// DownloadBlobFromContainer is the container-explicit variant of DownloadBlob. +func (d *Driver) DownloadBlobFromContainer(ctx context.Context, containerName, blobName string) ([]byte, error) { + client, err := d.newBlobContainerClient(ctx, containerName) + if err != nil { + return nil, err + } + resp, err := client.NewBlobClient(blobName).DownloadStream(ctx, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return io.ReadAll(io.LimitReader(resp.Body, maxBlobDownloadSize)) +} diff --git a/src/pkg/clouds/azure/cd/driver.go b/src/pkg/clouds/azure/cd/driver.go new file mode 100644 index 000000000..c9fa5c720 --- /dev/null +++ b/src/pkg/clouds/azure/cd/driver.go @@ -0,0 +1,49 @@ +package cd + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +type Driver struct { + azure.Azure + resourceGroupPrefix string + resourceGroupName string + StorageAccount string + BlobContainerName string +} + +func New(resourceGroupPrefix string, location azure.Location) *Driver { + d := &Driver{ + Azure: azure.Azure{ + Location: location, + }, + resourceGroupPrefix: resourceGroupPrefix, + } + d.resourceGroupName = resourceGroupPrefix + "-" + location.String() + return d +} + +func (d *Driver) ResourceGroupName() string { + return d.resourceGroupName +} + +// SetLocation updates the location and recomputes the resource group name. +func (d *Driver) SetLocation(loc azure.Location) { + d.Location = loc + d.resourceGroupName = d.resourceGroupPrefix + "-" + loc.String() +} + +func (d *Driver) newResourceGroupClient() (*armresources.ResourceGroupsClient, error) { + cred, err := d.NewCreds() + if err != nil { + return nil, err + } + client, err := armresources.NewResourceGroupsClient(d.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create resource group client: %w", err) + } + return client, nil +} diff --git a/src/pkg/clouds/azure/cd/driver_test.go b/src/pkg/clouds/azure/cd/driver_test.go new file mode 100644 index 000000000..d11bcd36b --- /dev/null +++ b/src/pkg/clouds/azure/cd/driver_test.go @@ -0,0 +1,319 @@ +package cd + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +type fakeCred struct { + tok string + err error +} + +func (f fakeCred) GetToken(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) { + if f.err != nil { + return azcore.AccessToken{}, f.err + } + return azcore.AccessToken{Token: f.tok, ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func useFakeCred(t *testing.T, tok string, gerr error) { + t.Helper() + orig := azure.NewCredsFunc + azure.NewCredsFunc = func(_ azure.Azure) (azcore.TokenCredential, error) { + return fakeCred{tok: tok, err: gerr}, nil + } + t.Cleanup(func() { azure.NewCredsFunc = orig }) +} + +func TestNew(t *testing.T) { + d := New("defang-cd", azure.LocationWestUS2) + if d == nil { + t.Fatal("New returned nil") + } + if d.Location != azure.LocationWestUS2 { + t.Errorf("Location = %q, want westus2", d.Location) + } + if d.resourceGroupPrefix != "defang-cd" { + t.Errorf("resourceGroupPrefix = %q", d.resourceGroupPrefix) + } + if got := d.ResourceGroupName(); got != "defang-cd-westus2" { + t.Errorf("ResourceGroupName() = %q, want defang-cd-westus2", got) + } +} + +func TestNewEmptyLocation(t *testing.T) { + d := New("defang-cd", "") + if d == nil { + t.Fatal("New returned nil") + } + if got := d.ResourceGroupName(); got != "defang-cd-" { + t.Errorf("ResourceGroupName() = %q, want defang-cd-", got) + } +} + +func TestSetLocation(t *testing.T) { + d := New("defang-cd", azure.LocationEastUS) + if got := d.ResourceGroupName(); got != "defang-cd-eastus" { + t.Errorf("initial ResourceGroupName = %q", got) + } + d.SetLocation(azure.LocationWestUS3) + if d.Location != azure.LocationWestUS3 { + t.Errorf("Location not updated: %q", d.Location) + } + if got := d.ResourceGroupName(); got != "defang-cd-westus3" { + t.Errorf("ResourceGroupName after SetLocation = %q, want defang-cd-westus3", got) + } +} + +func TestBlobItem(t *testing.T) { + b := BlobItem{name: "a/b/c.tar.gz", size: 42} + if b.Name() != "a/b/c.tar.gz" { + t.Errorf("Name() = %q", b.Name()) + } + if b.Size() != 42 { + t.Errorf("Size() = %d", b.Size()) + } +} + +func TestNewResourceGroupClientMissingSubscription(t *testing.T) { + // Default NewCredsFunc requires AZURE_SUBSCRIPTION_ID. + orig := azure.NewCredsFunc + azure.NewCredsFunc = func(a azure.Azure) (azcore.TokenCredential, error) { + if a.SubscriptionID == "" { + return nil, errors.New("AZURE_SUBSCRIPTION_ID is not set") + } + return fakeCred{tok: "x"}, nil + } + t.Cleanup(func() { azure.NewCredsFunc = orig }) + + d := New("defang-cd", azure.LocationWestUS2) + if _, err := d.newResourceGroupClient(); err == nil { + t.Error("newResourceGroupClient should fail without subscription ID") + } +} + +func TestNewResourceGroupClientOK(t *testing.T) { + useFakeCred(t, "x", nil) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + if _, err := d.newResourceGroupClient(); err != nil { + t.Errorf("newResourceGroupClient: %v", err) + } +} + +func TestCreateResourceGroupCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + if err := d.CreateResourceGroup(t.Context(), "rg"); err == nil { + t.Error("CreateResourceGroup should surface credential error") + } +} + +func TestTearDownCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + if err := d.TearDown(t.Context()); err == nil { + t.Error("TearDown should surface credential error") + } +} + +func TestGetStorageAccountFromField(t *testing.T) { + d := New("defang-cd", azure.LocationWestUS2) + d.StorageAccount = "myacct" + got, err := d.getStorageAccount(t.Context(), nil) + if err != nil { + t.Fatalf("getStorageAccount: %v", err) + } + if got != "myacct" { + t.Errorf("got = %q, want myacct", got) + } +} + +func TestGetStorageAccountFromEnv(t *testing.T) { + t.Setenv("AZURE_STORAGE_ACCOUNT", "envacct") + d := New("defang-cd", azure.LocationWestUS2) + got, err := d.getStorageAccount(t.Context(), nil) + if err != nil { + t.Fatalf("getStorageAccount: %v", err) + } + if got != "envacct" { + t.Errorf("got = %q, want envacct", got) + } +} + +func TestSetUpStorageAccountIdempotent(t *testing.T) { + d := New("defang-cd", azure.LocationWestUS2) + d.StorageAccount = "acct" + d.BlobContainerName = "uploads" + got, err := d.SetUpStorageAccount(t.Context()) + if err != nil { + t.Fatalf("SetUpStorageAccount: %v", err) + } + if got != "acct" { + t.Errorf("SetUpStorageAccount returned %q, want acct", got) + } +} + +func TestSetUpStorageAccountCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + if _, err := d.SetUpStorageAccount(t.Context()); err == nil { + t.Error("SetUpStorageAccount should fail on bad cred") + } +} + +func TestSetUpResourceGroupCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + if err := d.SetUpResourceGroup(t.Context()); err == nil { + t.Error("SetUpResourceGroup should fail on bad cred") + } +} + +func TestCreateUploadURLTooLong(t *testing.T) { + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + long := strings.Repeat("x", 65) + if _, err := d.CreateUploadURL(t.Context(), long); err == nil { + t.Error("CreateUploadURL should reject names > 64 chars") + } +} + +func TestCreateUploadURLStorageAccountSetupFails(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + // No pre-populated StorageAccount — SetUpStorageAccount will be called and fail. + if _, err := d.CreateUploadURL(t.Context(), ""); err == nil { + t.Error("CreateUploadURL should fail when setup fails") + } +} + +func TestCreateUploadURLWithStorageKeyEnv(t *testing.T) { + // With AZURE_STORAGE_KEY set and StorageAccount pre-populated, we skip + // all ARM calls and produce a SAS URL. + t.Setenv("AZURE_STORAGE_KEY", "dGVzdC1rZXk=") // base64-encoded fake key + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "acct" + d.BlobContainerName = "uploads" + + got, err := d.CreateUploadURL(t.Context(), "myblob") + if err != nil { + t.Fatalf("CreateUploadURL: %v", err) + } + if got == "" { + t.Error("CreateUploadURL returned empty URL") + } + if !strings.Contains(got, "acct.blob.core.windows.net") || !strings.Contains(got, "myblob") { + t.Errorf("URL %q does not look like a SAS URL", got) + } +} + +func TestCreateUploadURLSanitizesSlash(t *testing.T) { + t.Setenv("AZURE_STORAGE_KEY", "dGVzdC1rZXk=") + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "acct" + d.BlobContainerName = "uploads" + + got, err := d.CreateUploadURL(t.Context(), "sha256:abc/def") + if err != nil { + t.Fatalf("CreateUploadURL: %v", err) + } + // Slash in the digest should be replaced with underscore so it's a safe + // blob name. + if !strings.Contains(got, "sha256%3Aabc_def") && !strings.Contains(got, "sha256:abc_def") { + t.Errorf("URL %q did not sanitize slash", got) + } +} + +func TestNewSharedKeyCredentialFromEnv(t *testing.T) { + t.Setenv("AZURE_STORAGE_KEY", "dGVzdC1rZXk=") + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "myacct" + cred, err := d.newSharedKeyCredential(t.Context()) + if err != nil { + t.Fatalf("newSharedKeyCredential: %v", err) + } + if cred == nil { + t.Error("cred should not be nil") + } +} + +func TestNewSharedKeyCredentialCredError(t *testing.T) { + // No AZURE_STORAGE_KEY and the ARM path fails. + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "acct" + if _, err := d.newSharedKeyCredential(t.Context()); err == nil { + t.Error("newSharedKeyCredential should fail without key") + } +} + +func TestNewBlobContainerClientFromEnv(t *testing.T) { + t.Setenv("AZURE_STORAGE_KEY", "dGVzdC1rZXk=") + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "myacct" + d.BlobContainerName = "uploads" + client, err := d.newBlobContainerClient(t.Context(), "uploads") + if err != nil { + t.Fatalf("newBlobContainerClient: %v", err) + } + if client == nil { + t.Error("client should not be nil") + } +} + +func TestIterateBlobsCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "acct" + d.BlobContainerName = "uploads" + if _, err := d.IterateBlobsInContainer(t.Context(), "uploads", ".pulumi/stacks/"); err == nil { + t.Error("IterateBlobs should fail without key") + } +} + +func TestDownloadBlobCredError(t *testing.T) { + useFakeCred(t, "", errors.New("denied")) + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "acct" + d.BlobContainerName = "uploads" + if _, err := d.DownloadBlobFromContainer(t.Context(), "uploads", "blob"); err == nil { + t.Error("DownloadBlob should fail without key") + } +} + +func TestCreateUploadURLGenerateBlobName(t *testing.T) { + // With empty blobName the driver generates a UUID. + t.Setenv("AZURE_STORAGE_KEY", "dGVzdC1rZXk=") + d := New("defang-cd", azure.LocationWestUS2) + d.SubscriptionID = "sub" + d.StorageAccount = "acct" + d.BlobContainerName = "uploads" + url, err := d.CreateUploadURL(t.Context(), "") + if err != nil { + t.Fatalf("CreateUploadURL: %v", err) + } + if url == "" { + t.Error("URL should be non-empty") + } +} diff --git a/src/pkg/clouds/azure/cd/setup.go b/src/pkg/clouds/azure/cd/setup.go new file mode 100644 index 000000000..d3cdaa49b --- /dev/null +++ b/src/pkg/clouds/azure/cd/setup.go @@ -0,0 +1,177 @@ +package cd + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage/v2" + "github.com/DefangLabs/defang/src/pkg" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" + "github.com/DefangLabs/defang/src/pkg/term" +) + +const storageAccountPrefix = "defangcd" + +// Container names used in the CD storage account. Keep them DNS-safe: +// 3–63 chars, lowercase alphanumeric + hyphens (no leading/trailing hyphen). +const ( + // UploadsContainerName holds per-deploy payloads (etag blobs) and source tarballs. + UploadsContainerName = "uploads" + // PulumiContainerName is the dedicated Pulumi state backend container. + PulumiContainerName = "pulumi" + // ProjectsContainerName holds {project}/{stack}/project.pb audit blobs + // written by the CD task before each deploy. + ProjectsContainerName = "projects" + + // blobContainerName is kept for backward compatibility with existing + // callers that default to the uploads container. + blobContainerName = UploadsContainerName +) + +// CreateResourceGroup creates or updates an Azure resource group with the given name. +func (d *Driver) CreateResourceGroup(ctx context.Context, name string) error { + rgClient, err := d.newResourceGroupClient() + if err != nil { + return err + } + _, err = rgClient.CreateOrUpdate(ctx, name, armresources.ResourceGroup{ + Location: d.Location.Ptr(), + }, nil) + if err != nil { + return fmt.Errorf("failed to create resource group %q: %w", name, err) + } + return nil +} + +// SetUpResourceGroup creates or updates the shared CD resource group (defang-cd-{location}). +func (d *Driver) SetUpResourceGroup(ctx context.Context) error { + return d.CreateResourceGroup(ctx, d.resourceGroupName) +} + +func (d *Driver) TearDown(ctx context.Context) error { + rgClient, err := d.newResourceGroupClient() + if err != nil { + return err + } + deletePoller, err := rgClient.BeginDelete(ctx, d.resourceGroupName, nil) + if err != nil { + return fmt.Errorf("failed to delete resource group: %w", err) + } + _, err = deletePoller.PollUntilDone(ctx, azure.PollOptions) + return err +} + +func (d *Driver) getStorageAccount(ctx context.Context, accountsClient *armstorage.AccountsClient) (string, error) { + if d.StorageAccount != "" { + return d.StorageAccount, nil + } + + if sa := os.Getenv("AZURE_STORAGE_ACCOUNT"); sa != "" { + return sa, nil + } + + for pager := accountsClient.NewListByResourceGroupPager(d.resourceGroupName, nil); pager.More(); { + page, err := pager.NextPage(ctx) + if err != nil { + return "", fmt.Errorf("failed to list storage accounts: %w", err) + } + for _, account := range page.Value { + if strings.HasPrefix(*account.Name, storageAccountPrefix) && *account.Location == d.Location.String() { + return *account.Name, nil + } + } + } + return "", nil +} + +// FindStorageAccount is a read-only variant of SetUpStorageAccount: it locates +// the defang CD storage account (and remembers its container) without +// creating anything. Returns ("", nil) when the storage account or blob +// container doesn't exist yet — typical for a subscription where defang has +// never been deployed. On success, d.StorageAccount and d.BlobContainerName +// are populated for subsequent DownloadBlob / IterateBlobs calls. +func (d *Driver) FindStorageAccount(ctx context.Context) (string, error) { + if d.StorageAccount != "" && d.BlobContainerName != "" { + return d.StorageAccount, nil + } + accountsClient, err := d.NewStorageAccountsClient() + if err != nil { + return "", err + } + storageAccount, err := d.getStorageAccount(ctx, accountsClient) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == 404 { + return "", nil // resource group doesn't exist yet + } + return "", err + } + if storageAccount == "" { + return "", nil + } + d.StorageAccount = storageAccount + // The blob container is always created with the well-known name; its + // existence is implied by the storage account being present on a + // defang-managed subscription. We don't verify it here — DownloadBlob / + // IterateBlobs will return 404 if it doesn't exist yet. + d.BlobContainerName = blobContainerName + return storageAccount, nil +} + +func (d *Driver) SetUpStorageAccount(ctx context.Context) (string, error) { + // Idempotency: skip if already set up. + if d.StorageAccount != "" && d.BlobContainerName != "" { + return d.StorageAccount, nil + } + + accountsClient, err := d.NewStorageAccountsClient() + if err != nil { + return "", err + } + + storageAccount, err := d.getStorageAccount(ctx, accountsClient) + if err != nil { + return "", err + } + + if storageAccount == "" { + storageAccount = storageAccountPrefix + pkg.RandomID() + createPoller, err := accountsClient.BeginCreate(ctx, d.resourceGroupName, storageAccount, armstorage.AccountCreateParameters{ + Kind: to.Ptr(armstorage.KindStorageV2), + Location: d.Location.Ptr(), + SKU: &armstorage.SKU{Name: to.Ptr(armstorage.SKUNameStandardLRS)}, + }, nil) + if err != nil { + return "", fmt.Errorf("failed to create storage account: %w", err) + } + _, err = createPoller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return "", fmt.Errorf("failed to poll storage account creation: %w", err) + } + } + d.StorageAccount = storageAccount + + containerClient, err := d.NewBlobContainersClient() + if err != nil { + return "", fmt.Errorf("failed to create blob containers client: %w", err) + } + for _, name := range []string{UploadsContainerName, PulumiContainerName, ProjectsContainerName} { + if _, err := containerClient.Create(ctx, d.resourceGroupName, storageAccount, name, armstorage.BlobContainer{}, nil); err != nil { + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) || respErr.ErrorCode != "ContainerAlreadyExists" { + return "", fmt.Errorf("failed to create blob container %q: %w", name, err) + } + } + } + d.BlobContainerName = UploadsContainerName + + term.Infof("Using storage account %s (containers: %s, %s, %s)", storageAccount, UploadsContainerName, PulumiContainerName, ProjectsContainerName) + + return storageAccount, nil +} diff --git a/src/pkg/clouds/azure/cd/upload.go b/src/pkg/clouds/azure/cd/upload.go new file mode 100644 index 000000000..1e4b6d8b8 --- /dev/null +++ b/src/pkg/clouds/azure/cd/upload.go @@ -0,0 +1,69 @@ +package cd + +import ( + "context" + "errors" + "fmt" + "net/url" + "os" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas" + "github.com/google/uuid" +) + +func (d *Driver) CreateUploadURL(ctx context.Context, blobName string) (string, error) { + if blobName == "" { + blobName = uuid.NewString() + } else { + if len(blobName) > 64 { + return "", errors.New("name must be less than 64 characters") + } + // Sanitize the digest so it's safe to use as a file name + blobName = strings.ReplaceAll(blobName, "/", "_") + } + if _, err := d.SetUpStorageAccount(ctx); err != nil { + return "", err + } + + expiry := time.Now().UTC().Add(1 * time.Hour) + + storageKey := os.Getenv("AZURE_STORAGE_KEY") + if storageKey == "" { + accountsClient, err := d.NewStorageAccountsClient() + if err != nil { + return "", err + } + keys, err := accountsClient.ListKeys(ctx, d.resourceGroupName, d.StorageAccount, nil) + if err != nil { + return "", err + } + if len(keys.Keys) == 0 || keys.Keys[0].Value == nil { + return "", errors.New("no storage account keys returned") + } + storageKey = *keys.Keys[0].Value + } + + keyCred, err := azblob.NewSharedKeyCredential(d.StorageAccount, storageKey) + if err != nil { + return "", err + } + + perms := sas.BlobPermissions{Create: true, Write: true, Read: true} + sasQueryParams, err := sas.BlobSignatureValues{ + BlobName: blobName, + ContainerName: d.BlobContainerName, + ExpiryTime: expiry, + Permissions: perms.String(), + Protocol: sas.ProtocolHTTPS, + }.SignWithSharedKey(keyCred) + if err != nil { + return "", err + } + + serviceURL := fmt.Sprintf("https://%s.blob.core.windows.net/", d.StorageAccount) + sasURL := fmt.Sprintf("%s%s/%s?%s", serviceURL, d.BlobContainerName, url.PathEscape(blobName), sasQueryParams.Encode()) + return sasURL, nil +} diff --git a/src/pkg/clouds/azure/common.go b/src/pkg/clouds/azure/common.go new file mode 100644 index 000000000..937e11b2d --- /dev/null +++ b/src/pkg/clouds/azure/common.go @@ -0,0 +1,183 @@ +package azure + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage/v2" + "github.com/DefangLabs/defang/src/pkg/tokenstore" +) + +var PollOptions = &runtime.PollUntilDoneOptions{Frequency: 5 * time.Second} + +// cliTimeout overrides the default 10s timeout for CLI-based credentials. +// The Azure CLI can be slow to start, especially when installed via Nix. +// cliTimeout overrides the SDK's default 10s budget for AzureCLICredential. +// `az account get-access-token` can exceed 30s when several pollers refresh +// concurrently after token expiry, when `az` cold-starts (Python interpreter + +// module imports), or when AAD round-trips for refresh. 90s gives enough +// headroom for these compounded delays without making startup feel hung. +const cliTimeout = 90 * time.Second + +type Azure struct { + Location Location + SubscriptionID string + // Cred is populated by Authenticate and, when non-nil, is returned by + // NewCreds instead of building a fresh DefaultAzureCredential. + Cred azcore.TokenCredential + // TokenStore persists the AuthenticationRecord returned by the + // device-code flow so future invocations can silently reuse the user's + // session (the actual refresh token lives in the OS-level token cache, + // not the TokenStore). + TokenStore tokenstore.TokenStore +} + +// tokenCredentialWithTimeout wraps an azcore.TokenCredential to ensure +// GetToken has a minimum deadline, overriding the SDK's default 10s CLI timeout. +type tokenCredentialWithTimeout struct { + cred azcore.TokenCredential + timeout time.Duration +} + +func (t *tokenCredentialWithTimeout) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, t.timeout) + defer cancel() + } + return t.cred.GetToken(ctx, opts) +} + +// NewCredsFunc builds a TokenCredential for ARM calls. Tests can override this +// to inject a fake credential; the default implementation returns any cred +// populated by Authenticate, falling back to DefaultAzureCredential. +var NewCredsFunc = func(a Azure) (azcore.TokenCredential, error) { + if a.Cred != nil { + return a.Cred, nil + } + if len(a.SubscriptionID) == 0 { + return nil, errors.New("environment variable AZURE_SUBSCRIPTION_ID is not set") + } + + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, fmt.Errorf("failed to create default Azure credentials: %w", err) + } + + return &tokenCredentialWithTimeout{cred: cred, timeout: cliTimeout}, nil +} + +func (a Azure) NewCreds() (azcore.TokenCredential, error) { + return NewCredsFunc(a) +} + +// ManagementEndpoint is the base URL for Azure Resource Manager REST calls. +// It is exposed as a variable so tests can swap in an httptest.Server URL. +var ManagementEndpoint = "https://management.azure.com" + +// ArmToken returns a Bearer token scoped to the Azure management endpoint, +// suitable for direct REST API calls that the ARM SDK does not expose. +func (a Azure) ArmToken(ctx context.Context) (string, error) { + cred, err := a.NewCreds() + if err != nil { + return "", err + } + tok, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err != nil { + return "", err + } + return tok.Token, nil +} + +// FetchLogStreamAuthToken POSTs to the `getAuthToken` action on an ACA resource +// (container app or job) and returns the short-lived token that the resource's +// log-stream endpoint accepts. resourcePath is the segment after +// "providers/", e.g. "Microsoft.App/containerApps/{name}" or +// "Microsoft.App/jobs/{name}". +func (a Azure) FetchLogStreamAuthToken(ctx context.Context, resourceGroup, resourcePath, apiVersion string) (string, error) { + armTok, err := a.ArmToken(ctx) + if err != nil { + return "", err + } + + url := fmt.Sprintf( + "%s/subscriptions/%s/resourceGroups/%s/providers/%s/getAuthToken?api-version=%s", + ManagementEndpoint, a.SubscriptionID, resourceGroup, resourcePath, apiVersion, + ) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, http.NoBody) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+armTok) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("getAuthToken: HTTP %s", resp.Status) + } + + var result struct { + Properties struct { + Token string `json:"token"` + } `json:"properties"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("getAuthToken: decode response: %w", err) + } + return result.Properties.Token, nil +} + +func (a Azure) NewStorageAccountsClient() (*armstorage.AccountsClient, error) { + cred, err := a.NewCreds() + if err != nil { + return nil, err + } + + clientFactory, err := armstorage.NewClientFactory(a.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create storage client: %w", err) + } + + return clientFactory.NewAccountsClient(), nil +} + +func (a Azure) NewBlobContainersClient() (*armstorage.BlobContainersClient, error) { + cred, err := a.NewCreds() + if err != nil { + return nil, err + } + + clientFactory, err := armstorage.NewClientFactory(a.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create storage client: %w", err) + } + + return clientFactory.NewBlobContainersClient(), nil +} + +// func (a Azure) NewRoleAssignmentsClient() (*armauthorization.RoleAssignmentsClient, error) { +// cred, err := a.NewCreds() +// if err != nil { +// return nil, err +// } + +// clientFactory, err := armauthorization.NewRoleAssignmentsClient(a.SubscriptionID, cred, nil) +// if err != nil { +// return nil, fmt.Errorf("failed to create role assignments client: %w", err) +// } + +// return clientFactory, nil +// } diff --git a/src/pkg/clouds/azure/common_test.go b/src/pkg/clouds/azure/common_test.go new file mode 100644 index 000000000..8240e7173 --- /dev/null +++ b/src/pkg/clouds/azure/common_test.go @@ -0,0 +1,154 @@ +package azure + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// fakeCredential is a stub TokenCredential used to bypass azidentity in tests. +type fakeCredential struct { + token string + err error +} + +func (f fakeCredential) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + if f.err != nil { + return azcore.AccessToken{}, f.err + } + return azcore.AccessToken{Token: f.token, ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +// useFakeCred swaps in a fake credential for the duration of the test. +func useFakeCred(t *testing.T, tok string, gerr error) { + t.Helper() + orig := NewCredsFunc + NewCredsFunc = func(_ Azure) (azcore.TokenCredential, error) { + return fakeCredential{token: tok, err: gerr}, nil + } + t.Cleanup(func() { NewCredsFunc = orig }) +} + +// useTestEndpoint swaps ManagementEndpoint to the httptest.Server URL. +func useTestEndpoint(t *testing.T, url string) { + t.Helper() + orig := ManagementEndpoint + ManagementEndpoint = url + t.Cleanup(func() { ManagementEndpoint = orig }) +} + +func TestArmToken(t *testing.T) { + useFakeCred(t, "my-arm-token", nil) + a := Azure{SubscriptionID: "sub"} + got, err := a.ArmToken(context.Background()) + if err != nil { + t.Fatalf("ArmToken: %v", err) + } + if got != "my-arm-token" { + t.Errorf("ArmToken = %q", got) + } +} + +func TestArmTokenCredError(t *testing.T) { + useFakeCred(t, "", errors.New("auth failed")) + a := Azure{SubscriptionID: "sub"} + if _, err := a.ArmToken(context.Background()); err == nil { + t.Error("ArmToken should propagate credential error") + } +} + +func TestNewCredsMissingSub(t *testing.T) { + // Reset to the default implementation so we hit the AZURE_SUBSCRIPTION_ID check. + orig := NewCredsFunc + NewCredsFunc = func(a Azure) (azcore.TokenCredential, error) { + if a.SubscriptionID == "" { + return nil, errors.New("AZURE_SUBSCRIPTION_ID is not set") + } + return fakeCredential{token: "x"}, nil + } + t.Cleanup(func() { NewCredsFunc = orig }) + + a := Azure{} + if _, err := a.NewCreds(); err == nil { + t.Error("NewCreds should fail without subscription ID") + } +} + +func TestFetchLogStreamAuthToken(t *testing.T) { + var gotAuth, gotMethod string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotMethod = r.Method + if !strings.Contains(r.URL.Path, "Microsoft.App/jobs/defang-cd/getAuthToken") { + t.Errorf("path = %q, want contains Microsoft.App/jobs/defang-cd/getAuthToken", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"properties": {"token": "stream-token-abc"}}`)) + })) + defer srv.Close() + + useFakeCred(t, "arm-token", nil) + useTestEndpoint(t, srv.URL) + + a := Azure{SubscriptionID: "sub"} + got, err := a.FetchLogStreamAuthToken(context.Background(), "rg", "Microsoft.App/jobs/defang-cd", "2024-02-02-preview") + if err != nil { + t.Fatalf("FetchLogStreamAuthToken: %v", err) + } + if got != "stream-token-abc" { + t.Errorf("token = %q, want stream-token-abc", got) + } + if gotAuth != "Bearer arm-token" { + t.Errorf("Authorization header = %q", gotAuth) + } + if gotMethod != http.MethodPost { + t.Errorf("method = %q, want POST", gotMethod) + } +} + +func TestFetchLogStreamAuthTokenNonOK(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + useFakeCred(t, "arm-token", nil) + useTestEndpoint(t, srv.URL) + + a := Azure{SubscriptionID: "sub"} + _, err := a.FetchLogStreamAuthToken(context.Background(), "rg", "Microsoft.App/jobs/x", "2024-02-02-preview") + if err == nil { + t.Error("expected error for 403 response") + } +} + +func TestFetchLogStreamAuthTokenArmTokenError(t *testing.T) { + useFakeCred(t, "", errors.New("arm denied")) + a := Azure{SubscriptionID: "sub"} + if _, err := a.FetchLogStreamAuthToken(context.Background(), "rg", "x", "v"); err == nil { + t.Error("expected error when ArmToken fails") + } +} + +func TestFetchLogStreamAuthTokenBadJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`not-json`)) + })) + defer srv.Close() + + useFakeCred(t, "arm-token", nil) + useTestEndpoint(t, srv.URL) + + a := Azure{SubscriptionID: "sub"} + _, err := a.FetchLogStreamAuthToken(context.Background(), "rg", "x", "v") + if err == nil { + t.Error("expected decode error") + } +} diff --git a/src/pkg/clouds/azure/keyvault/keyvault.go b/src/pkg/clouds/azure/keyvault/keyvault.go new file mode 100644 index 000000000..bfb81d702 --- /dev/null +++ b/src/pkg/clouds/azure/keyvault/keyvault.go @@ -0,0 +1,424 @@ +package keyvault + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" + "github.com/DefangLabs/defang/src/pkg/term" + "github.com/google/uuid" +) + +const vaultNameSuffixLen = 8 + +// Key Vault Secrets Officer allows read/write/delete of secrets. +const keyVaultSecretsOfficerRoleID = "b86a8fe4-44ce-4948-aee5-eccb2c155cd7" // nolint:gosec + +// VaultName returns a deterministic, globally-unique vault name (max 24 chars) +// for the given resource group in the given subscription. +func VaultName(resourceGroupName, subscriptionID string) string { + h := sha256.Sum256([]byte(subscriptionID + "|" + resourceGroupName)) + suffix := hex.EncodeToString(h[:])[:vaultNameSuffixLen] + name := "defang-config-" + suffix + return name +} + +// VaultURL returns the data-plane URL for the vault. +func VaultURL(vaultName string) string { + return "https://" + vaultName + ".vault.azure.net" +} + +// ToSecretName converts a config key path (e.g. "/Defang/myapp/test/POSTGRES_PASSWORD") +// to a Key Vault-safe secret name. Slashes become "--", underscores become "-". +func ToSecretName(key string) string { + key = strings.TrimPrefix(key, "/") + key = strings.ReplaceAll(key, "/", "--") + key = strings.ReplaceAll(key, "_", "-") + return key +} + +// KeyVault wraps an Azure Key Vault for storing project config secrets. +type KeyVault struct { + azure.Azure + resourceGroupName string + VaultName string + vaultURL string +} + +// New builds a KeyVault client rooted in the given resource group. The Azure +// value is copied in full so that an authenticated credential (Azure.Cred, +// set by Authenticate) propagates to subsequent SDK calls instead of each +// component silently falling back to DefaultAzureCredential. +func New(resourceGroupName string, az azure.Azure) *KeyVault { + return &KeyVault{ + Azure: az, + resourceGroupName: resourceGroupName, + } +} + +func (kv *KeyVault) getTenantID(ctx context.Context, cred azcore.TokenCredential) (string, error) { + client, err := armsubscriptions.NewClient(cred, nil) + if err != nil { + return "", fmt.Errorf("creating subscriptions client: %w", err) + } + resp, err := client.Get(ctx, kv.SubscriptionID, nil) + if err != nil { + return "", fmt.Errorf("getting subscription: %w", err) + } + if resp.TenantID == nil || *resp.TenantID == "" { + return "", errors.New("subscription has no tenant ID") + } + return *resp.TenantID, nil +} + +// Find is a read-only variant of SetUp: it binds to an existing Key Vault by +// its deterministic VaultName without creating one. Returns (true, nil) when +// the vault exists, (false, nil) when it or its resource group doesn't, and +// (false, err) on any other failure. +func (kv *KeyVault) Find(ctx context.Context) (bool, error) { + cred, err := kv.NewCreds() + if err != nil { + return false, err + } + client, err := armkeyvault.NewVaultsClient(kv.SubscriptionID, cred, nil) + if err != nil { + return false, err + } + name := VaultName(kv.resourceGroupName, kv.SubscriptionID) + if _, err := client.Get(ctx, kv.resourceGroupName, name, nil); err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && (respErr.StatusCode == 404 || respErr.ErrorCode == "ResourceGroupNotFound" || respErr.ErrorCode == "ResourceNotFound" || respErr.ErrorCode == "VaultNotFound") { + return false, nil + } + return false, fmt.Errorf("looking up Key Vault %q: %w", name, err) + } + kv.VaultName = name + kv.vaultURL = VaultURL(name) + return true, nil +} + +// SetUp creates the Key Vault (using the deterministic VaultName) if it doesn't +// already exist. Uses RBAC authorization mode so the CLI user (who creates the vault) +// and the CD job identity can access secrets via role assignments. +func (kv *KeyVault) SetUp(ctx context.Context) error { + cred, err := kv.NewCreds() + if err != nil { + return err + } + + client, err := armkeyvault.NewVaultsClient(kv.SubscriptionID, cred, nil) + if err != nil { + return err + } + + kv.VaultName = VaultName(kv.resourceGroupName, kv.SubscriptionID) + kv.vaultURL = VaultURL(kv.VaultName) + + tenantID, err := kv.getTenantID(ctx, cred) + if err != nil { + return fmt.Errorf("failed to get tenant ID: %w", err) + } + + term.Debugf("Creating or updating Key Vault %s", kv.VaultName) + poller, err := client.BeginCreateOrUpdate(ctx, kv.resourceGroupName, kv.VaultName, armkeyvault.VaultCreateOrUpdateParameters{ + Location: kv.Location.Ptr(), + Properties: &armkeyvault.VaultProperties{ + TenantID: to.Ptr(tenantID), + EnableRbacAuthorization: to.Ptr(true), + EnableSoftDelete: to.Ptr(true), + SoftDeleteRetentionInDays: to.Ptr(int32(7)), + SKU: &armkeyvault.SKU{ + Family: to.Ptr(armkeyvault.SKUFamilyA), + Name: to.Ptr(armkeyvault.SKUNameStandard), + }, + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to create Key Vault: %w", err) + } + result, err := poller.PollUntilDone(ctx, azure.PollOptions) + if err != nil { + return fmt.Errorf("failed to poll Key Vault creation: %w", err) + } + + // Assign Key Vault Secrets Officer to the current user so the CLI can + // manage secrets. The vault uses RBAC, so even the creator needs an + // explicit role assignment. + if err := kv.assignSecretsOfficerRole(ctx, cred, *result.ID); err != nil { + return kv.wrapRoleAssignmentError(ctx, cred, *result.ID, err) + } + + return nil +} + +// EnsureSecretsOfficer assigns "Key Vault Secrets Officer" to the current +// caller on the bound vault, idempotently. Required after Find for shared +// stacks where the vault was created by another user (whose SetUp granted +// the role only to themselves). RoleAssignmentExists is treated as success, +// so this is safe to call on every config-list and incurs only one ARM PUT +// per CLI invocation (because callers cache the bound KeyVault). +func (kv *KeyVault) EnsureSecretsOfficer(ctx context.Context) error { + if kv.vaultURL == "" { + return errors.New("Key Vault not bound; call Find or SetUp first") + } + cred, err := kv.NewCreds() + if err != nil { + return err + } + vaultResourceID := fmt.Sprintf( + "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.KeyVault/vaults/%s", + kv.SubscriptionID, kv.resourceGroupName, kv.VaultName, + ) + if err := kv.assignSecretsOfficerRole(ctx, cred, vaultResourceID); err != nil { + return kv.wrapRoleAssignmentError(ctx, cred, vaultResourceID, err) + } + return nil +} + +// wrapRoleAssignmentError augments an assignSecretsOfficerRole failure with +// remediation guidance. Shared by SetUp (vault creation) and +// EnsureSecretsOfficer (existing-vault onboarding). +func (kv *KeyVault) wrapRoleAssignmentError(ctx context.Context, cred azcore.TokenCredential, vaultResourceID string, err error) error { + oid := kv.currentUserOID(ctx, cred) + if oid == "" { + oid = "" + } + return fmt.Errorf( + "assigning Key Vault Secrets Officer role failed: %w\n\n"+ + "Your current Azure identity (oid=%s) cannot write role assignments at subscription %s. Possible reasons:\n\n"+ + " 1. Your RBAC role on this subscription is Contributor — which does NOT include Microsoft.Authorization/roleAssignments/write. You need Owner or User Access Administrator. Check with:\n"+ + " az role assignment list --assignee %s --subscription %s -o table\n\n"+ + " 2. You hold an Azure AD / Entra ID directory role (e.g. Global Admin) but haven't elevated to Azure RBAC. Go to Entra ID → Properties → 'Access management for Azure resources' → Yes, then sign in again.\n\n"+ + " 3. Your Owner / UAA role is eligible under Privileged Identity Management (PIM) and must be activated for this session before running defang.\n\n"+ + " 4. You're a guest user in this tenant. Guests typically cannot create role assignments.\n\n"+ + "Workaround (run once as a subscription Owner):\n"+ + " az role assignment create --role 'Key Vault Secrets Officer' --assignee %s --scope %s", + err, oid, kv.SubscriptionID, oid, kv.SubscriptionID, oid, vaultResourceID) +} + +// currentUserOID returns the object ID of the caller behind cred, extracted +// from an ARM-scoped access token's "oid" claim. Returns empty string if the +// token can't be acquired or parsed — callers should render a placeholder. +func (kv *KeyVault) currentUserOID(ctx context.Context, cred azcore.TokenCredential) string { + tok, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err != nil { + return "" + } + return objectIDFromJWT(tok.Token) +} + +// assignSecretsOfficerRole assigns Key Vault Secrets Officer to the current caller. +func (kv *KeyVault) assignSecretsOfficerRole(ctx context.Context, cred azcore.TokenCredential, vaultResourceID string) error { + raClient, err := armauthorization.NewRoleAssignmentsClient(kv.SubscriptionID, cred, nil) + if err != nil { + return err + } + + // Get the caller's object ID from the token. + token, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err != nil { + return fmt.Errorf("getting token for caller OID: %w", err) + } + callerOID := objectIDFromJWT(token.Token) + if callerOID == "" { + return errors.New("could not extract object ID from token") + } + + roleDefID := fmt.Sprintf( + "/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", + kv.SubscriptionID, keyVaultSecretsOfficerRoleID, + ) + _, err = raClient.Create(ctx, vaultResourceID, uuid.NewString(), armauthorization.RoleAssignmentCreateParameters{ + Properties: &armauthorization.RoleAssignmentProperties{ + PrincipalID: to.Ptr(callerOID), + RoleDefinitionID: to.Ptr(roleDefID), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.ErrorCode == "RoleAssignmentExists" { + return nil + } + return err + } + return nil +} + +// objectIDFromJWT extracts the "oid" claim from a JWT access token without full parsing. +func objectIDFromJWT(token string) string { + parts := strings.SplitN(token, ".", 3) + if len(parts) < 2 { + return "" + } + // Pad base64url to standard base64. + payload := parts[1] + if m := len(payload) % 4; m != 0 { + payload += strings.Repeat("=", 4-m) + } + decoded, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + return "" + } + // Crude extraction — find "oid":"..." without pulling in encoding/json. + const needle = `"oid":"` + idx := strings.Index(string(decoded), needle) + if idx < 0 { + return "" + } + rest := string(decoded)[idx+len(needle):] + end := strings.IndexByte(rest, '"') + if end < 0 { + return "" + } + return rest[:end] +} + +func (kv *KeyVault) newSecretsClient() (*azsecrets.Client, error) { + if kv.vaultURL == "" { + return nil, errors.New("Key Vault not set up") + } + cred, err := kv.NewCreds() + if err != nil { + return nil, err + } + return azsecrets.NewClient(kv.vaultURL, cred, nil) +} + +// PutSecret creates or updates a secret in the vault. The originalKey tag +// preserves the exact config key name (which may contain underscores that +// were replaced in the secret name). +// +// Immediately after SetUp, Azure RBAC can take up to ~60s to propagate the +// Key Vault Secrets Officer role assignment to the vault's data plane. A +// transient 403 ForbiddenByRbac is therefore retried with backoff before +// giving up. +func (kv *KeyVault) PutSecret(ctx context.Context, name, value, originalKey string) error { + client, err := kv.newSecretsClient() + if err != nil { + return err + } + params := azsecrets.SetSecretParameters{ + Value: to.Ptr(value), + Tags: map[string]*string{ + "original-key": to.Ptr(originalKey), + }, + } + return retryOnForbiddenByRbac(ctx, func(ctx context.Context) error { + _, err := client.SetSecret(ctx, name, params, nil) + return err + }) +} + +// retryOnForbiddenByRbac retries op with exponential backoff while it fails +// with 403 ForbiddenByRbac — the canonical signature of a freshly-assigned +// Key Vault role that hasn't propagated yet. Gives up after ~60s total. +func retryOnForbiddenByRbac(ctx context.Context, op func(context.Context) error) error { + const maxAttempts = 6 + delay := 2 * time.Second + for attempt := 0; ; attempt++ { + err := op(ctx) + if err == nil { + return nil + } + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) || respErr.ErrorCode != "ForbiddenByRbac" || attempt >= maxAttempts-1 { + return err + } + term.Debugf("Key Vault returned ForbiddenByRbac (likely RBAC propagation), retrying in %s (attempt %d/%d)", delay, attempt+1, maxAttempts) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + delay *= 2 + } +} + +// DeleteSecret removes a secret from the vault. +func (kv *KeyVault) DeleteSecret(ctx context.Context, name string) error { + client, err := kv.newSecretsClient() + if err != nil { + return err + } + return retryOnForbiddenByRbac(ctx, func(ctx context.Context) error { + _, err := client.DeleteSecret(ctx, name, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == 404 { + return nil + } + } + return err + }) +} + +// SecretEntry holds a secret's metadata returned by ListSecrets. +type SecretEntry struct { + Name string + OriginalKey string +} + +// ListSecrets returns secrets whose names start with the given prefix. +// It uses the "original-key" tag to recover the original config key name. +func (kv *KeyVault) ListSecrets(ctx context.Context, prefix string) ([]SecretEntry, error) { + client, err := kv.newSecretsClient() + if err != nil { + return nil, err + } + + var entries []SecretEntry + err = retryOnForbiddenByRbac(ctx, func(ctx context.Context) error { + entries = entries[:0] + pager := client.NewListSecretPropertiesPager(nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return fmt.Errorf("failed to list secrets: %w", err) + } + for _, props := range page.Value { + if props.ID == nil { + continue + } + name := props.ID.Name() + if !strings.HasPrefix(name, prefix) { + continue + } + entry := SecretEntry{Name: name} + if props.Tags != nil { + if orig, ok := props.Tags["original-key"]; ok && orig != nil { + entry.OriginalKey = *orig + } + } + entries = append(entries, entry) + } + } + return nil + }) + if err != nil { + return nil, err + } + return entries, nil +} + +// SecretURL returns the Key Vault URL for a specific secret, suitable for +// Container App Key Vault secret references. +func (kv *KeyVault) SecretURL(secretName string) string { + return kv.vaultURL + "/secrets/" + secretName +} diff --git a/src/pkg/clouds/azure/keyvault/keyvault_test.go b/src/pkg/clouds/azure/keyvault/keyvault_test.go new file mode 100644 index 000000000..bad6a4438 --- /dev/null +++ b/src/pkg/clouds/azure/keyvault/keyvault_test.go @@ -0,0 +1,183 @@ +package keyvault + +import ( + "context" + "encoding/base64" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/DefangLabs/defang/src/pkg/clouds/azure" +) + +type fakeCred struct { + tok string + err error +} + +func (f fakeCred) GetToken(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) { + if f.err != nil { + return azcore.AccessToken{}, f.err + } + return azcore.AccessToken{Token: f.tok, ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func useFakeCred(t *testing.T, tok string, gerr error) { + t.Helper() + orig := azure.NewCredsFunc + azure.NewCredsFunc = func(_ azure.Azure) (azcore.TokenCredential, error) { + return fakeCred{tok: tok, err: gerr}, nil + } + t.Cleanup(func() { azure.NewCredsFunc = orig }) +} + +func TestVaultName(t *testing.T) { + name := VaultName("my-rg", "sub-id") + if !strings.HasPrefix(name, "defang-config-") { + t.Errorf("VaultName = %q, want defang-config- prefix", name) + } + if len(name) > 24 { + t.Errorf("VaultName %q exceeds 24 chars", name) + } + // Deterministic. + if VaultName("my-rg", "sub-id") != name { + t.Error("VaultName is not deterministic") + } + // Different inputs produce different names. + if VaultName("other-rg", "sub-id") == name { + t.Error("VaultName collision for different resource group") + } + if VaultName("my-rg", "other-sub") == name { + t.Error("VaultName collision for different subscription") + } +} + +func TestVaultURL(t *testing.T) { + got := VaultURL("kv-abc123") + want := "https://kv-abc123.vault.azure.net" + if got != want { + t.Errorf("VaultURL = %q, want %q", got, want) + } +} + +func TestToSecretName(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"FOO", "FOO"}, + {"/Defang", "Defang"}, + {"/Defang/myapp/test/POSTGRES_PASSWORD", "Defang--myapp--test--POSTGRES-PASSWORD"}, + {"foo_bar", "foo-bar"}, + {"foo/bar", "foo--bar"}, + } + for _, tt := range tests { + if got := ToSecretName(tt.in); got != tt.want { + t.Errorf("ToSecretName(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestNew(t *testing.T) { + kv := New("rg-name", azure.Azure{Location: azure.LocationWestUS2, SubscriptionID: "sub-id"}) + if kv == nil { + t.Fatal("New returned nil") + } + if kv.SubscriptionID != "sub-id" { + t.Errorf("SubscriptionID = %q, want sub-id", kv.SubscriptionID) + } + if kv.Location != azure.LocationWestUS2 { + t.Errorf("Location = %q, want westus2", kv.Location) + } + if kv.resourceGroupName != "rg-name" { + t.Errorf("resourceGroupName = %q, want rg-name", kv.resourceGroupName) + } +} + +func TestSecretURL(t *testing.T) { + kv := &KeyVault{vaultURL: "https://kv-abc.vault.azure.net"} + got := kv.SecretURL("my-secret") + want := "https://kv-abc.vault.azure.net/secrets/my-secret" + if got != want { + t.Errorf("SecretURL = %q, want %q", got, want) + } +} + +func TestVaultNameAndURLFields(t *testing.T) { + kv := New("rg", azure.Azure{Location: azure.LocationWestUS2, SubscriptionID: "sub"}) + // VaultName and vaultURL are populated by SetUp; zero values before that. + if kv.VaultName != "" { + t.Errorf("VaultName before SetUp = %q, want empty", kv.VaultName) + } + if kv.vaultURL != "" { + t.Errorf("vaultURL before SetUp = %q, want empty", kv.vaultURL) + } + // Simulate SetUp populating fields. + kv.VaultName = VaultName(kv.resourceGroupName, kv.SubscriptionID) + kv.vaultURL = VaultURL(kv.VaultName) + if kv.vaultURL == "" { + t.Error("vaultURL should be populated") + } + if got := kv.SecretURL("foo"); got != kv.vaultURL+"/secrets/foo" { + t.Errorf("SecretURL = %q", got) + } +} + +func TestNewSecretsClientNotSetUp(t *testing.T) { + useFakeCred(t, "x", nil) + kv := New("rg", azure.Azure{Location: azure.LocationWestUS2, SubscriptionID: "sub"}) + if _, err := kv.newSecretsClient(); err == nil { + t.Error("newSecretsClient should fail when vaultURL empty") + } +} + +func TestNewSecretsClientOK(t *testing.T) { + useFakeCred(t, "tok", nil) + kv := New("rg", azure.Azure{Location: azure.LocationWestUS2, SubscriptionID: "sub"}) + kv.vaultURL = "https://kv.vault.azure.net" + if _, err := kv.newSecretsClient(); err != nil { + t.Errorf("newSecretsClient: %v", err) + } +} + +func TestPutDeleteListSecretNotSetUp(t *testing.T) { + useFakeCred(t, "x", nil) + kv := New("rg", azure.Azure{Location: azure.LocationWestUS2, SubscriptionID: "sub"}) + if err := kv.PutSecret(context.Background(), "s", "v", "k"); err == nil { + t.Error("PutSecret should fail when vault not set up") + } + if err := kv.DeleteSecret(context.Background(), "s"); err == nil { + t.Error("DeleteSecret should fail when vault not set up") + } + if _, err := kv.ListSecrets(context.Background(), "prefix"); err == nil { + t.Error("ListSecrets should fail when vault not set up") + } +} + +func TestObjectIDFromJWT(t *testing.T) { + // Build a fake JWT with {"oid":"test-oid-value"} payload. + payload := `{"sub":"x","oid":"test-oid-value","aud":"y"}` + encoded := base64.RawURLEncoding.EncodeToString([]byte(payload)) + token := "header." + encoded + ".signature" + if got := objectIDFromJWT(token); got != "test-oid-value" { + t.Errorf("objectIDFromJWT = %q, want test-oid-value", got) + } + + // Missing oid claim. + noOID := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"x"}`)) + if got := objectIDFromJWT("h." + noOID + ".s"); got != "" { + t.Errorf("objectIDFromJWT without oid = %q, want empty", got) + } + + // Not a JWT (no '.'). + if got := objectIDFromJWT("not-a-jwt"); got != "" { + t.Errorf("objectIDFromJWT(bad) = %q, want empty", got) + } + + // Invalid base64 in payload. + if got := objectIDFromJWT("h.!!!not-base64!!!.s"); got != "" { + t.Errorf("objectIDFromJWT(bad base64) = %q, want empty", got) + } +} diff --git a/src/pkg/clouds/azure/location.go b/src/pkg/clouds/azure/location.go new file mode 100644 index 000000000..d1023e243 --- /dev/null +++ b/src/pkg/clouds/azure/location.go @@ -0,0 +1,120 @@ +package azure + +type Location string + +const ( + LocationAsia Location = "asia" + LocationAsiaPacific Location = "asiapacific" + LocationAustralia Location = "australia" + LocationAustraliaCentral Location = "australiacentral" + LocationAustraliaCentral2 Location = "australiacentral2" + LocationAustraliaEast Location = "australiaeast" + LocationAustraliaSouthEast Location = "australiasoutheast" + LocationAustriaEast Location = "austriaeast" + LocationBrazil Location = "brazil" + LocationBrazilSouth Location = "brazilsouth" + LocationBrazilSouthEast Location = "brazilsoutheast" + LocationBrazilUS Location = "brazilus" + LocationCanada Location = "canada" + LocationCanadaCentral Location = "canadacentral" + LocationCanadaEast Location = "canadaeast" + LocationCentralIndia Location = "centralindia" + LocationCentralUS Location = "centralus" + LocationCentralUSEuap Location = "centraluseuap" + LocationCentralUSStage Location = "centralusstage" + LocationChileCentral Location = "chilecentral" + LocationEastAsia Location = "eastasia" + LocationEastAsiaStage Location = "eastasiastage" + LocationEastUS Location = "eastus" + LocationEastUS2 Location = "eastus2" + LocationEastUS2Euap Location = "eastus2euap" + LocationEastUS2Stage Location = "eastus2stage" + LocationEastUSStage Location = "eastusstage" + LocationEastUSStg Location = "eastusstg" + LocationEurope Location = "europe" + LocationFrance Location = "france" + LocationFranceCentral Location = "francecentral" + LocationFranceSouth Location = "francesouth" + LocationGermany Location = "germany" + LocationGermanyNorth Location = "germanynorth" + LocationGermanyWestCentral Location = "germanywestcentral" + LocationGlobal Location = "global" + LocationIndia Location = "india" + LocationIndonesia Location = "indonesia" + LocationIndonesiaCentral Location = "indonesiacentral" + LocationIsrael Location = "israel" + LocationIsraelCentral Location = "israelcentral" + LocationItaly Location = "italy" + LocationItalyNorth Location = "italynorth" + LocationJapan Location = "japan" + LocationJapanEast Location = "japaneast" + LocationJapanWest Location = "japanwest" + LocationJioIndiaCentral Location = "jioindiacentral" + LocationJioIndiaWest Location = "jioindiawest" + LocationKorea Location = "korea" + LocationKoreaCentral Location = "koreacentral" + LocationKoreaSouth Location = "koreasouth" + LocationMalaysia Location = "malaysia" + LocationMalaysiaWest Location = "malaysiawest" + LocationMexico Location = "mexico" + LocationMexicoCentral Location = "mexicocentral" + LocationNewZealand Location = "newzealand" + LocationNewZealandNorth Location = "newzealandnorth" + LocationNorthCentralUS Location = "northcentralus" + LocationNorthCentralUSStage Location = "northcentralusstage" + LocationNorthEurope Location = "northeurope" + LocationNorway Location = "norway" + LocationNorwayEast Location = "norwayeast" + LocationNorwayWest Location = "norwaywest" + LocationPoland Location = "poland" + LocationPolandCentral Location = "polandcentral" + LocationQatar Location = "qatar" + LocationQatarCentral Location = "qatarcentral" + LocationSingapore Location = "singapore" + LocationSouthAfrica Location = "southafrica" + LocationSouthAfricaNorth Location = "southafricanorth" + LocationSouthAfricaWest Location = "southafricawest" + LocationSouthCentralUS Location = "southcentralus" + LocationSouthCentralUSStage Location = "southcentralusstage" + LocationSouthCentralUSStg Location = "southcentralusstg" + LocationSoutheastAsia Location = "southeastasia" + LocationSoutheastAsiaStage Location = "southeastasiastage" + LocationSouthIndia Location = "southindia" + LocationSpain Location = "spain" + LocationSpainCentral Location = "spaincentral" + LocationSweden Location = "sweden" + LocationSwedenCentral Location = "swedencentral" + LocationSwedenSouth Location = "swedensouth" + LocationSwitzerland Location = "switzerland" + LocationSwitzerlandNorth Location = "switzerlandnorth" + LocationSwitzerlandWest Location = "switzerlandwest" + LocationTaiwan Location = "taiwan" + LocationUae Location = "uae" + LocationUaeCentral Location = "uaecentral" + LocationUaeNorth Location = "uaenorth" + LocationUK Location = "uk" + LocationUKSouth Location = "uksouth" + LocationUKWest Location = "ukwest" + LocationUnitedStates Location = "unitedstates" + LocationUnitedStatesEuap Location = "unitedstateseuap" + LocationWestCentralUS Location = "westcentralus" + LocationWestEurope Location = "westeurope" + LocationWestIndia Location = "westindia" + LocationWestUS Location = "westus" + LocationWestUS2 Location = "westus2" + LocationWestUS2Stage Location = "westus2stage" + LocationWestUS3 Location = "westus3" + LocationWestUSStage Location = "westusstage" +) + +func (l Location) String() string { + return string(l) +} + +func (l Location) Ptr() *string { + if l == "" { + return nil + } + s := string(l) + return &s +} diff --git a/src/pkg/clouds/azure/location_test.go b/src/pkg/clouds/azure/location_test.go new file mode 100644 index 000000000..d689d5305 --- /dev/null +++ b/src/pkg/clouds/azure/location_test.go @@ -0,0 +1,35 @@ +package azure + +import "testing" + +func TestLocationString(t *testing.T) { + tests := []struct { + loc Location + want string + }{ + {"", ""}, + {LocationEastUS, "eastus"}, + {LocationWestUS2, "westus2"}, + {LocationWestEurope, "westeurope"}, + } + for _, tt := range tests { + if got := tt.loc.String(); got != tt.want { + t.Errorf("Location(%q).String() = %q, want %q", tt.loc, got, tt.want) + } + } +} + +func TestLocationPtr(t *testing.T) { + if p := Location("").Ptr(); p != nil { + t.Errorf("empty location Ptr() = %v, want nil", p) + } + + loc := LocationWestUS2 + p := loc.Ptr() + if p == nil { + t.Fatalf("Ptr() returned nil for non-empty location") + } + if *p != "westus2" { + t.Errorf("*Ptr() = %q, want %q", *p, "westus2") + } +} diff --git a/src/pkg/clouds/azure/login.go b/src/pkg/clouds/azure/login.go new file mode 100644 index 000000000..9be58a63c --- /dev/null +++ b/src/pkg/clouds/azure/login.go @@ -0,0 +1,307 @@ +package azure + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/DefangLabs/defang/src/pkg/term" + "github.com/DefangLabs/defang/src/pkg/tokenstore" +) + +const ( + // managementScope is the OAuth2 scope for ARM (management plane) calls. + managementScope = "https://management.azure.com/.default" + // azureCLIClientID is Microsoft's public client ID for Azure CLI. Using + // it means the user sees the same consent prompt they would from + // `az login --use-device-code`, and we don't need to register our own app. + azureCLIClientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" + // defaultTenant routes through Microsoft's "organizations" tenant so any + // work/school account can authenticate when we can't discover the + // subscription's specific tenant. + defaultTenant = "organizations" + // msalCacheKey is the TokenStore key holding MSAL's serialized cache blob + // (one blob per defang installation covers all accounts MSAL tracks). + msalCacheKey = "azure-msal-cache" +) + +// defangMSALCache adapts defang's file-based TokenStore to MSAL's +// cache.ExportReplace interface. MSAL calls Replace before every operation +// that consults its cache — multiple times per defang invocation — so we +// keep an in-memory mirror and touch the disk only once on the first +// Replace and again on Export when the cache actually changes. +type defangMSALCache struct { + store tokenstore.TokenStore + key string + + mu sync.Mutex + inMemory []byte // last-known serialized cache blob + loaded bool // true once the initial Load from disk has completed +} + +func (c *defangMSALCache) Replace(_ context.Context, u cache.Unmarshaler, _ cache.ReplaceHints) error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.loaded { + if c.store != nil { + if data, err := c.store.Load(c.key); err == nil { + c.inMemory = []byte(data) + } + // Load error (file not found etc.) → start empty; first Export will create it. + } + c.loaded = true + } + if len(c.inMemory) == 0 { + return nil + } + return u.Unmarshal(c.inMemory) +} + +func (c *defangMSALCache) Export(_ context.Context, m cache.Marshaler, _ cache.ExportHints) error { + data, err := m.Marshal() + if err != nil { + return err + } + c.mu.Lock() + defer c.mu.Unlock() + // MSAL calls Export after every successful operation even if nothing + // mutated. Skip the disk write when bytes match the last-seen state. + if c.loaded && bytes.Equal(data, c.inMemory) { + return nil + } + c.inMemory = data + c.loaded = true + if c.store == nil { + return nil + } + return c.store.Save(c.key, string(data)) +} + +// msalCred is an azcore.TokenCredential backed by an MSAL public client and +// a specific account. GetToken delegates to MSAL's AcquireTokenSilent, which +// handles per-scope token caching, refresh-token rotation, CAE, and claims +// challenges — freeing us from reimplementing any of that. +type msalCred struct { + client public.Client + account public.Account +} + +func (c *msalCred) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + if len(opts.Scopes) == 0 { + return azcore.AccessToken{}, errors.New("GetToken: at least one scope is required") + } + res, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account)) + if err != nil { + return azcore.AccessToken{}, fmt.Errorf("acquiring Azure token for %v: %w", opts.Scopes, err) + } + return azcore.AccessToken{Token: res.AccessToken, ExpiresOn: res.ExpiresOn}, nil +} + +// Authenticate sets up Azure credentials for the session in order of preference: +// 1. Existing default Azure credentials — env vars (AZURE_TENANT_ID/CLIENT_ID/ +// CLIENT_SECRET), managed identity, workload identity, an `az login` +// session picked up via AzureCLICredential, etc. +// 2. Silent token acquisition via MSAL, using its on-disk cache (persisted +// through defang's TokenStore). Covers the common case of a returning +// user with a still-valid refresh token. +// 3. Interactive device-code login (equivalent to `az login --use-device-code`). +// On success the refresh token is written to the cache so step 2 works +// on the next invocation. +// +// On success a.Cred is populated with an msalCred (for path 2/3) or a +// DefaultAzureCredential wrapper (path 1). Both honor per-scope GetToken +// requests from the Azure SDK. +func (a *Azure) Authenticate(ctx context.Context, interactive bool) error { + if a.SubscriptionID == "" { + a.SubscriptionID = os.Getenv("AZURE_SUBSCRIPTION_ID") + } + if a.SubscriptionID == "" { + return errors.New("AZURE_SUBSCRIPTION_ID is required for Azure login") + } + + // 1. DefaultAzureCredential (az cli session, env vars, managed identity, …). + term.Debug("checking default Azure credentials...") + if cred, err := a.tryDefaultCredential(ctx); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + term.Debugf("default Azure credentials invalid: %v", err) + } else if cred != nil { + term.Debug("found valid default Azure credentials") + a.Cred = cred + return nil + } + + // Resolve the subscription's tenant so MSAL authenticates against the + // right authority (avoids the InvalidAuthenticationTokenTenant error + // when the user's home tenant differs from the subscription's tenant). + tenant := os.Getenv("AZURE_TENANT_ID") + if tenant == "" { + if discovered, err := discoverSubscriptionTenant(ctx, a.SubscriptionID); err == nil { + tenant = discovered + term.Debugf("discovered tenant %s for subscription %s", tenant, a.SubscriptionID) + } else { + term.Debugf("tenant discovery failed, falling back to %q: %v", defaultTenant, err) + tenant = defaultTenant + } + } + + client, err := public.New(azureCLIClientID, + public.WithAuthority("https://login.microsoftonline.com/"+tenant), + public.WithCache(&defangMSALCache{store: a.TokenStore, key: msalCacheKey}), + ) + if err != nil { + return fmt.Errorf("creating MSAL client: %w", err) + } + + // 2. Silent token acquisition via any cached account in the right tenant. + if cred, err := a.trySilentMSAL(ctx, client, tenant); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + term.Debugf("silent MSAL acquisition failed: %v", err) + } else if cred != nil { + term.Debug("reused cached Azure credentials") + a.Cred = cred + return nil + } + + // 3. Interactive device-code login. + if !interactive { + return errors.New("no valid Azure credentials found; run `defang login` or `az login --use-device-code`, or set AZURE_TENANT_ID / CLIENT_ID / CLIENT_SECRET") + } + term.Info("no valid Azure credentials found, starting device code login...") + + dc, err := client.AcquireTokenByDeviceCode(ctx, []string{managementScope}) + if err != nil { + return fmt.Errorf("starting device code flow: %w", err) + } + term.Println(dc.Result.Message) + + res, err := dc.AuthenticationResult(ctx) + if err != nil { + return fmt.Errorf("device code login failed: %w", err) + } + cred := &msalCred{client: client, account: res.Account} + if err := testAzureCredential(ctx, a.SubscriptionID, cred); err != nil { + return fmt.Errorf("device code login token failed validation on subscription %q: %w", a.SubscriptionID, err) + } + a.Cred = cred + return nil +} + +// trySilentMSAL walks MSAL's cached accounts (filtered by tenant) and +// returns a credential for the first one that can silently mint an +// ARM-scoped token AND pass the subscription permission check. Returns +// (nil, nil) when nothing in the cache works. +func (a *Azure) trySilentMSAL(ctx context.Context, client public.Client, tenant string) (azcore.TokenCredential, error) { + accounts, err := client.Accounts(ctx) + if err != nil { + return nil, fmt.Errorf("listing MSAL accounts: %w", err) + } + for _, acct := range accounts { + if tenant != "" && tenant != defaultTenant && acct.Realm != "" && acct.Realm != tenant { + continue // account belongs to a different tenant + } + if _, err := client.AcquireTokenSilent(ctx, []string{managementScope}, public.WithSilentAccount(acct)); err != nil { + term.Debugf("silent acquire for %q failed: %v", acct.PreferredUsername, err) + continue + } + cred := &msalCred{client: client, account: acct} + if err := testAzureCredential(ctx, a.SubscriptionID, cred); err != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + term.Debugf("cached account %q failed subscription check: %v", acct.PreferredUsername, err) + continue + } + return cred, nil + } + return nil, nil +} + +// tryDefaultCredential constructs a DefaultAzureCredential and tests it +// against the subscription. Returns (nil, nil) when the cred builds but +// fails the permission check; returns (cred, nil) when it works. +func (a *Azure) tryDefaultCredential(ctx context.Context) (azcore.TokenCredential, error) { + defaultCred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, err + } + cred := &tokenCredentialWithTimeout{cred: defaultCred, timeout: cliTimeout} + if err := testAzureCredential(ctx, a.SubscriptionID, cred); err != nil { + return nil, err + } + return cred, nil +} + +// discoverSubscriptionTenant resolves the tenant that owns subscriptionID +// without any credentials. ARM responds to an unauthenticated GET on the +// subscription with a 401 whose WWW-Authenticate header embeds +// +// authorization_uri="https://login.microsoftonline.com/{tenantId}" +// +// — the same trick `az account show` uses when the CLI needs to pick a +// tenant for MSA/guest scenarios. +func discoverSubscriptionTenant(ctx context.Context, subscriptionID string) (string, error) { + endpoint := fmt.Sprintf("https://management.azure.com/subscriptions/%s?api-version=2020-01-01", url.PathEscape(subscriptionID)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + return "", fmt.Errorf("unexpected status %s from unauthenticated subscription probe", resp.Status) + } + + header := resp.Header.Get("WWW-Authenticate") + const key = `authorization_uri="` + i := strings.Index(header, key) + if i < 0 { + return "", errors.New("WWW-Authenticate missing authorization_uri") + } + rest := header[i+len(key):] + j := strings.IndexByte(rest, '"') + if j < 0 { + return "", errors.New("WWW-Authenticate authorization_uri unterminated") + } + authURL, err := url.Parse(rest[:j]) + if err != nil { + return "", fmt.Errorf("parsing authorization_uri: %w", err) + } + tenant := strings.Trim(authURL.Path, "/") + if tenant == "" { + return "", fmt.Errorf("authorization_uri %q has no tenant path", authURL) + } + return tenant, nil +} + +// testAzureCredential validates cred by asking ARM for the subscription. +// Any 200 response means the token is good and the caller has at least +// read access. +func testAzureCredential(ctx context.Context, subscriptionID string, cred azcore.TokenCredential) error { + client, err := armsubscriptions.NewClient(cred, nil) + if err != nil { + return fmt.Errorf("creating subscriptions client: %w", err) + } + if _, err := client.Get(ctx, subscriptionID, nil); err != nil { + return fmt.Errorf("cannot access subscription %q: %w", subscriptionID, err) + } + return nil +} diff --git a/src/pkg/clouds/azure/login_test.go b/src/pkg/clouds/azure/login_test.go new file mode 100644 index 000000000..f0245c3f3 --- /dev/null +++ b/src/pkg/clouds/azure/login_test.go @@ -0,0 +1,37 @@ +package azure + +import ( + "context" + "os" + "testing" + "time" +) + +func TestAuthenticateMissingSubscriptionID(t *testing.T) { + // Fully unset AZURE_SUBSCRIPTION_ID for the duration of this test — + // t.Setenv("", ...) would leave it set-but-empty and LookupEnv returns true. + if v, ok := os.LookupEnv("AZURE_SUBSCRIPTION_ID"); ok { + _ = os.Unsetenv("AZURE_SUBSCRIPTION_ID") + t.Cleanup(func() { _ = os.Setenv("AZURE_SUBSCRIPTION_ID", v) }) //nolint:usetesting + } + + a := &Azure{} + if err := a.Authenticate(context.Background(), false); err == nil { + t.Error("Authenticate should fail when AZURE_SUBSCRIPTION_ID is missing") + } + if a.Cred != nil { + t.Error("Cred should remain nil on error") + } +} + +func TestAuthenticateNonInteractiveFailsWithInvalidSubscription(t *testing.T) { + // An unknown subscription ID: the test call to ARM fails (either because + // the subscription doesn't exist or because the caller has no credentials + // at all). Non-interactive mode must return an error instead of prompting. + a := &Azure{SubscriptionID: "00000000-0000-0000-0000-000000000000"} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := a.Authenticate(ctx, false); err == nil { + t.Error("Authenticate(interactive=false) should fail with invalid subscription") + } +} diff --git a/src/pkg/http/post.go b/src/pkg/http/post.go index 3adb34019..67fae830d 100644 --- a/src/pkg/http/post.go +++ b/src/pkg/http/post.go @@ -36,13 +36,15 @@ func PostFormWithContext(ctx context.Context, url string, data url.Values) (*htt return PostWithContext(ctx, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } -func PostWithContext(ctx context.Context, url, contentType string, body io.Reader) (*http.Response, error) { +func PostWithHeader(ctx context.Context, url string, header http.Header, body io.Reader) (*http.Response, error) { hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) if err != nil { return nil, err } - if contentType != "" { - hreq.Header.Set("Content-Type", contentType) - } + hreq.Header = header return DefaultClient.Do(hreq) } + +func PostWithContext(ctx context.Context, url, contentType string, body io.Reader) (*http.Response, error) { + return PostWithHeader(ctx, url, http.Header{"Content-Type": []string{contentType}}, body) +} diff --git a/src/pkg/http/put.go b/src/pkg/http/put.go index 8304b72b4..7048863d5 100644 --- a/src/pkg/http/put.go +++ b/src/pkg/http/put.go @@ -18,10 +18,14 @@ import ( // See the Client.Do method documentation for details on how redirects // are handled. func Put(ctx context.Context, url string, contentType string, body io.Reader) (*http.Response, error) { + return PutWithHeader(ctx, url, http.Header{"Content-Type": []string{contentType}}, body) +} + +func PutWithHeader(ctx context.Context, url string, header http.Header, body io.Reader) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body) if err != nil { return nil, err } - req.Header.Set("Content-Type", contentType) + req.Header = header return DefaultClient.Do(req) } diff --git a/src/pkg/session/session.go b/src/pkg/session/session.go index 66ddf4e31..de9d5344e 100644 --- a/src/pkg/session/session.go +++ b/src/pkg/session/session.go @@ -102,6 +102,9 @@ func printProviderMismatchWarnings(ctx context.Context, provider client.Provider if env := pkg.GcpInEnv(); env != "" { term.Warnf("GCP project environment variable was detected (%v); did you forget --provider=gcp or DEFANG_PROVIDER=gcp?", env) } + if env := pkg.AzureInEnv(); env != "" { + term.Warnf("Azure environment variables were detected (%v); did you forget --provider=azure or DEFANG_PROVIDER=azure?", env) + } } switch provider { @@ -117,6 +120,10 @@ func printProviderMismatchWarnings(ctx context.Context, provider client.Provider if env := pkg.GcpInEnv(); env == "" { term.Warnf("GCP provider was selected, but no GCP project environment variable is set (%v)", pkg.GCPProjectEnvVars) } + case client.ProviderAzure: + if env := pkg.AzureInEnv(); env == "" { + term.Warn("Azure provider was selected, but no Azure environment variables are set") + } } } diff --git a/src/pkg/session/session_test.go b/src/pkg/session/session_test.go index 0424916ef..f9ca984bc 100644 --- a/src/pkg/session/session_test.go +++ b/src/pkg/session/session_test.go @@ -65,6 +65,55 @@ func (m *mockStacksManager) TargetDirectory() string { return "" } +func TestPrintProviderMismatchWarnings(t *testing.T) { + tests := []struct { + name string + provider client.ProviderID + env map[string]string + }{ + {"defang with no env", client.ProviderDefang, nil}, + {"defang with AWS env", client.ProviderDefang, map[string]string{"AWS_PROFILE": "x"}}, + {"defang with DO env", client.ProviderDefang, map[string]string{"DIGITALOCEAN_TOKEN": "x"}}, + {"defang with Azure env", client.ProviderDefang, map[string]string{"AZURE_SUBSCRIPTION_ID": "x"}}, + {"azure with no env", client.ProviderAzure, nil}, + {"azure with env set", client.ProviderAzure, map[string]string{"AZURE_SUBSCRIPTION_ID": "sub"}}, + {"do with no env", client.ProviderDO, nil}, + {"do with env", client.ProviderDO, map[string]string{"DIGITALOCEAN_TOKEN": "t"}}, + {"gcp with no env", client.ProviderGCP, nil}, + } + + // Unset all provider env vars to give the test deterministic state. + unset := []string{ + "AWS_PROFILE", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_ROLE_ARN", + "DIGITALOCEAN_TOKEN", "DIGITALOCEAN_ACCESS_TOKEN", + "AZURE_SUBSCRIPTION_ID", "AZURE_TENANT_ID", "AZURE_CLIENT_ID", "AZURE_CLIENT_SECRET", + "GOOGLE_CLOUD_PROJECT", "GCP_PROJECT_ID", "GCLOUD_PROJECT", "CLOUDSDK_CORE_PROJECT", + } + saved := map[string]string{} + for _, k := range unset { + if v, ok := os.LookupEnv(k); ok { + saved[k] = v + _ = os.Unsetenv(k) + } + } + t.Cleanup(func() { + for k, v := range saved { + _ = os.Setenv(k, v) //nolint:usetesting // t.Setenv registers another cleanup; restore via os.Setenv + } + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + // Function writes warnings to term but has no return value; we just + // ensure it runs without panicking and exercises each branch. + printProviderMismatchWarnings(context.Background(), tt.provider) + }) + } +} + func TestLoadSession(t *testing.T) { tests := []struct { name string diff --git a/src/pkg/stacks/selector_test.go b/src/pkg/stacks/selector_test.go index 188ab61a4..aa8bbf585 100644 --- a/src/pkg/stacks/selector_test.go +++ b/src/pkg/stacks/selector_test.go @@ -183,7 +183,7 @@ func TestStackSelector_SelectStack_CreateNewStack(t *testing.T) { mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return(CreateNewStack, nil) // Mock wizard parameter collection - provider selection - providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform"} + providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform", "Azure"} mockEC.On("RequestEnum", ctx, "Where do you want to deploy?", "provider", providerOptions).Return("AWS", nil) // Mock wizard parameter collection - region selection (default is us-west-2 for AWS) @@ -259,7 +259,7 @@ func TestStackSelector_SelectStack_NoExistingStacks(t *testing.T) { mockSM.On("List", ctx).Return([]ListItem{}, nil) // Mock wizard parameter collection - provider selection - providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform"} + providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform", "Azure"} mockEC.On("RequestEnum", ctx, "Where do you want to deploy?", "provider", providerOptions).Return("AWS", nil) // Mock wizard parameter collection - region selection @@ -418,7 +418,7 @@ func TestStackSelector_SelectStack_WizardError(t *testing.T) { mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return(CreateNewStack, nil) // Mock wizard parameter collection - provider selection fails - providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform"} + providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform", "Azure"} mockEC.On("RequestEnum", ctx, "Where do you want to deploy?", "provider", providerOptions).Return("", errors.New("user cancelled wizard")) selector := NewSelector(mockEC, mockSM) @@ -455,7 +455,7 @@ func TestStackSelector_SelectStack_CreateStackError(t *testing.T) { mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return(CreateNewStack, nil) // Mock wizard parameter collection - provider selection - providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform"} + providerOptions := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform", "Azure"} mockEC.On("RequestEnum", ctx, "Where do you want to deploy?", "provider", providerOptions).Return("AWS", nil) // Mock wizard parameter collection - region selection diff --git a/src/pkg/stacks/wizard.go b/src/pkg/stacks/wizard.go index e572712fd..b82c34c40 100644 --- a/src/pkg/stacks/wizard.go +++ b/src/pkg/stacks/wizard.go @@ -143,6 +143,16 @@ func (w *Wizard) CollectRemainingParameters(ctx context.Context, params *Paramet } params.Variables["GCP_PROJECT_ID"] = projectID } + case client.ProviderAzure: + if params.Variables["AZURE_SUBSCRIPTION_ID"] == "" { + subscriptionID, err := w.ec.RequestString(ctx, "What is your Azure Subscription ID?:", "azure_subscription_id", + elicitations.WithDefault(os.Getenv("AZURE_SUBSCRIPTION_ID")), + ) + if err != nil { + return nil, fmt.Errorf("failed to elicit Azure Subscription ID: %w", err) + } + params.Variables["AZURE_SUBSCRIPTION_ID"] = subscriptionID + } } return params, nil diff --git a/src/pkg/utils.go b/src/pkg/utils.go index fc3eb98a8..8c8926f75 100644 --- a/src/pkg/utils.go +++ b/src/pkg/utils.go @@ -226,3 +226,8 @@ func GcpInEnv() string { env, _ := GetFirstEnv(GCPProjectEnvVars...) return env } + +func AzureInEnv() string { + env, _ := GetFirstEnv("AZURE_SUBSCRIPTION_ID", "AZURE_TENANT_ID", "AZURE_CLIENT_ID", "AZURE_CLIENT_SECRET") + return env +} diff --git a/src/pkg/utils_test.go b/src/pkg/utils_test.go index d6a6b8ed3..494d8bd7b 100644 --- a/src/pkg/utils_test.go +++ b/src/pkg/utils_test.go @@ -1,6 +1,7 @@ package pkg import ( + "os" "reflect" "testing" "time" @@ -194,6 +195,74 @@ func TestShellQuote(t *testing.T) { } } +func unsetAll(t *testing.T, keys ...string) { + t.Helper() + saved := map[string]string{} + for _, k := range keys { + if v, ok := os.LookupEnv(k); ok { + saved[k] = v + if err := os.Unsetenv(k); err != nil { + t.Fatalf("unsetenv %s: %v", k, err) + } + } + } + t.Cleanup(func() { + for k, v := range saved { + _ = os.Setenv(k, v) //nolint:usetesting // t.Setenv registers another cleanup; restore via os.Setenv + } + }) +} + +func TestAzureInEnv(t *testing.T) { + unsetAll(t, "AZURE_SUBSCRIPTION_ID", "AZURE_TENANT_ID", "AZURE_CLIENT_ID", "AZURE_CLIENT_SECRET") + if got := AzureInEnv(); got != "" { + t.Errorf("AzureInEnv() with no vars set = %q, want empty", got) + } + t.Setenv("AZURE_CLIENT_ID", "abc") + if got := AzureInEnv(); got != "AZURE_CLIENT_ID" { + t.Errorf("AzureInEnv() = %q, want AZURE_CLIENT_ID", got) + } + t.Setenv("AZURE_SUBSCRIPTION_ID", "sub") // first in list, should win + if got := AzureInEnv(); got != "AZURE_SUBSCRIPTION_ID" { + t.Errorf("AzureInEnv() prefers AZURE_SUBSCRIPTION_ID, got %q", got) + } +} + +func TestAwsInEnv(t *testing.T) { + unsetAll(t, "AWS_PROFILE", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_ROLE_ARN") + if got := AwsInEnv(); got != "" { + t.Errorf("AwsInEnv() with no vars set = %q, want empty", got) + } + t.Setenv("AWS_ROLE_ARN", "arn") + if got := AwsInEnv(); got != "AWS_ROLE_ARN" { + t.Errorf("AwsInEnv() = %q, want AWS_ROLE_ARN", got) + } +} + +func TestDoInEnv(t *testing.T) { + unsetAll(t, "DIGITALOCEAN_ACCESS_TOKEN", "DIGITALOCEAN_TOKEN") + if got := DoInEnv(); got != "" { + t.Errorf("DoInEnv() with no vars = %q, want empty", got) + } + t.Setenv("DIGITALOCEAN_TOKEN", "x") + if got := DoInEnv(); got != "DIGITALOCEAN_TOKEN" { + t.Errorf("DoInEnv() = %q, want DIGITALOCEAN_TOKEN", got) + } +} + +func TestGcpInEnv(t *testing.T) { + unsetAll(t, GCPProjectEnvVars...) + if got := GcpInEnv(); got != "" { + t.Errorf("GcpInEnv() with no vars = %q, want empty", got) + } + if len(GCPProjectEnvVars) > 0 { + t.Setenv(GCPProjectEnvVars[0], "proj") + if got := GcpInEnv(); got != GCPProjectEnvVars[0] { + t.Errorf("GcpInEnv() = %q, want %q", got, GCPProjectEnvVars[0]) + } + } +} + func TestGetFirstEnv(t *testing.T) { tests := []struct { name string diff --git a/src/protos/io/defang/v1/fabric.proto b/src/protos/io/defang/v1/fabric.proto index 8f6c3307c..aae3ef9ab 100644 --- a/src/protos/io/defang/v1/fabric.proto +++ b/src/protos/io/defang/v1/fabric.proto @@ -704,7 +704,8 @@ message GetRequest { // was ServiceID } message Service { - option deprecated = true; // still used by pulumi-defang provider in state files + option deprecated = + true; // still used by pulumi-defang provider in state files string name = 1; reserved 2; // was: string image reserved 3; // was: Platform platform