diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 1cec48c3f50..513d4980da9 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### Notable Changes ### CLI +* Add `databricks dbconnect init` and `databricks dbconnect sync` to provision a local Python environment (Python version, `databricks-connect` pin, and dependency constraints) matched to the selected Databricks compute target. ### Bundles * `bundle run` now prints the modern job run URL (`/jobs//runs/`) so that non-admin users permitted to view the run are taken to the run instead of the workspace homepage. diff --git a/acceptance/dbconnect/cluster-unsupported/out.test.toml b/acceptance/dbconnect/cluster-unsupported/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/cluster-unsupported/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/cluster-unsupported/output.txt b/acceptance/dbconnect/cluster-unsupported/output.txt new file mode 100644 index 00000000000..bbfaf955a4a --- /dev/null +++ b/acceptance/dbconnect/cluster-unsupported/output.txt @@ -0,0 +1,8 @@ +=== Phase 1: preflight === + status=ok uv [UV_VERSION] +=== Phase 2: resolve === + status=ok kind=cluster envKey=dbr/15.4.x-scala2.12 +=== Phase 3: fetch === + status=failed fetch constraints for dbr/15.4.x-scala2.12: GET [DATABRICKS_URL]/dbr/15.4.x-scala2.12/pyproject.toml: unexpected status 404 Not Found +For more detail, re-run with --debug, or --output json to share a structured report. +Error: fetch constraints for dbr/15.4.x-scala2.12: GET [DATABRICKS_URL]/dbr/15.4.x-scala2.12/pyproject.toml: unexpected status 404 Not Found diff --git a/acceptance/dbconnect/cluster-unsupported/script b/acceptance/dbconnect/cluster-unsupported/script new file mode 100644 index 00000000000..c07f6635790 --- /dev/null +++ b/acceptance/dbconnect/cluster-unsupported/script @@ -0,0 +1 @@ +musterr $CLI dbconnect init --cluster test-cluster-id --check diff --git a/acceptance/dbconnect/cluster-unsupported/test.toml b/acceptance/dbconnect/cluster-unsupported/test.toml new file mode 100644 index 00000000000..b152f24ecbf --- /dev/null +++ b/acceptance/dbconnect/cluster-unsupported/test.toml @@ -0,0 +1,22 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] + +[Env] +DATABRICKS_DBCONNECT_CONSTRAINT_SOURCE = "$DATABRICKS_HOST" + +[[Server]] +Pattern = "GET /api/2.1/clusters/get" +Response.Body = ''' +{ + "cluster_id": "test-cluster-id", + "spark_version": "15.4.x-scala2.12" +} +''' + +[[Server]] +Pattern = "GET /dbr/15.4.x-scala2.12/pyproject.toml" +Response.StatusCode = 404 +Response.Body = '{"message": "Not found"}' + +[[Repls]] +Old = 'uv uv \S+(?: \([^)]+\))?' +New = 'uv [UV_VERSION]' diff --git a/acceptance/dbconnect/flag-conflict/out.test.toml b/acceptance/dbconnect/flag-conflict/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/flag-conflict/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/flag-conflict/output.txt b/acceptance/dbconnect/flag-conflict/output.txt new file mode 100644 index 00000000000..141152ae1cb --- /dev/null +++ b/acceptance/dbconnect/flag-conflict/output.txt @@ -0,0 +1 @@ +Error: if any flags in the group [cluster serverless job] are set none of the others can be; [cluster serverless] were all set diff --git a/acceptance/dbconnect/flag-conflict/script b/acceptance/dbconnect/flag-conflict/script new file mode 100644 index 00000000000..9c344219428 --- /dev/null +++ b/acceptance/dbconnect/flag-conflict/script @@ -0,0 +1 @@ +musterr $CLI dbconnect init --cluster abc --serverless v4 diff --git a/acceptance/dbconnect/flag-conflict/test.toml b/acceptance/dbconnect/flag-conflict/test.toml new file mode 100644 index 00000000000..c63fe3fe108 --- /dev/null +++ b/acceptance/dbconnect/flag-conflict/test.toml @@ -0,0 +1 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/help/out.test.toml b/acceptance/dbconnect/help/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/help/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/help/output.txt b/acceptance/dbconnect/help/output.txt new file mode 100644 index 00000000000..bbafed7a057 --- /dev/null +++ b/acceptance/dbconnect/help/output.txt @@ -0,0 +1,40 @@ +Set up a local Python environment matched to your Databricks compute target. + +Derives the Python version, databricks-connect version, and dependency +constraints from the selected compute (cluster, serverless, or job) so that +local resolution matches the Databricks runtime. + +Usage: + databricks dbconnect [command] + +Available Commands: + init Create a fresh pyproject.toml and provision a matched .venv + sync Merge managed dependencies into an existing pyproject.toml and re-provision + +Flags: + -h, --help help for dbconnect + +Global Flags: + --debug enable debug logging + -o, --output type output type: text or json (default text) + -p, --profile string ~/.databrickscfg profile + -t, --target string bundle target to use (if applicable) + +Use "databricks dbconnect [command] --help" for more information about a command. +Create a fresh pyproject.toml and provision a matched .venv + +Usage: + databricks dbconnect init [flags] + +Flags: + --check compute the plan without writing files or provisioning + --cluster string cluster ID to use as the compute target + -h, --help help for init + --job string job ID to use as the compute target + --serverless string serverless version to use as the compute target (e.g. v4) + +Global Flags: + --debug enable debug logging + -o, --output type output type: text or json (default text) + -p, --profile string ~/.databrickscfg profile + -t, --target string bundle target to use (if applicable) diff --git a/acceptance/dbconnect/help/script b/acceptance/dbconnect/help/script new file mode 100644 index 00000000000..962d7c3f64e --- /dev/null +++ b/acceptance/dbconnect/help/script @@ -0,0 +1,2 @@ +$CLI dbconnect --help +$CLI dbconnect init --help diff --git a/acceptance/dbconnect/help/test.toml b/acceptance/dbconnect/help/test.toml new file mode 100644 index 00000000000..c63fe3fe108 --- /dev/null +++ b/acceptance/dbconnect/help/test.toml @@ -0,0 +1 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/json-error/out.test.toml b/acceptance/dbconnect/json-error/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/json-error/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/json-error/output.txt b/acceptance/dbconnect/json-error/output.txt new file mode 100644 index 00000000000..39376a2cfce --- /dev/null +++ b/acceptance/dbconnect/json-error/output.txt @@ -0,0 +1,20 @@ +{ + "mode": "init", + "check": false, + "phases": [ + { + "name": "preflight", + "status": "ok", + "detail": "uv [UV_VERSION]" + }, + { + "name": "resolve", + "status": "failed", + "detail": "No compute target is selected. Select a cluster or serverless target, or pass --cluster/--serverless/--job" + } + ], + "error": { + "code": "no_target_selected", + "message": "No compute target is selected. Select a cluster or serverless target, or pass --cluster/--serverless/--job" + } +} diff --git a/acceptance/dbconnect/json-error/script b/acceptance/dbconnect/json-error/script new file mode 100644 index 00000000000..0a6837bb8f3 --- /dev/null +++ b/acceptance/dbconnect/json-error/script @@ -0,0 +1 @@ +musterr $CLI dbconnect init --output json diff --git a/acceptance/dbconnect/json-error/test.toml b/acceptance/dbconnect/json-error/test.toml new file mode 100644 index 00000000000..0d0481fb836 --- /dev/null +++ b/acceptance/dbconnect/json-error/test.toml @@ -0,0 +1,5 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] + +[[Repls]] +Old = 'uv uv \S+(?: \([^)]+\))?' +New = 'uv [UV_VERSION]' diff --git a/acceptance/dbconnect/no-target/out.test.toml b/acceptance/dbconnect/no-target/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/no-target/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/no-target/output.txt b/acceptance/dbconnect/no-target/output.txt new file mode 100644 index 00000000000..c2eb8434f86 --- /dev/null +++ b/acceptance/dbconnect/no-target/output.txt @@ -0,0 +1,6 @@ +=== Phase 1: preflight === + status=ok uv [UV_VERSION] +=== Phase 2: resolve === + status=failed No compute target is selected. Select a cluster or serverless target, or pass --cluster/--serverless/--job +For more detail, re-run with --debug, or --output json to share a structured report. +Error: No compute target is selected. Select a cluster or serverless target, or pass --cluster/--serverless/--job diff --git a/acceptance/dbconnect/no-target/script b/acceptance/dbconnect/no-target/script new file mode 100644 index 00000000000..d8f8e147a53 --- /dev/null +++ b/acceptance/dbconnect/no-target/script @@ -0,0 +1 @@ +musterr $CLI dbconnect init diff --git a/acceptance/dbconnect/no-target/test.toml b/acceptance/dbconnect/no-target/test.toml new file mode 100644 index 00000000000..0d0481fb836 --- /dev/null +++ b/acceptance/dbconnect/no-target/test.toml @@ -0,0 +1,5 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] + +[[Repls]] +Old = 'uv uv \S+(?: \([^)]+\))?' +New = 'uv [UV_VERSION]' diff --git a/acceptance/dbconnect/serverless-check/out.test.toml b/acceptance/dbconnect/serverless-check/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/serverless-check/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/serverless-check/output.txt b/acceptance/dbconnect/serverless-check/output.txt new file mode 100644 index 00000000000..98c84da8858 --- /dev/null +++ b/acceptance/dbconnect/serverless-check/output.txt @@ -0,0 +1,17 @@ + +>>> [CLI] dbconnect init --serverless v4 --check +=== Phase 1: preflight === + status=ok uv [UV_VERSION] +=== Phase 2: resolve === + status=ok kind=serverless envKey=serverless/serverless-v4 +=== Phase 3: fetch === + status=ok source=[DATABRICKS_URL]/serverless/serverless-v4/pyproject.toml fromCache=false +=== Phase 4: parse-python-version === + status=ok 3.12 +=== Phase 5: plan === + status=ok changed=requires-python,databricks-connect,tool.uv.constraint-dependencies +Plan: [TEST_TMP_DIR]/pyproject.toml + changed region: requires-python + changed region: databricks-connect + changed region: tool.uv.constraint-dependencies +Check complete. No files were modified. diff --git a/acceptance/dbconnect/serverless-check/script b/acceptance/dbconnect/serverless-check/script new file mode 100644 index 00000000000..f360138e4f3 --- /dev/null +++ b/acceptance/dbconnect/serverless-check/script @@ -0,0 +1 @@ +trace $CLI dbconnect init --serverless v4 --check diff --git a/acceptance/dbconnect/serverless-check/test.toml b/acceptance/dbconnect/serverless-check/test.toml new file mode 100644 index 00000000000..47881839976 --- /dev/null +++ b/acceptance/dbconnect/serverless-check/test.toml @@ -0,0 +1,21 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] + +[Env] +DATABRICKS_DBCONNECT_CONSTRAINT_SOURCE = "$DATABRICKS_HOST" + +[[Server]] +Pattern = "GET /serverless/serverless-v4/pyproject.toml" +Response.Body = ''' +[project] +requires-python = ">=3.12" + +[dependency-groups] +dev = ["databricks-connect~=17.2.0"] + +[tool.uv] +constraint-dependencies = ["pyarrow<19", "pandas<3"] +''' + +[[Repls]] +Old = 'uv uv \S+(?: \([^)]+\))?' +New = 'uv [UV_VERSION]' diff --git a/acceptance/dbconnect/serverless-json/out.test.toml b/acceptance/dbconnect/serverless-json/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/dbconnect/serverless-json/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/dbconnect/serverless-json/output.txt b/acceptance/dbconnect/serverless-json/output.txt new file mode 100644 index 00000000000..966ab4e556a --- /dev/null +++ b/acceptance/dbconnect/serverless-json/output.txt @@ -0,0 +1,57 @@ + +>>> [CLI] dbconnect init --serverless v4 --check --output json +{ + "mode": "init", + "check": true, + "target": { + "kind": "serverless", + "cluster_id": "", + "spark_version": "", + "env_key": "serverless/serverless-v4", + "python_version": "3.12" + }, + "constraints": { + "source_url": "[DATABRICKS_URL]/serverless/serverless-v4/pyproject.toml", + "from_cache": false, + "requires_python": "\u003e=3.12", + "databricks_connect": "databricks-connect~=17.2.0", + "constraint_count": 2 + }, + "plan": { + "pyproject_path": "[TEST_TMP_DIR]/pyproject.toml", + "backup_path": "[TEST_TMP_DIR]/pyproject.toml.bak", + "diff": "--- pyproject.toml\n+++ pyproject.toml\n@@ -1 +1,16 @@\n+[project]\n+name = \"001\"\n+requires-python = \"\u003e=3.12\"\n+\n+[dependency-groups]\n+dev = [\n+ \"databricks-connect~=17.2.0\",\n+]\n+\n+# managed by databricks dbconnect — do not edit\n+[tool.uv]\n+constraint-dependencies = [\n+ \"pyarrow\u003c19\",\n+ \"pandas\u003c3\",\n+]\n+# end managed by databricks dbconnect\n", + "changed_regions": [ + "requires-python", + "databricks-connect", + "tool.uv.constraint-dependencies" + ] + }, + "phases": [ + { + "name": "preflight", + "status": "ok", + "detail": "uv [UV_VERSION]" + }, + { + "name": "resolve", + "status": "ok", + "detail": "kind=serverless envKey=serverless/serverless-v4" + }, + { + "name": "fetch", + "status": "ok", + "detail": "source=[DATABRICKS_URL]/serverless/serverless-v4/pyproject.toml fromCache=false" + }, + { + "name": "parse-python-version", + "status": "ok", + "detail": "3.12" + }, + { + "name": "plan", + "status": "ok", + "detail": "changed=requires-python,databricks-connect,tool.uv.constraint-dependencies" + } + ] +} diff --git a/acceptance/dbconnect/serverless-json/script b/acceptance/dbconnect/serverless-json/script new file mode 100644 index 00000000000..68f96164406 --- /dev/null +++ b/acceptance/dbconnect/serverless-json/script @@ -0,0 +1 @@ +trace $CLI dbconnect init --serverless v4 --check --output json diff --git a/acceptance/dbconnect/serverless-json/test.toml b/acceptance/dbconnect/serverless-json/test.toml new file mode 100644 index 00000000000..47881839976 --- /dev/null +++ b/acceptance/dbconnect/serverless-json/test.toml @@ -0,0 +1,21 @@ +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] + +[Env] +DATABRICKS_DBCONNECT_CONSTRAINT_SOURCE = "$DATABRICKS_HOST" + +[[Server]] +Pattern = "GET /serverless/serverless-v4/pyproject.toml" +Response.Body = ''' +[project] +requires-python = ">=3.12" + +[dependency-groups] +dev = ["databricks-connect~=17.2.0"] + +[tool.uv] +constraint-dependencies = ["pyarrow<19", "pandas<3"] +''' + +[[Repls]] +Old = 'uv uv \S+(?: \([^)]+\))?' +New = 'uv [UV_VERSION]' diff --git a/cmd/cmd.go b/cmd/cmd.go index 718d3a8fda3..47aef433604 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -15,6 +15,7 @@ import ( "github.com/databricks/cli/cmd/cache" "github.com/databricks/cli/cmd/completion" "github.com/databricks/cli/cmd/configure" + "github.com/databricks/cli/cmd/dbconnect" "github.com/databricks/cli/cmd/experimental" "github.com/databricks/cli/cmd/fs" "github.com/databricks/cli/cmd/labs" @@ -120,6 +121,7 @@ func New(ctx context.Context) *cobra.Command { cli.AddCommand(cache.New()) cli.AddCommand(experimental.New()) cli.AddCommand(psql.New()) + cli.AddCommand(dbconnect.New()) cli.AddCommand(configure.New()) cli.AddCommand(fs.New()) cli.AddCommand(labs.New(ctx)) diff --git a/cmd/dbconnect/compute.go b/cmd/dbconnect/compute.go new file mode 100644 index 00000000000..97e702e3190 --- /dev/null +++ b/cmd/dbconnect/compute.go @@ -0,0 +1,60 @@ +package dbconnect + +import ( + "context" + "fmt" + "strconv" + + databricks "github.com/databricks/databricks-sdk-go" +) + +// sdkCompute adapts the Databricks SDK to the dbconnect.ComputeClient interface. +type sdkCompute struct { + w *databricks.WorkspaceClient +} + +// GetClusterSparkVersion returns the Spark version string for a running cluster. +func (c sdkCompute) GetClusterSparkVersion(ctx context.Context, clusterID string) (string, error) { + d, err := c.w.Clusters.GetByClusterId(ctx, clusterID) + if err != nil { + return "", fmt.Errorf("get cluster %s: %w", clusterID, err) + } + return d.SparkVersion, nil +} + +// GetJobSparkVersion inspects the job's configuration to determine compute type. +// +// A job is considered serverless when it has non-empty Environments (JobEnvironment +// entries), which signals the Databricks serverless runtime. A job with classic compute +// uses JobClusters; we read SparkVersion from the first job cluster's NewCluster spec. +// If neither indicator is present the job's compute cannot be determined. +func (c sdkCompute) GetJobSparkVersion(ctx context.Context, jobID string) (sparkVersion string, isServerless bool, version string, err error) { + id, err := strconv.ParseInt(jobID, 10, 64) + if err != nil { + return "", false, "", fmt.Errorf("invalid job ID %q: must be an integer: %w", jobID, err) + } + + job, err := c.w.Jobs.GetByJobId(ctx, id) + if err != nil { + return "", false, "", fmt.Errorf("get job %d: %w", id, err) + } + + if job.Settings == nil { + return "", false, "", fmt.Errorf("job %d has no settings", id) + } + + // Serverless jobs have Environments populated; classic compute uses JobClusters. + if len(job.Settings.Environments) > 0 { + return "", true, "", nil + } + + if len(job.Settings.JobClusters) > 0 { + sv := job.Settings.JobClusters[0].NewCluster.SparkVersion + if sv == "" { + return "", false, "", fmt.Errorf("could not determine compute for job %d: first job cluster has no spark_version", id) + } + return sv, false, sv, nil + } + + return "", false, "", fmt.Errorf("could not determine compute for job %d: no environments or job clusters found", id) +} diff --git a/cmd/dbconnect/dbconnect.go b/cmd/dbconnect/dbconnect.go new file mode 100644 index 00000000000..5ccaeb1ac44 --- /dev/null +++ b/cmd/dbconnect/dbconnect.go @@ -0,0 +1,20 @@ +package dbconnect + +import "github.com/spf13/cobra" + +// New returns the `dbconnect` command group. +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "dbconnect", + Short: "Set up a local Python environment matched to your Databricks compute", + GroupID: "development", + Long: `Set up a local Python environment matched to your Databricks compute target. + +Derives the Python version, databricks-connect version, and dependency +constraints from the selected compute (cluster, serverless, or job) so that +local resolution matches the Databricks runtime.`, + } + cmd.AddCommand(newInitCommand()) + cmd.AddCommand(newSyncCommand()) + return cmd +} diff --git a/cmd/dbconnect/init.go b/cmd/dbconnect/init.go new file mode 100644 index 00000000000..c862dee3a4d --- /dev/null +++ b/cmd/dbconnect/init.go @@ -0,0 +1,137 @@ +package dbconnect + +import ( + "context" + "os" + "path/filepath" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + libsdbconnect "github.com/databricks/cli/libs/dbconnect" + "github.com/databricks/cli/libs/env" + "github.com/spf13/cobra" +) + +const ( + // defaultConstraintBaseURL is the default URL for the constraint source. + defaultConstraintBaseURL = "https://raw.githubusercontent.com/pietern/databricks-environments/main" + + // envConstraintSource is the environment variable for overriding the constraint source URL. + envConstraintSource = "DATABRICKS_DBCONNECT_CONSTRAINT_SOURCE" +) + +func newInitCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "init", + Short: "Create a fresh pyproject.toml and provision a matched .venv", + } + cmd.PreRunE = root.MustWorkspaceClient + addTargetFlags(cmd) + cmd.RunE = func(cmd *cobra.Command, args []string) error { + return runPipeline(cmd, libsdbconnect.ModeInit) + } + return cmd +} + +// addTargetFlags adds the shared target flags to a command. +func addTargetFlags(cmd *cobra.Command) { + cmd.Flags().String("cluster", "", "cluster ID to use as the compute target") + cmd.Flags().String("serverless", "", "serverless version to use as the compute target (e.g. v4)") + cmd.Flags().String("job", "", "job ID to use as the compute target") + cmd.Flags().Bool("check", false, "compute the plan without writing files or provisioning") + cmd.Flags().String("constraint-source", "", "URL for the constraint source (overrides "+envConstraintSource+")") + // Hide constraint-source from casual --help output; it is a power-user escape hatch. + _ = cmd.Flags().MarkHidden("constraint-source") + cmd.MarkFlagsMutuallyExclusive("cluster", "serverless", "job") +} + +// runPipeline builds and runs the dbconnect Pipeline for the given mode. +func runPipeline(cmd *cobra.Command, mode libsdbconnect.Mode) error { + ctx := cmd.Context() + + cluster, _ := cmd.Flags().GetString("cluster") + serverless, _ := cmd.Flags().GetString("serverless") + job, _ := cmd.Flags().GetString("job") + check, _ := cmd.Flags().GetBool("check") + constraintSource, _ := cmd.Flags().GetString("constraint-source") + + targetFlags := libsdbconnect.TargetFlags{ + Cluster: cluster, + Serverless: serverless, + Job: job, + } + // ValidateTargetFlags is kept despite MarkFlagsMutuallyExclusive above: + // it also validates the library path (no Cobra equivalent) and guards + // non-Cobra call paths such as tests that invoke runPipeline directly. + if err := libsdbconnect.ValidateTargetFlags(targetFlags); err != nil { + return err + } + + // Resolve constraint base URL: flag → env var → default constant. + constraintBaseURL := resolveConstraintBaseURL(ctx, constraintSource) + + projectDir, err := os.Getwd() + if err != nil { + return err + } + + cacheDir, err := os.UserCacheDir() + if err != nil { + return err + } + cacheDir = filepath.Join(cacheDir, "databricks", "dbconnect") + + bt := bundleTarget(cmd) + + w := cmdctx.WorkspaceClient(ctx) + p := &libsdbconnect.Pipeline{ + Mode: mode, + Check: check, + ProjectDir: projectDir, + ConstraintBaseURL: constraintBaseURL, + CacheDir: cacheDir, + Flags: targetFlags, + Compute: sdkCompute{w: w}, + Bundle: bt, + PM: libsdbconnect.NewUvManager(), + } + + res, pipelineErr := p.Run(ctx) + return renderResult(ctx, cmd, res, pipelineErr) +} + +// resolveConstraintBaseURL returns the constraint base URL using ordered precedence: +// flag → env var → default constant. +func resolveConstraintBaseURL(ctx context.Context, flagValue string) string { + if flagValue != "" { + return flagValue + } + if v, ok := env.Lookup(ctx, envConstraintSource); ok { + return v + } + return defaultConstraintBaseURL +} + +// bundleTarget reads the active bundle (if any) and maps its compute configuration +// to a libsdbconnect.BundleTarget. +// +// Only the top-level bundle.cluster_id field is consulted here; serverless is not +// recorded in the bundle config, so Selected=true is set only when a cluster ID is +// present. If the bundle is absent or has no cluster_id, Selected=false is returned +// so the pipeline falls through to requiring an explicit flag. +// +// TODO: extend once bundle config exposes a serverless field at the bundle level. +func bundleTarget(cmd *cobra.Command) libsdbconnect.BundleTarget { + b := root.TryConfigureBundle(cmd) + if b == nil { + return libsdbconnect.BundleTarget{Selected: false} + } + clusterID := b.Config.Bundle.ClusterId + if clusterID == "" { + return libsdbconnect.BundleTarget{Selected: false} + } + return libsdbconnect.BundleTarget{ + ClusterID: clusterID, + Selected: true, + } +} diff --git a/cmd/dbconnect/output.go b/cmd/dbconnect/output.go new file mode 100644 index 00000000000..57c08f026c2 --- /dev/null +++ b/cmd/dbconnect/output.go @@ -0,0 +1,83 @@ +package dbconnect + +import ( + "context" + "errors" + "fmt" + "path/filepath" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + libsdbconnect "github.com/databricks/cli/libs/dbconnect" + "github.com/databricks/cli/libs/flags" + "github.com/spf13/cobra" +) + +// renderResult renders the pipeline result to the command's output. +// In JSON mode it renders the full structured result (even on error). +// In text mode it prints phase headers and a summary, then returns the error. +func renderResult(ctx context.Context, cmd *cobra.Command, res *libsdbconnect.Result, pipelineErr error) error { + // Guard against a nil result (e.g. pipeline failed before constructing one). + // Always emit a structured object in JSON mode so callers can rely on the schema. + if res == nil { + res = &libsdbconnect.Result{} + if pipelineErr != nil { + if pe, ok := errors.AsType[*libsdbconnect.PipelineError](pipelineErr); ok { + res.Error = pe + } else { + res.Error = libsdbconnect.NewError(libsdbconnect.ErrProvisionFailed, pipelineErr, "%s", pipelineErr.Error()) + } + } + } + + if root.OutputType(cmd) == flags.OutputJSON { + if err := cmdio.Render(ctx, res); err != nil { + return err + } + // The JSON object is the only thing written to stdout. On failure we still + // need a non-zero exit, but returning pipelineErr would make the root print + // "Error: ..." to stderr. ErrAlreadyPrinted exits non-zero without that. + if pipelineErr != nil { + return root.ErrAlreadyPrinted + } + return nil + } + + // Text mode: print phase headers. + for i, phase := range res.Phases { + cmdio.LogString(ctx, fmt.Sprintf("=== Phase %d: %s ===", i+1, phase.Name)) + if phase.Detail != "" { + cmdio.LogString(ctx, fmt.Sprintf(" status=%s %s", phase.Status, phase.Detail)) + } else { + cmdio.LogString(ctx, " status="+phase.Status) + } + } + + if pipelineErr != nil { + cmdio.LogString(ctx, "For more detail, re-run with --debug, or --output json to share a structured report.") + return pipelineErr + } + + // Print a final success / check summary. + if res.Check { + if res.Plan != nil { + cmdio.LogString(ctx, "Plan: "+filepath.ToSlash(res.Plan.PyprojectPath)) + if len(res.Plan.ChangedRegions) > 0 { + for _, region := range res.Plan.ChangedRegions { + cmdio.LogString(ctx, " changed region: "+region) + } + } + } + cmdio.LogString(ctx, "Check complete. No files were modified.") + return nil + } + + if res.Result != nil { + cmdio.LogString(ctx, fmt.Sprintf("Success: python=%s databricks-connect=%s venv=%s", + res.Result.PythonVersion, + res.Result.DatabricksConnectInstalled, + filepath.ToSlash(res.Result.VenvPath), + )) + } + return nil +} diff --git a/cmd/dbconnect/sync.go b/cmd/dbconnect/sync.go new file mode 100644 index 00000000000..8e23fda5f4a --- /dev/null +++ b/cmd/dbconnect/sync.go @@ -0,0 +1,20 @@ +package dbconnect + +import ( + "github.com/databricks/cli/cmd/root" + libsdbconnect "github.com/databricks/cli/libs/dbconnect" + "github.com/spf13/cobra" +) + +func newSyncCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "sync", + Short: "Merge managed dependencies into an existing pyproject.toml and re-provision", + } + cmd.PreRunE = root.MustWorkspaceClient + addTargetFlags(cmd) + cmd.RunE = func(cmd *cobra.Command, args []string) error { + return runPipeline(cmd, libsdbconnect.ModeSync) + } + return cmd +} diff --git a/libs/dbconnect/constraints.go b/libs/dbconnect/constraints.go new file mode 100644 index 00000000000..517db457dc9 --- /dev/null +++ b/libs/dbconnect/constraints.go @@ -0,0 +1,143 @@ +package dbconnect + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/BurntSushi/toml" + "github.com/databricks/cli/libs/log" +) + +// Constraints holds the parsed contents of a per-environment pyproject.toml. +type Constraints struct { + // EnvKey is the environment key used to look up the constraints. + EnvKey string + // SourceURL is the URL from which the constraints were fetched. + SourceURL string + // FromCache is true when the data came from the on-disk cache rather than a live fetch. + FromCache bool + // RequiresPython is the PEP 440 python version specifier from [project].requires-python. + RequiresPython string + // DatabricksConnect is the full dependency string for databricks-connect from [dependency-groups].dev. + DatabricksConnect string + // ConstraintDeps is the list of entries from [tool.uv].constraint-dependencies. + ConstraintDeps []string +} + +// sanitizeEnvKey replaces path separators with double-underscores to produce a flat filename. +func sanitizeEnvKey(envKey string) string { + return strings.ReplaceAll(envKey, "/", "__") +} + +// FetchConstraints fetches the pyproject.toml for envKey from baseURL, caches it in cacheDir, +// and falls back to the cached copy on network or HTTP errors. +// +// Constraint files are hosted at: +// https://github.com/pietern/databricks-environments +func FetchConstraints(ctx context.Context, baseURL, envKey, cacheDir string) (*Constraints, error) { + url := baseURL + "/" + envKey + "/pyproject.toml" + cachePath := filepath.Join(cacheDir, sanitizeEnvKey(envKey)+".toml") + + data, fetchErr := fetchURL(ctx, url) + if fetchErr == nil { + // Write the cache copy; non-fatal so a read-only cacheDir doesn't break the command. + if err := os.WriteFile(cachePath, data, 0o600); err != nil { + log.Debugf(ctx, "failed to write constraint cache %s: %v", filepath.ToSlash(cachePath), err) + } + rp, dbc, deps, err := parseConstraints(data) + if err != nil { + return nil, fmt.Errorf("parse constraints for %s: %w", envKey, err) + } + return &Constraints{ + EnvKey: envKey, + SourceURL: url, + FromCache: false, + RequiresPython: rp, + DatabricksConnect: dbc, + ConstraintDeps: deps, + }, nil + } + + // Network or HTTP failure: attempt to serve from cache. + cached, readErr := os.ReadFile(cachePath) + if readErr != nil { + return nil, NewError(ErrConstraintFetchFailed, fetchErr, "fetch constraints for %s", envKey) + } + + log.Warnf(ctx, "constraint fetch failed, using cached copy: %v", fetchErr) + rp, dbc, deps, err := parseConstraints(cached) + if err != nil { + return nil, fmt.Errorf("parse cached constraints for %s: %w", envKey, err) + } + return &Constraints{ + EnvKey: envKey, + SourceURL: url, + FromCache: true, + RequiresPython: rp, + DatabricksConnect: dbc, + ConstraintDeps: deps, + }, nil +} + +// fetchURL performs an HTTP GET and returns the body bytes, or an error on non-2xx or transport failure. +func fetchURL(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("build request for %s: %w", url, err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("GET %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("GET %s: unexpected status %s", url, resp.Status) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read body from %s: %w", url, err) + } + return data, nil +} + +// pyprojectTOML mirrors the pyproject.toml fields we care about. +type pyprojectTOML struct { + Project struct { + RequiresPython string `toml:"requires-python"` + } `toml:"project"` + DependencyGroups struct { + Dev []string `toml:"dev"` + } `toml:"dependency-groups"` + Tool struct { + UV struct { + ConstraintDependencies []string `toml:"constraint-dependencies"` + } `toml:"uv"` + } `toml:"tool"` +} + +// parseConstraints parses a pyproject.toml byte slice and extracts requires-python, +// the databricks-connect entry from dependency-groups.dev, and constraint-dependencies. +func parseConstraints(data []byte) (requiresPython, dbconnect string, deps []string, err error) { + var p pyprojectTOML + if err = toml.Unmarshal(data, &p); err != nil { + return "", "", nil, fmt.Errorf("unmarshal pyproject.toml: %w", err) + } + + requiresPython = p.Project.RequiresPython + + for _, entry := range p.DependencyGroups.Dev { + // Despace before matching so whitespace variants like "databricks-connect ~=17" also match. + if strings.HasPrefix(strings.ReplaceAll(entry, " ", ""), "databricks-connect") { + dbconnect = entry + break + } + } + + deps = p.Tool.UV.ConstraintDependencies + return requiresPython, dbconnect, deps, nil +} diff --git a/libs/dbconnect/constraints_test.go b/libs/dbconnect/constraints_test.go new file mode 100644 index 00000000000..9a5275e0e6d --- /dev/null +++ b/libs/dbconnect/constraints_test.go @@ -0,0 +1,64 @@ +package dbconnect + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const sampleToml = `[project] +requires-python = "==3.12.*" + +[dependency-groups] +dev = [ + "databricks-connect~=17.2.0", + "pytest~=8.0", +] + +[tool.uv] +constraint-dependencies = [ + "pydantic~=2.10.6", + "anyio~=4.6.2", +] +` + +func TestParseConstraints(t *testing.T) { + rp, dbc, deps, err := parseConstraints([]byte(sampleToml)) + require.NoError(t, err) + assert.Equal(t, "==3.12.*", rp) + assert.Equal(t, "databricks-connect~=17.2.0", dbc) + assert.Equal(t, []string{"pydantic~=2.10.6", "anyio~=4.6.2"}, deps) +} + +func TestFetchConstraintsHTTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/serverless/serverless-v4/pyproject.toml", r.URL.Path) + _, _ = w.Write([]byte(sampleToml)) + })) + defer srv.Close() + + c, err := FetchConstraints(t.Context(), srv.URL, "serverless/serverless-v4", t.TempDir()) + require.NoError(t, err) + assert.False(t, c.FromCache) + assert.Equal(t, "databricks-connect~=17.2.0", c.DatabricksConnect) + assert.Len(t, c.ConstraintDeps, 2) +} + +func TestFetchConstraintsFallsBackToCache(t *testing.T) { + cacheDir := t.TempDir() + // First, a successful fetch populates the cache. + good := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(sampleToml)) + })) + _, err := FetchConstraints(t.Context(), good.URL, "serverless/serverless-v4", cacheDir) + require.NoError(t, err) + good.Close() + + // Now the server is down; fetch must serve the cache. + c, err := FetchConstraints(t.Context(), good.URL, "serverless/serverless-v4", cacheDir) + require.NoError(t, err) + assert.True(t, c.FromCache) +} diff --git a/libs/dbconnect/envkey.go b/libs/dbconnect/envkey.go new file mode 100644 index 00000000000..3f53bb2b94e --- /dev/null +++ b/libs/dbconnect/envkey.go @@ -0,0 +1,29 @@ +package dbconnect + +import ( + "fmt" + "regexp" + "strings" +) + +var pythonVersionRe = regexp.MustCompile(`(\d+)\.(\d+)`) + +// EnvKeyForServerless returns the environment key for a serverless version. +func EnvKeyForServerless(version string) string { + normalized := strings.TrimPrefix(strings.ToLower(version), "v") + return "serverless/serverless-v" + normalized +} + +// EnvKeyForSparkVersion returns the environment key for a Spark version. +func EnvKeyForSparkVersion(sparkVersion string) string { + return "dbr/" + sparkVersion +} + +// PythonMinorFromRequires parses a PEP 440 requires-python string and extracts MAJOR.MINOR. +func PythonMinorFromRequires(requiresPython string) (string, error) { + match := pythonVersionRe.FindStringSubmatch(requiresPython) + if match == nil { + return "", fmt.Errorf("cannot parse python version from %q", requiresPython) + } + return fmt.Sprintf("%s.%s", match[1], match[2]), nil +} diff --git a/libs/dbconnect/envkey_test.go b/libs/dbconnect/envkey_test.go new file mode 100644 index 00000000000..4c8e368d303 --- /dev/null +++ b/libs/dbconnect/envkey_test.go @@ -0,0 +1,34 @@ +package dbconnect + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnvKeyForServerless(t *testing.T) { + for _, in := range []string{"4", "v4", "V4"} { + assert.Equal(t, "serverless/serverless-v4", EnvKeyForServerless(in)) + } +} + +func TestEnvKeyForSparkVersion(t *testing.T) { + assert.Equal(t, "dbr/15.4.x-scala2.12", EnvKeyForSparkVersion("15.4.x-scala2.12")) +} + +func TestPythonMinorFromRequires(t *testing.T) { + cases := map[string]string{ + "==3.12.*": "3.12", + ">=3.12": "3.12", + "==3.12.3": "3.12", + "~=3.11": "3.11", + } + for in, want := range cases { + got, err := PythonMinorFromRequires(in) + require.NoError(t, err) + assert.Equal(t, want, got) + } + _, err := PythonMinorFromRequires("garbage") + assert.Error(t, err) +} diff --git a/libs/dbconnect/merge.go b/libs/dbconnect/merge.go new file mode 100644 index 00000000000..e6c475ebc92 --- /dev/null +++ b/libs/dbconnect/merge.go @@ -0,0 +1,354 @@ +package dbconnect + +import ( + "fmt" + "regexp" + "strings" +) + +// managedMarkerStart and managedMarkerEnd bracket the region of pyproject.toml that +// "databricks dbconnect" owns. Everything between them is rewritten on each merge; +// everything outside is preserved byte-for-byte. +const ( + managedMarkerStart = "# managed by databricks dbconnect — do not edit" + managedMarkerEnd = "# end managed by databricks dbconnect" +) + +// Region names reported back to the caller via MergeManaged's regions return value. +const ( + regionRequiresPython = "requires-python" + regionDatabricksConnect = "databricks-connect" + regionToolUv = "tool.uv.constraint-dependencies" +) + +var ( + // tableHeaderRe matches a TOML table header line such as "[project]" or "[tool.uv]". + tableHeaderRe = regexp.MustCompile(`^\s*\[[^\]]+\]\s*$`) + // requiresPythonRe captures the leading whitespace of a requires-python assignment so it + // can be preserved when the value is replaced. + requiresPythonRe = regexp.MustCompile(`^(\s*)requires-python\s*=`) +) + +// MergeManaged applies the three managed transforms to target, preserving every other +// byte (comments, ordering, whitespace). It returns the merged bytes and the list of +// regions that actually changed. The operation is idempotent: feeding its own output +// back in produces identical bytes. +func MergeManaged(target []byte, c Constraints) (merged []byte, regions []string, err error) { + s := string(target) + + // Detect and normalize line endings. We process on "\n" and restore "\r\n" on exit. + crlf := strings.Contains(s, "\r\n") + if crlf { + s = strings.ReplaceAll(s, "\r\n", "\n") + } + + lines := strings.Split(s, "\n") + + lines, rpChanged := mergeRequiresPython(lines, c.RequiresPython) + if rpChanged { + regions = append(regions, regionRequiresPython) + } + + lines, dbcChanged := mergeDatabricksConnect(lines, c.DatabricksConnect) + if dbcChanged { + regions = append(regions, regionDatabricksConnect) + } + + lines, uvChanged := mergeToolUv(lines, c.ConstraintDeps) + if uvChanged { + regions = append(regions, regionToolUv) + } + + out := strings.Join(lines, "\n") + if crlf { + out = strings.ReplaceAll(out, "\n", "\r\n") + } + return []byte(out), regions, nil +} + +// tableBounds returns the line index of the header matching name (e.g. "[project]") and +// the index of the first line after the table body (the next table header or EOF). If the +// table is absent, found is false. +func tableBounds(lines []string, name string) (header, end int, found bool) { + header = -1 + for i, line := range lines { + if strings.TrimSpace(line) == name { + header = i + break + } + } + if header == -1 { + return -1, -1, false + } + end = len(lines) + for i := header + 1; i < len(lines); i++ { + if tableHeaderRe.MatchString(lines[i]) { + end = i + break + } + } + return header, end, true +} + +// mergeRequiresPython replaces the value of requires-python within [project], preserving +// the line's leading whitespace. If the key is absent, it is inserted directly under the +// [project] header. Returns whether the line slice changed. +func mergeRequiresPython(lines []string, value string) ([]string, bool) { + header, end, found := tableBounds(lines, "[project]") + if !found { + return lines, false + } + + want := func(indent string) string { + return fmt.Sprintf(`%srequires-python = "%s"`, indent, value) + } + + for i := header + 1; i < end; i++ { + m := requiresPythonRe.FindStringSubmatch(lines[i]) + if m == nil { + continue + } + replacement := want(m[1]) + if lines[i] == replacement { + return lines, false + } + lines[i] = replacement + return lines, true + } + + // Key absent: insert directly under the [project] header. + inserted := make([]string, 0, len(lines)+1) + inserted = append(inserted, lines[:header+1]...) + inserted = append(inserted, want("")) + inserted = append(inserted, lines[header+1:]...) + return inserted, true +} + +// dbconnectLineRe captures, for a line holding a databricks-connect dependency element: +// (1) the leading whitespace, and (3) any trailing comma (with optional trailing space), +// so that indentation and comma style are preserved when the quoted token is replaced. +var dbconnectLineRe = regexp.MustCompile(`^(\s*)"databricks-connect[^"]*"(\s*,?\s*)$`) + +// mergeDatabricksConnect replaces the databricks-connect element inside +// [dependency-groups].dev. It handles both the multi-line array form (one element per +// line) and the single-line array form (dev = ["databricks-connect~=..."]). +func mergeDatabricksConnect(lines []string, value string) ([]string, bool) { + header, end, found := tableBounds(lines, "[dependency-groups]") + if !found { + return lines, false + } + + for i := header + 1; i < end; i++ { + // Multi-line element form: a standalone line holding only the quoted token. + if m := dbconnectLineRe.FindStringSubmatch(lines[i]); m != nil { + replacement := fmt.Sprintf(`%s"%s"%s`, m[1], value, m[2]) + if lines[i] == replacement { + return lines, false + } + lines[i] = replacement + return lines, true + } + // Single-line array form: replace the quoted databricks-connect token in place. + if strings.Contains(lines[i], `"databricks-connect`) { + replaced := dbconnectTokenRe.ReplaceAllString(lines[i], `"`+value+`"`) + if replaced == lines[i] { + return lines, false + } + lines[i] = replaced + return lines, true + } + } + return lines, false +} + +// dbconnectTokenRe matches a quoted databricks-connect element anywhere in a line, used +// for the single-line array form. +var dbconnectTokenRe = regexp.MustCompile(`"databricks-connect[^"]*"`) + +// mergeToolUv rewrites the managed [tool.uv] constraint-dependencies block. If a +// marker-bracketed block already exists, its contents are replaced in place. Otherwise any +// plain [tool.uv] table is removed and a fresh marker-bracketed block is appended at EOF. +func mergeToolUv(lines, deps []string) ([]string, bool) { + block := renderToolUvBlock(deps) + + start, stop, found := markerBounds(lines) + if found { + existing := lines[start : stop+1] + if equalLines(existing, block) { + return lines, false + } + out := make([]string, 0, len(lines)-(stop-start+1)+len(block)) + out = append(out, lines[:start]...) + out = append(out, block...) + out = append(out, lines[stop+1:]...) + return out, true + } + + // No managed block: reconcile any plain [tool.uv] table, then append a fresh managed + // block at EOF. If the table is effectively ours (its only meaningful key is + // constraint-dependencies, from a pre-marker run), drop it whole. Otherwise the table + // holds user-authored keys, so we preserve it and strip only our constraint-dependencies. + if header, end, ok := tableBounds(lines, "[tool.uv]"); ok { + if toolUvHasOnlyConstraintDeps(lines, header, end) { + out := make([]string, 0, len(lines)) + out = append(out, lines[:header]...) + out = append(out, lines[end:]...) + lines = out + } else { + lines = removeConstraintDeps(lines, header, end) + } + } + + lines = appendManagedBlock(lines, block) + return lines, true +} + +// constraintDepsRe matches the start of a constraint-dependencies assignment within a +// [tool.uv] table, capturing its leading whitespace. +var constraintDepsRe = regexp.MustCompile(`^\s*constraint-dependencies\s*=`) + +// toolUvHasOnlyConstraintDeps reports whether the [tool.uv] table body spanning +// (header, end) contains no meaningful key other than constraint-dependencies. Blank lines +// and comment-only lines are ignored when deciding "only". +func toolUvHasOnlyConstraintDeps(lines []string, header, end int) bool { + for i := header + 1; i < end; i++ { + trimmed := strings.TrimSpace(lines[i]) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + if !constraintDepsRe.MatchString(lines[i]) { + return false + } + // Multi-line array form: skip the continuation lines through the closing "]" + // so the whole managed key counts as ignorable (mirrors removeConstraintDeps). + // The single-line form already holds the "]" and does not advance i. + if !strings.Contains(lines[i], "]") { + for i++; i < end; i++ { + if strings.TrimSpace(lines[i]) == "]" { + break + } + } + } + } + return true +} + +// removeConstraintDeps strips a constraint-dependencies key from the [tool.uv] table body +// spanning (header, end), leaving the table header and all other user keys in place. It +// handles both the single-line array form and the multi-line array form (the value spans +// several lines until a line whose trimmed content is "]"). +func removeConstraintDeps(lines []string, header, end int) []string { + for i := header + 1; i < end; i++ { + if !constraintDepsRe.MatchString(lines[i]) { + continue + } + // Multi-line array form: extend through the closing "]" line. The single-line form + // already contains the closing bracket, so this loop does not advance. + end2 := i + 1 + if !strings.Contains(lines[i], "]") { + for j := i + 1; j < end; j++ { + end2 = j + 1 + if strings.TrimSpace(lines[j]) == "]" { + break + } + } + } + out := make([]string, 0, len(lines)-(end2-i)) + out = append(out, lines[:i]...) + out = append(out, lines[end2:]...) + return out + } + return lines +} + +// markerBounds returns the indices of the managed marker start and end lines, if present. +func markerBounds(lines []string) (start, stop int, found bool) { + start, stop = -1, -1 + for i, line := range lines { + if strings.TrimSpace(line) == managedMarkerStart { + start = i + break + } + } + if start == -1 { + return -1, -1, false + } + for i := start + 1; i < len(lines); i++ { + if strings.TrimSpace(lines[i]) == managedMarkerEnd { + stop = i + break + } + } + if stop == -1 { + return -1, -1, false + } + return start, stop, true +} + +// renderToolUvBlock builds the marker-bracketed [tool.uv] block lines (no surrounding +// blank lines). +func renderToolUvBlock(deps []string) []string { + block := []string{ + managedMarkerStart, + "[tool.uv]", + "constraint-dependencies = [", + } + for _, d := range deps { + block = append(block, fmt.Sprintf(" %q,", d)) + } + block = append(block, "]", managedMarkerEnd) + return block +} + +// appendManagedBlock appends block to lines, ensuring exactly one blank line separates it +// from prior content and the file ends with a single trailing newline. +func appendManagedBlock(lines, block []string) []string { + // strings.Split on a trailing "\n" leaves a final empty element; drop trailing empty + // lines so we control the spacing precisely. + for len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + + out := make([]string, 0, len(lines)+len(block)+2) + out = append(out, lines...) + if len(out) > 0 { + out = append(out, "") // exactly one blank line before the managed block + } + out = append(out, block...) + out = append(out, "") // trailing newline after final join + return out +} + +// equalLines reports whether two line slices are identical. +func equalLines(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// RenderFreshPyproject produces a complete managed pyproject.toml for a project that has +// none, with [project], [dependency-groups].dev (carrying the databricks-connect pin), and +// the marker-bracketed [tool.uv] constraint block. +func RenderFreshPyproject(projectName string, c Constraints) []byte { + var b strings.Builder + b.WriteString("[project]\n") + fmt.Fprintf(&b, "name = %q\n", projectName) + fmt.Fprintf(&b, "requires-python = %q\n", c.RequiresPython) + b.WriteString("\n") + b.WriteString("[dependency-groups]\n") + b.WriteString("dev = [\n") + fmt.Fprintf(&b, " %q,\n", c.DatabricksConnect) + b.WriteString("]\n") + b.WriteString("\n") + for _, line := range renderToolUvBlock(c.ConstraintDeps) { + b.WriteString(line) + b.WriteString("\n") + } + return []byte(b.String()) +} diff --git a/libs/dbconnect/merge_test.go b/libs/dbconnect/merge_test.go new file mode 100644 index 00000000000..caf8ce3ad43 --- /dev/null +++ b/libs/dbconnect/merge_test.go @@ -0,0 +1,248 @@ +package dbconnect + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testConstraints() Constraints { + return Constraints{ + RequiresPython: "==3.12.*", + DatabricksConnect: "databricks-connect~=17.2.0", + ConstraintDeps: []string{"pydantic~=2.10.6", "anyio~=4.6.2"}, + } +} + +func TestMergeReplacesRequiresPythonPreservingComments(t *testing.T) { + in := []byte(`[project] +name = "demo" +# keep this comment +requires-python = ">=3.10" + +[dependency-groups] +dev = [ + "databricks-connect~=16.0.0", + "pytest~=8.0", +] +`) + out, regions, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + assert.Contains(t, string(out), `requires-python = "==3.12.*"`) + assert.Contains(t, string(out), "# keep this comment") + assert.Contains(t, string(out), `"databricks-connect~=17.2.0",`) + assert.Contains(t, string(out), `"pytest~=8.0",`) + assert.Contains(t, regions, "requires-python") + assert.Contains(t, regions, "databricks-connect") + assert.Contains(t, regions, "tool.uv.constraint-dependencies") + assert.Contains(t, string(out), "pydantic~=2.10.6") +} + +func TestMergeIsIdempotent(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = [ + "databricks-connect~=16.0.0", +] +`) + once, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + twice, _, err := MergeManaged(once, testConstraints()) + require.NoError(t, err) + assert.Equal(t, string(once), string(twice)) +} + +func TestMergeInsertsRequiresPythonWhenMissing(t *testing.T) { + in := []byte(`[project] +name = "demo" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + assert.Contains(t, string(out), `requires-python = "==3.12.*"`) +} + +func TestMergeReplacesExistingManagedToolUvBlock(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] + +` + managedMarkerStart + ` +[tool.uv] +constraint-dependencies = [ + "stale~=1.0.0", +] +` + managedMarkerEnd + ` +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + assert.NotContains(t, string(out), "stale~=1.0.0") + assert.Contains(t, string(out), "pydantic~=2.10.6") + // Only one managed block remains. + assert.Equal(t, 1, countOccurrences(string(out), managedMarkerStart)) +} + +func TestMergePreservesCRLF(t *testing.T) { + in := []byte("[project]\r\nrequires-python = \">=3.10\"\r\n\r\n[dependency-groups]\r\ndev = [\"databricks-connect~=16.0.0\"]\r\n") + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + assert.Contains(t, string(out), "\r\n") + assert.Contains(t, string(out), `requires-python = "==3.12.*"`) + // Merging the CRLF output again must be byte-identical (idempotent under \r\n). + twice, _, err := MergeManaged(out, testConstraints()) + require.NoError(t, err) + assert.Equal(t, string(out), string(twice)) +} + +func TestMergePreservesUserToolUvKeys(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] + +[tool.uv] +package = true +dev-dependencies = ["ruff"] +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + s := string(out) + assert.Contains(t, s, "[tool.uv]") + assert.Contains(t, s, "package = true") + assert.Contains(t, s, `dev-dependencies = ["ruff"]`) + assert.Contains(t, s, managedMarkerStart) + assert.Contains(t, s, "pydantic~=2.10.6") + // The user's keys must live outside the managed marker block. + start := strings.Index(s, managedMarkerStart) + require.GreaterOrEqual(t, start, 0) + assert.NotContains(t, s[start:], "package = true") + assert.NotContains(t, s[start:], `dev-dependencies = ["ruff"]`) +} + +func TestMergeStripsStaleConstraintDepsFromUserToolUv(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] + +[tool.uv] +package = true +constraint-dependencies = ["old~=1.0"] +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + s := string(out) + assert.Contains(t, s, "package = true") + // The stale constraint must be gone from the user table; the managed block has the new deps. + assert.NotContains(t, s, "old~=1.0") + assert.Contains(t, s, "pydantic~=2.10.6") + // Merge-twice is byte-identical. + twice, _, err := MergeManaged(out, testConstraints()) + require.NoError(t, err) + assert.Equal(t, string(out), string(twice)) +} + +func TestMergeRemovesOwnedOnlyToolUv(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] + +[tool.uv] +constraint-dependencies = ["old~=1.0"] +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + s := string(out) + assert.NotContains(t, s, "old~=1.0") + assert.Contains(t, s, "pydantic~=2.10.6") + // The plain table was removed and replaced by exactly one managed block. + assert.Equal(t, 1, countOccurrences(s, "[tool.uv]")) + assert.Equal(t, 1, countOccurrences(s, managedMarkerStart)) +} + +func TestMergeRemovesOwnedOnlyMultiLineToolUv(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] + +[tool.uv] +constraint-dependencies = [ + "old~=1.0", +] +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + s := string(out) + assert.NotContains(t, s, "old~=1.0") + assert.Contains(t, s, "pydantic~=2.10.6") + // The multi-line owned-only table was removed whole, leaving exactly one + // [tool.uv] (inside the managed block) and no stray empty header. + assert.Equal(t, 1, countOccurrences(s, "[tool.uv]")) + assert.Equal(t, 1, countOccurrences(s, managedMarkerStart)) + // Merge-twice is byte-identical. + twice, _, err := MergeManaged(out, testConstraints()) + require.NoError(t, err) + assert.Equal(t, string(out), string(twice)) +} + +func TestMergeReplacesSingleLineDevArray(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0", "pytest~=8.0"] +`) + out, regions, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + // Sibling element and single-line array layout are preserved. + assert.Contains(t, string(out), `dev = ["databricks-connect~=17.2.0", "pytest~=8.0"]`) + assert.Contains(t, regions, "databricks-connect") +} + +func TestMergePreservesMultiLineTrailingComma(t *testing.T) { + in := []byte(`[project] +requires-python = ">=3.10" + +[dependency-groups] +dev = [ + "databricks-connect~=16.0.0", +] +`) + out, _, err := MergeManaged(in, testConstraints()) + require.NoError(t, err) + // The trailing comma on the managed element is preserved. + assert.Contains(t, string(out), ` "databricks-connect~=17.2.0",`) +} + +func TestRenderFreshPyproject(t *testing.T) { + out := RenderFreshPyproject("demo", testConstraints()) + s := string(out) + assert.Contains(t, s, `name = "demo"`) + assert.Contains(t, s, `requires-python = "==3.12.*"`) + assert.Contains(t, s, `"databricks-connect~=17.2.0",`) + assert.Contains(t, s, managedMarkerStart) + assert.Contains(t, s, managedMarkerEnd) + assert.Contains(t, s, "pydantic~=2.10.6") + // A fresh render is itself a no-op under MergeManaged (already fully managed). + merged, _, err := MergeManaged(out, testConstraints()) + require.NoError(t, err) + assert.Equal(t, s, string(merged)) +} + +func countOccurrences(s, substr string) int { + return strings.Count(s, substr) +} diff --git a/libs/dbconnect/pipeline.go b/libs/dbconnect/pipeline.go new file mode 100644 index 00000000000..3df333cefea --- /dev/null +++ b/libs/dbconnect/pipeline.go @@ -0,0 +1,467 @@ +package dbconnect + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/log" + "github.com/hexops/gotextdiff" + "github.com/hexops/gotextdiff/myers" + "github.com/hexops/gotextdiff/span" +) + +// Pipeline orchestrates the dbconnect init/sync phases against a project directory. +type Pipeline struct { + Mode Mode + Check bool + ProjectDir string + ConstraintBaseURL string + CacheDir string + Flags TargetFlags + Bundle BundleTarget + Compute ComputeClient + PM PackageManager +} + +// Run executes all pipeline phases in order and returns a fully populated Result. +// On a phase error, Result.Error is set and the same error is also returned. +func (p *Pipeline) Run(ctx context.Context) (*Result, error) { + log.Debugf(ctx, "dbconnect: mode=%s project=%s cacheDir=%s constraintBaseURL=%s flags=%+v", + p.Mode, + filepath.ToSlash(p.ProjectDir), + filepath.ToSlash(p.CacheDir), + p.ConstraintBaseURL, + p.Flags, + ) + + res := &Result{ + Mode: p.Mode.String(), + Check: p.Check, + } + + // Phase 0: ensure the package manager is available. + phase := PhaseResult{Name: "preflight"} + version, err := p.PM.EnsureAvailable(ctx) + if err != nil { + phase.Status = "failed" + phase.Detail = err.Error() + res.Phases = append(res.Phases, phase) + pe := NewError(ErrUvUnavailable, err, "%s unavailable", p.PM.Name()) + res.Error = pe + return res, pe + } + phase.Status = "ok" + phase.Detail = p.PM.Name() + " " + version + res.Phases = append(res.Phases, phase) + + // Phase 1: resolve the compute target. + target, err := p.resolve(ctx, res) + if err != nil { + return res, err + } + + // Phase 2: fetch constraints. + c, err := p.fetch(ctx, res, target) + if err != nil { + return res, err + } + + // Phase 2b: fill in the python version on the target info from the constraints. + pyMinor, err := PythonMinorFromRequires(c.RequiresPython) + if err != nil { + pe := NewError(ErrValidationFailed, err, "failed to parse python version from constraints") + res.Phases = append(res.Phases, PhaseResult{Name: "parse-python-version", Status: "failed", Detail: pe.Error()}) + res.Error = pe + return res, pe + } + res.Phases = append(res.Phases, PhaseResult{Name: "parse-python-version", Status: "ok", Detail: pyMinor}) + target.PythonVersion = pyMinor + + // Phase 3: compute the merge plan (in-memory, no disk writes yet). + plan, mergedBytes, err := p.mergePlan(ctx, res, c) + if err != nil { + return res, err + } + res.Plan = plan + + // Check mode stops here — phases 4+ mutate disk. + if p.Check { + return res, nil + } + + // Phase 4: write the merged content to disk (mode-specific backup/restore). + if err := p.applyMerge(ctx, res, mergedBytes); err != nil { + return res, err + } + + // Phase 5: ensure the required Python version is installed. + if err := p.ensurePython(ctx, res, pyMinor); err != nil { + return res, err + } + + // Phase 6: provision the virtual environment. + if err := p.provision(ctx, res); err != nil { + return res, err + } + + // Phase 7: post-provision (pip seed). + if err := p.postProvision(ctx, res); err != nil { + return res, err + } + + // Phase 8: validate the environment. + if err := p.validate(ctx, res, pyMinor, c.DatabricksConnect); err != nil { + return res, err + } + + return res, nil +} + +// resolve runs ResolveTarget and appends a phase result. +func (p *Pipeline) resolve(ctx context.Context, res *Result) (*TargetInfo, error) { + phase := PhaseResult{Name: "resolve"} + target, err := ResolveTarget(ctx, p.Flags, p.Compute, p.Bundle) + if err != nil { + phase.Status = "failed" + phase.Detail = err.Error() + res.Phases = append(res.Phases, phase) + pe, ok := errors.AsType[*PipelineError](err) + if !ok { + pe = NewError(ErrNoTargetSelected, err, "target resolution failed") + } + res.Error = pe + return nil, pe + } + phase.Status = "ok" + phase.Detail = fmt.Sprintf("kind=%s envKey=%s", target.Kind, target.EnvKey) + res.Phases = append(res.Phases, phase) + res.Target = target + return target, nil +} + +// fetch fetches constraints for the resolved target and appends a phase result. +func (p *Pipeline) fetch(ctx context.Context, res *Result, target *TargetInfo) (*Constraints, error) { + phase := PhaseResult{Name: "fetch"} + c, err := FetchConstraints(ctx, p.ConstraintBaseURL, target.EnvKey, p.CacheDir) + if err != nil { + phase.Status = "failed" + phase.Detail = err.Error() + res.Phases = append(res.Phases, phase) + pe, ok := errors.AsType[*PipelineError](err) + if !ok { + pe = NewError(ErrConstraintFetchFailed, err, "fetch constraints failed") + } + res.Error = pe + return nil, pe + } + phase.Status = "ok" + phase.Detail = fmt.Sprintf("source=%s fromCache=%v", c.SourceURL, c.FromCache) + res.Phases = append(res.Phases, phase) + res.Constraints = &ConstraintInfo{ + SourceURL: c.SourceURL, + FromCache: c.FromCache, + RequiresPython: c.RequiresPython, + DatabricksConnect: c.DatabricksConnect, + ConstraintCount: len(c.ConstraintDeps), + } + return c, nil +} + +// pyprojectPath returns the path to pyproject.toml in the project directory. +func (p *Pipeline) pyprojectPath() string { + return filepath.Join(p.ProjectDir, "pyproject.toml") +} + +// backupPath returns the path to the pyproject.toml backup file. +func (p *Pipeline) backupPath() string { + return filepath.Join(p.ProjectDir, "pyproject.toml.bak") +} + +// mergePlan computes the merged pyproject.toml bytes (without writing to disk) +// and builds the Plan with a unified diff. +func (p *Pipeline) mergePlan(_ context.Context, res *Result, c *Constraints) (*Plan, []byte, error) { + phase := PhaseResult{Name: "plan"} + pyproject := p.pyprojectPath() + backup := p.backupPath() + + // Determine base bytes for the merge. For sync with a backup, the backup is + // the canonical base so the merge starts from the original unmanaged state. + var baseBytes []byte + if p.Mode == ModeSync { + if data, err := os.ReadFile(backup); err == nil { + baseBytes = data + } + } + + // Fall back to the current pyproject.toml if no base was found above. + if baseBytes == nil { + if data, err := os.ReadFile(pyproject); err == nil { + baseBytes = data + } + } + + var mergedBytes []byte + var changedRegions []string + + if baseBytes == nil { + // No existing pyproject.toml — render a fresh one. + // Extract the project name from the directory name as a reasonable default. + projectName := filepath.Base(p.ProjectDir) + mergedBytes = RenderFreshPyproject(projectName, *c) + changedRegions = []string{regionRequiresPython, regionDatabricksConnect, regionToolUv} + } else { + var err error + mergedBytes, changedRegions, err = MergeManaged(baseBytes, *c) + if err != nil { + pe := NewError(ErrMergeFailed, err, "merge managed regions failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return nil, nil, pe + } + } + + // Build a unified diff for the plan. + oldStr := "" + newStr := string(mergedBytes) + oldName := "pyproject.toml" + newName := "pyproject.toml" + if baseBytes != nil { + oldStr = string(baseBytes) + oldName = "pyproject.toml" + newName = "pyproject.toml.new" + } + edits := myers.ComputeEdits(span.URIFromPath(oldName), oldStr, newStr) + diff := fmt.Sprint(gotextdiff.ToUnified(oldName, newName, oldStr, edits)) + + plan := &Plan{ + PyprojectPath: pyproject, + BackupPath: backup, + Diff: diff, + ChangedRegions: changedRegions, + } + + phase.Status = "ok" + phase.Detail = "changed=" + strings.Join(changedRegions, ",") + res.Phases = append(res.Phases, phase) + return plan, mergedBytes, nil +} + +// applyMerge writes the merged bytes to disk, performing the mode-specific +// backup or restore first. +func (p *Pipeline) applyMerge(_ context.Context, res *Result, mergedBytes []byte) error { + phase := PhaseResult{Name: "apply"} + pyproject := p.pyprojectPath() + backup := p.backupPath() + + switch p.Mode { + case ModeInit: + // Back up only if a pyproject.toml already exists. + if _, err := os.Stat(pyproject); err == nil { + if err := copyFile(pyproject, backup); err != nil { + pe := NewError(ErrMergeFailed, err, "backup pyproject.toml failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + } + case ModeSync: + if _, err := os.Stat(backup); err != nil { + // No backup yet — create one from the current pyproject.toml. + if _, statErr := os.Stat(pyproject); statErr == nil { + if err := copyFile(pyproject, backup); err != nil { + pe := NewError(ErrMergeFailed, err, "backup pyproject.toml failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + } + } + // When a backup already exists, mergePlan already used it as the base — no + // additional restore step is needed here. + } + + if err := os.WriteFile(pyproject, mergedBytes, 0o644); err != nil { + pe := NewError(ErrMergeFailed, err, "write pyproject.toml failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + + phase.Status = "ok" + res.Phases = append(res.Phases, phase) + return nil +} + +// ensurePython ensures the required Python version is installed. +func (p *Pipeline) ensurePython(ctx context.Context, res *Result, pyMinor string) error { + phase := PhaseResult{Name: "ensure-python"} + if err := p.PM.EnsurePython(ctx, pyMinor); err != nil { + pe := NewError(ErrProvisionFailed, err, "ensure python %s failed", pyMinor) + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + phase.Status = "ok" + phase.Detail = pyMinor + res.Phases = append(res.Phases, phase) + return nil +} + +// provision installs project dependencies into the virtual environment. +func (p *Pipeline) provision(ctx context.Context, res *Result) error { + phase := PhaseResult{Name: "provision"} + if err := p.PM.Provision(ctx, p.ProjectDir); err != nil { + pe := NewError(ErrProvisionFailed, err, "provision failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + phase.Status = "ok" + res.Phases = append(res.Phases, phase) + return nil +} + +// postProvision seeds pip into the virtual environment. +func (p *Pipeline) postProvision(ctx context.Context, res *Result) error { + phase := PhaseResult{Name: "post-provision"} + if err := p.PM.PostProvision(ctx, p.ProjectDir); err != nil { + pe := NewError(ErrProvisionFailed, err, "post-provision failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + phase.Status = "ok" + res.Phases = append(res.Phases, phase) + return nil +} + +// validate reads the Python and databricks-connect versions from the venv and +// populates Result.Result. +func (p *Pipeline) validate(ctx context.Context, res *Result, expectedPyMinor, dbcPin string) error { + phase := PhaseResult{Name: "validate"} + pyVer, dbcVer, err := p.PM.Validate(ctx, p.ProjectDir) + if err != nil { + pe := NewError(ErrValidationFailed, err, "validation failed") + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + + // Assert the installed Python minor matches the target. + if pyVer != expectedPyMinor { + pe := NewError(ErrValidationFailed, nil, + "python version mismatch: want %s, got %s", expectedPyMinor, pyVer) + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + + // Assert the installed databricks-connect major matches the pin's major. + // dbcPin is e.g. "databricks-connect~=17.2.0"; dbcVer is e.g. "17.2.0". + if dbcPin != "" { + pinMajor := dbcMajorFromPin(dbcPin) + if pinMajor == "" { + pe := NewError(ErrValidationFailed, nil, + "cannot determine databricks-connect major version from pin %q", dbcPin) + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + installedMajor := majorVersion(dbcVer) + if installedMajor == "" { + pe := NewError(ErrValidationFailed, nil, + "cannot determine installed databricks-connect major version from %q", dbcVer) + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + if pinMajor != installedMajor { + pe := NewError(ErrValidationFailed, nil, + "databricks-connect major version mismatch: want %s.x, got %s", pinMajor, dbcVer) + phase.Status = "failed" + phase.Detail = pe.Error() + res.Phases = append(res.Phases, phase) + res.Error = pe + return pe + } + } + + phase.Status = "ok" + phase.Detail = fmt.Sprintf("python=%s databricks-connect=%s", pyVer, dbcVer) + res.Phases = append(res.Phases, phase) + + venvPath := filepath.Join(p.ProjectDir, ".venv") + res.Result = &ResultDetail{ + Status: "success", + VenvPath: venvPath, + PythonVersion: pyVer, + DatabricksConnectInstalled: dbcVer, + } + return nil +} + +// dbcMajorFromPin extracts the major version number from a databricks-connect +// pin string such as "databricks-connect~=17.2.0". Returns "" if unparseable. +func dbcMajorFromPin(pin string) string { + // Strip the "databricks-connect" prefix and any operator (~=, ==, >=, etc.). + // The first digit sequence is the major version. + for i, c := range pin { + if c >= '0' && c <= '9' { + return majorVersion(pin[i:]) + } + } + return "" +} + +// majorVersion returns the major portion of a version string (digits before the +// first dot), e.g. "17" from "17.2.0". A bare integer like "17" returns "17". +// Returns "" for an empty string. +func majorVersion(v string) string { + if v == "" { + return "" + } + before, _, ok := strings.Cut(v, ".") + if !ok { + // No dot — the whole string is the major component. + return v + } + return before +} + +// copyFile copies src to dst, creating or overwriting dst. +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("read %s: %w", src, err) + } + if err := os.WriteFile(dst, data, 0o644); err != nil { + return fmt.Errorf("write %s: %w", dst, err) + } + return nil +} diff --git a/libs/dbconnect/pipeline_test.go b/libs/dbconnect/pipeline_test.go new file mode 100644 index 00000000000..e27a22b94f4 --- /dev/null +++ b/libs/dbconnect/pipeline_test.go @@ -0,0 +1,273 @@ +package dbconnect + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakePM struct{ py, dbc string } + +func (fakePM) Name() string { return "fake" } +func (fakePM) EnsureAvailable(context.Context) (string, error) { return "fake 1.0", nil } +func (fakePM) EnsurePython(context.Context, string) error { return nil } +func (fakePM) Provision(context.Context, string) error { return nil } +func (fakePM) PostProvision(context.Context, string) error { return nil } +func (f fakePM) Validate(context.Context, string) (string, string, error) { + return f.py, f.dbc, nil +} + +func writeProject(t *testing.T) string { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "pyproject.toml"), []byte(`[project] +name = "demo" +requires-python = ">=3.10" + +[dependency-groups] +dev = ["databricks-connect~=16.0.0"] +`), 0o644)) + return dir +} + +func newTestServer(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(sampleToml)) + })) +} + +func TestPipelineCheckMutatesNothing(t *testing.T) { + dir := writeProject(t) + before, _ := os.ReadFile(filepath.Join(dir, "pyproject.toml")) + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, Check: true, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.NoError(t, err) + assert.True(t, res.Check) + require.NotNil(t, res.Plan) + assert.Contains(t, res.Plan.Diff, "==3.12.*") + after, _ := os.ReadFile(filepath.Join(dir, "pyproject.toml")) + assert.Equal(t, string(before), string(after)) // unchanged +} + +func TestPipelineSyncProvisionsAndValidates(t *testing.T) { + dir := writeProject(t) + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.NoError(t, err) + require.NotNil(t, res.Result) + assert.Equal(t, "success", res.Result.Status) + assert.Equal(t, "3.12", res.Result.PythonVersion) + merged, _ := os.ReadFile(filepath.Join(dir, "pyproject.toml")) + assert.Contains(t, string(merged), `"databricks-connect~=17.2.0"`) + assert.FileExists(t, filepath.Join(dir, "pyproject.toml.bak")) +} + +func TestPipelineInitCreatesNewPyproject(t *testing.T) { + dir := t.TempDir() + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeInit, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.NoError(t, err) + require.NotNil(t, res.Result) + assert.Equal(t, "success", res.Result.Status) + data, readErr := os.ReadFile(filepath.Join(dir, "pyproject.toml")) + require.NoError(t, readErr) + assert.Contains(t, string(data), `"databricks-connect~=17.2.0",`) + // No backup created when pyproject.toml did not previously exist. + assert.NoFileExists(t, filepath.Join(dir, "pyproject.toml.bak")) +} + +func TestPipelineInitBacksUpExistingPyproject(t *testing.T) { + dir := writeProject(t) + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeInit, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.NoError(t, err) + require.NotNil(t, res.Result) + assert.FileExists(t, filepath.Join(dir, "pyproject.toml.bak")) +} + +func TestPipelineNoTarget(t *testing.T) { + dir := writeProject(t) + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{}, + Compute: stubCompute{}, PM: fakePM{}, + } + res, err := p.Run(t.Context()) + require.Error(t, err) + require.NotNil(t, res) + require.NotNil(t, res.Error) + assert.Equal(t, ErrNoTargetSelected, res.Error.Code) +} + +func TestPipelineSyncRestoresBackupBeforeMerge(t *testing.T) { + dir := t.TempDir() + // Write an original pyproject.toml and a pre-existing .bak. + original := []byte(`[project] +name = "demo" +requires-python = ">=3.9" + +[dependency-groups] +dev = ["databricks-connect~=15.0.0"] +`) + require.NoError(t, os.WriteFile(filepath.Join(dir, "pyproject.toml.bak"), original, 0o644)) + // Current pyproject.toml has been mutated by a previous run. + mutated := []byte(`[project] +name = "demo" +requires-python = "==3.12.*" + +[dependency-groups] +dev = ["databricks-connect~=17.2.0"] +`) + require.NoError(t, os.WriteFile(filepath.Join(dir, "pyproject.toml"), mutated, 0o644)) + + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.NoError(t, err) + require.NotNil(t, res) + // The bak content (requires-python = ">=3.9") was the base; merged result should + // contain the newly pinned version. + data, _ := os.ReadFile(filepath.Join(dir, "pyproject.toml")) + assert.Contains(t, string(data), `"databricks-connect~=17.2.0"`) + assert.Contains(t, string(data), `requires-python = "==3.12.*"`) +} + +func TestPipelineResultPopulatesConstraintInfo(t *testing.T) { + dir := writeProject(t) + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, Check: true, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.NoError(t, err) + require.NotNil(t, res.Constraints) + assert.Equal(t, "==3.12.*", res.Constraints.RequiresPython) + assert.Equal(t, "databricks-connect~=17.2.0", res.Constraints.DatabricksConnect) + assert.Equal(t, 2, res.Constraints.ConstraintCount) +} + +// newServerWithDBC returns a test server that serves a constraints TOML with the +// given databricks-connect pin value in the dev dependency group. +func newServerWithDBC(t *testing.T, dbcPin string) *httptest.Server { + t.Helper() + body := `[project] +requires-python = "==3.12.*" + +[dependency-groups] +dev = ["` + dbcPin + `"] + +[tool.uv] +constraint-dependencies = [] +` + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(body)) + })) +} + +func TestPipelineValidateRejectsUnparseablePin(t *testing.T) { + dir := writeProject(t) + // Serve a TOML whose dev group has a malformed databricks-connect entry + // (no version digits after the package name). + srv := newServerWithDBC(t, "databricks-connect") + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: "17.2.0"}, + } + res, err := p.Run(t.Context()) + require.Error(t, err) + require.NotNil(t, res.Error) + assert.Equal(t, ErrValidationFailed, res.Error.Code) +} + +func TestPipelineValidateRejectsUnparseableInstalledVersion(t *testing.T) { + dir := writeProject(t) + // sampleToml has databricks-connect~=17.2.0 as the pin; fakePM returns a + // bare integer "17" as the installed version — majorVersion("17") must now + // return "17" (not ""), so this actually passes. Use an empty installed + // version string to simulate an installed version that can't be parsed. + srv := newTestServer(t) + defer srv.Close() + + p := &Pipeline{ + Mode: ModeSync, ProjectDir: dir, + ConstraintBaseURL: srv.URL, CacheDir: t.TempDir(), + Flags: TargetFlags{Serverless: "v4"}, + Compute: stubCompute{}, PM: fakePM{py: "3.12", dbc: ""}, + } + res, err := p.Run(t.Context()) + require.Error(t, err) + require.NotNil(t, res.Error) + assert.Equal(t, ErrValidationFailed, res.Error.Code) +} + +func TestMajorVersion(t *testing.T) { + cases := []struct { + input string + want string + }{ + {"17.2.0", "17"}, + {"17", "17"}, + {"", ""}, + {"3.12", "3"}, + } + for _, tc := range cases { + assert.Equal(t, tc.want, majorVersion(tc.input), "input=%q", tc.input) + } +} diff --git a/libs/dbconnect/pkgmanager.go b/libs/dbconnect/pkgmanager.go new file mode 100644 index 00000000000..84d7c04b47d --- /dev/null +++ b/libs/dbconnect/pkgmanager.go @@ -0,0 +1,31 @@ +package dbconnect + +import "context" + +// PackageManager manages the Python environment for a dbconnect project. +type PackageManager interface { + // Name returns the name of the package manager (e.g. "uv"). + Name() string + + // EnsureAvailable ensures the package manager binary is present, installing + // it if necessary. It returns the version string on success. + EnsureAvailable(ctx context.Context) (version string, err error) + + // EnsurePython ensures the requested Python minor version (e.g. "3.12") is + // available via the package manager. + EnsurePython(ctx context.Context, minor string) error + + // Provision installs the project dependencies inside projectDir. + Provision(ctx context.Context, projectDir string) error + + // PostProvision seeds pip into the virtual environment inside projectDir. + // This step is required because VS Code's ms-python.vscode-python-envs + // extension falls back to `python -m pip list` when its `uv --version` + // probe fails on the GUI PATH; uv venvs contain no pip; and `uv sync` + // strips pip, so seeding must run after every sync. + PostProvision(ctx context.Context, projectDir string) error + + // Validate reads the Python minor version and databricks-connect version + // from the virtual environment inside projectDir. + Validate(ctx context.Context, projectDir string) (pythonVersion, dbconnectVersion string, err error) +} diff --git a/libs/dbconnect/result.go b/libs/dbconnect/result.go new file mode 100644 index 00000000000..b0917b062be --- /dev/null +++ b/libs/dbconnect/result.go @@ -0,0 +1,112 @@ +package dbconnect + +import "fmt" + +// Mode represents the dbconnect operation mode. +type Mode int + +const ( + ModeInit Mode = iota + ModeSync +) + +// String returns the string representation of the Mode. +func (m Mode) String() string { + if m == ModeInit { + return "init" + } + return "sync" +} + +// ErrorCode represents a dbconnect error code. +type ErrorCode string + +const ( + ErrNoTargetSelected ErrorCode = "no_target_selected" + ErrConstraintFetchFailed ErrorCode = "constraint_fetch_failed" + ErrMergeFailed ErrorCode = "merge_failed" + ErrProvisionFailed ErrorCode = "provision_failed" + ErrValidationFailed ErrorCode = "validation_failed" + ErrUvUnavailable ErrorCode = "uv_unavailable" +) + +// PipelineError represents an error during the dbconnect pipeline. +type PipelineError struct { + Code ErrorCode `json:"code"` + Msg string `json:"message"` + Err error `json:"-"` +} + +func (e *PipelineError) Error() string { + if e.Err != nil { + return e.Msg + ": " + e.Err.Error() + } + return e.Msg +} + +func (e *PipelineError) Unwrap() error { + return e.Err +} + +// NewError creates a new PipelineError. The message is formatted using fmt.Sprintf(format, args...), +// and err may be nil. +func NewError(code ErrorCode, err error, format string, args ...any) *PipelineError { + return &PipelineError{ + Code: code, + Msg: fmt.Sprintf(format, args...), + Err: err, + } +} + +// TargetInfo contains information about the target environment. +type TargetInfo struct { + Kind string `json:"kind"` + ClusterID string `json:"cluster_id"` + SparkVersion string `json:"spark_version"` + EnvKey string `json:"env_key"` + PythonVersion string `json:"python_version"` +} + +// ConstraintInfo contains constraint information. +type ConstraintInfo struct { + SourceURL string `json:"source_url"` + FromCache bool `json:"from_cache"` + RequiresPython string `json:"requires_python"` + DatabricksConnect string `json:"databricks_connect"` + ConstraintCount int `json:"constraint_count"` +} + +// Plan contains the deployment plan. +type Plan struct { + PyprojectPath string `json:"pyproject_path"` + BackupPath string `json:"backup_path"` + Diff string `json:"diff"` + ChangedRegions []string `json:"changed_regions"` +} + +// PhaseResult contains the result of a single phase. +type PhaseResult struct { + Name string `json:"name"` + Status string `json:"status"` + Detail string `json:"detail"` +} + +// ResultDetail contains the final result details. +type ResultDetail struct { + Status string `json:"status"` + VenvPath string `json:"venv_path"` + PythonVersion string `json:"python_version"` + DatabricksConnectInstalled string `json:"databricks_connect_installed"` +} + +// Result contains the overall result of the dbconnect operation. +type Result struct { + Mode string `json:"mode"` + Check bool `json:"check"` + Target *TargetInfo `json:"target,omitempty"` + Constraints *ConstraintInfo `json:"constraints,omitempty"` + Plan *Plan `json:"plan,omitempty"` + Phases []PhaseResult `json:"phases,omitempty"` + Result *ResultDetail `json:"result,omitempty"` + Error *PipelineError `json:"error,omitempty"` +} diff --git a/libs/dbconnect/result_test.go b/libs/dbconnect/result_test.go new file mode 100644 index 00000000000..643f505151f --- /dev/null +++ b/libs/dbconnect/result_test.go @@ -0,0 +1,21 @@ +package dbconnect + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPipelineErrorWrapsAndExposesCode(t *testing.T) { + base := errors.New("boom") + err := NewError(ErrConstraintFetchFailed, base, "fetch %s", "x") + assert.Equal(t, "fetch x: boom", err.Error()) + assert.Equal(t, ErrConstraintFetchFailed, err.Code) + assert.ErrorIs(t, err, base) +} + +func TestModeString(t *testing.T) { + assert.Equal(t, "init", ModeInit.String()) + assert.Equal(t, "sync", ModeSync.String()) +} diff --git a/libs/dbconnect/target.go b/libs/dbconnect/target.go new file mode 100644 index 00000000000..b151cce49dc --- /dev/null +++ b/libs/dbconnect/target.go @@ -0,0 +1,134 @@ +package dbconnect + +import ( + "context" + "fmt" + "strings" +) + +// ComputeClient is a narrow seam over the SDK so tests can stub it. +// The real adapter is wired in Task 9. +type ComputeClient interface { + // GetClusterSparkVersion returns the Spark version string for a cluster. + GetClusterSparkVersion(ctx context.Context, clusterID string) (string, error) + // GetJobSparkVersion returns either a Spark version (isServerless=false) or a + // serverless marker (isServerless=true) for a job, plus a recorded version string. + GetJobSparkVersion(ctx context.Context, jobID string) (sparkVersion string, isServerless bool, version string, err error) +} + +// TargetFlags holds the mutually-exclusive compute target flags from the CLI. +type TargetFlags struct { + Cluster string + Serverless string + Job string +} + +// BundleTarget is the three-state result of reading the bundle's configured +// target. Selected=false means nothing was configured. +type BundleTarget struct { + ClusterID string + Serverless bool + Selected bool +} + +// ValidateTargetFlags returns an error if more than one of the three flags is set. +// Cobra marks them mutually exclusive too; this guards the library path. +func ValidateTargetFlags(f TargetFlags) error { + var set []string + if f.Cluster != "" { + set = append(set, "--cluster") + } + if f.Serverless != "" { + set = append(set, "--serverless") + } + if f.Job != "" { + set = append(set, "--job") + } + if len(set) > 1 { + return fmt.Errorf("flags %s are mutually exclusive; specify at most one", strings.Join(set, " and ")) + } + return nil +} + +// ResolveTarget resolves the compute target using ordered precedence: +// --cluster flag → --serverless flag → --job flag → bundle target. +// PythonVersion is left empty; it is filled later from constraint data. +func ResolveTarget(ctx context.Context, f TargetFlags, c ComputeClient, bt BundleTarget) (*TargetInfo, error) { + if f.Cluster != "" { + v, err := c.GetClusterSparkVersion(ctx, f.Cluster) + if err != nil { + return nil, fmt.Errorf("resolving cluster %s: %w", f.Cluster, err) + } + return &TargetInfo{ + Kind: "cluster", + ClusterID: f.Cluster, + SparkVersion: v, + EnvKey: EnvKeyForSparkVersion(v), + }, nil + } + + if f.Serverless != "" { + return &TargetInfo{ + Kind: "serverless", + EnvKey: EnvKeyForServerless(f.Serverless), + }, nil + } + + if f.Job != "" { + _, isServerless, version, err := c.GetJobSparkVersion(ctx, f.Job) + if err != nil { + return nil, fmt.Errorf("resolving job %s: %w", f.Job, err) + } + if isServerless { + // Default to v4 when the job is serverless; the serverless env version + // is not recorded in the bundle/project (documented stand-in from the + // original script). + v := version + if v == "" { + v = "v4" + } + return &TargetInfo{ + Kind: "serverless", + EnvKey: EnvKeyForServerless(v), + }, nil + } + return &TargetInfo{ + Kind: "cluster", + SparkVersion: version, + EnvKey: EnvKeyForSparkVersion(version), + }, nil + } + + // Fall back to bundle target. + if !bt.Selected { + return nil, NewError(ErrNoTargetSelected, nil, + "No compute target is selected. Select a cluster or serverless target, or pass --cluster/--serverless/--job") + } + + if bt.Serverless { + // Default to serverless-v4: the serverless env version is not recorded + // in the bundle/project (documented stand-in from the original script). + return &TargetInfo{ + Kind: "serverless", + EnvKey: EnvKeyForServerless("v4"), + }, nil + } + + if bt.ClusterID != "" { + v, err := c.GetClusterSparkVersion(ctx, bt.ClusterID) + if err != nil { + return nil, fmt.Errorf("resolving bundle cluster %s: %w", bt.ClusterID, err) + } + return &TargetInfo{ + Kind: "cluster", + ClusterID: bt.ClusterID, + SparkVersion: v, + EnvKey: EnvKeyForSparkVersion(v), + }, nil + } + + // Bundle target is selected but has neither serverless nor a cluster ID — + // treat this the same as nothing selected so the user gets a clear message. + return nil, NewError(ErrNoTargetSelected, nil, + "No compute target is selected. Select a cluster or serverless target, or pass --cluster/--serverless/--job.") +} diff --git a/libs/dbconnect/target_test.go b/libs/dbconnect/target_test.go new file mode 100644 index 00000000000..24c1250f95a --- /dev/null +++ b/libs/dbconnect/target_test.go @@ -0,0 +1,57 @@ +package dbconnect + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type stubCompute struct { + clusterVersion string + clusterErr error +} + +func (s stubCompute) GetClusterSparkVersion(_ context.Context, _ string) (string, error) { + return s.clusterVersion, s.clusterErr +} + +func (s stubCompute) GetJobSparkVersion(_ context.Context, _ string) (string, bool, string, error) { + return "", false, "", nil +} + +func TestResolveServerlessFlag(t *testing.T) { + ti, err := ResolveTarget(t.Context(), TargetFlags{Serverless: "v4"}, stubCompute{}, BundleTarget{}) + require.NoError(t, err) + assert.Equal(t, "serverless", ti.Kind) + assert.Equal(t, "serverless/serverless-v4", ti.EnvKey) +} + +func TestResolveClusterFlag(t *testing.T) { + c := stubCompute{clusterVersion: "15.4.x-scala2.12"} + ti, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "abc"}, c, BundleTarget{}) + require.NoError(t, err) + assert.Equal(t, "cluster", ti.Kind) + assert.Equal(t, "15.4.x-scala2.12", ti.SparkVersion) + assert.Equal(t, "dbr/15.4.x-scala2.12", ti.EnvKey) + assert.Equal(t, "abc", ti.ClusterID) +} + +func TestResolveBundleNothingSelected(t *testing.T) { + _, err := ResolveTarget(t.Context(), TargetFlags{}, stubCompute{}, BundleTarget{Selected: false}) + var pe *PipelineError + require.ErrorAs(t, err, &pe) + assert.Equal(t, ErrNoTargetSelected, pe.Code) +} + +func TestResolveBundleServerless(t *testing.T) { + ti, err := ResolveTarget(t.Context(), TargetFlags{}, stubCompute{}, BundleTarget{Selected: true, Serverless: true}) + require.NoError(t, err) + assert.Equal(t, "serverless/serverless-v4", ti.EnvKey) +} + +func TestValidateTargetFlagsMutuallyExclusive(t *testing.T) { + assert.Error(t, ValidateTargetFlags(TargetFlags{Cluster: "a", Serverless: "v4"})) + assert.NoError(t, ValidateTargetFlags(TargetFlags{Cluster: "a"})) +} diff --git a/libs/dbconnect/uv.go b/libs/dbconnect/uv.go new file mode 100644 index 00000000000..2348a1a9e46 --- /dev/null +++ b/libs/dbconnect/uv.go @@ -0,0 +1,265 @@ +package dbconnect + +import ( + "bufio" + "context" + "errors" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/process" +) + +// uvManager implements PackageManager using the uv tool. +// https://docs.astral.sh/uv/ +type uvManager struct { + bin string +} + +// newUvManager returns a uvManager whose binary path is resolved lazily via +// EnsureAvailable. +func newUvManager() *uvManager { + return &uvManager{} +} + +// NewUvManager returns a PackageManager backed by the uv tool. +// This is the exported constructor for use outside this package. +func NewUvManager() PackageManager { + return newUvManager() +} + +// Name returns "uv". +func (m *uvManager) Name() string { + return "uv" +} + +// EnsureAvailable discovers or installs uv and records the binary path. +// It runs the official installer when uv is not found on the PATH or in the +// standard candidate locations. +// https://docs.astral.sh/uv/getting-started/installation/ +func (m *uvManager) EnsureAvailable(ctx context.Context) (string, error) { + bin, err := discoverUv(ctx) + if err != nil { + // Install uv using the official installer script. + // https://astral.sh/uv/install.sh + _, installErr := process.Background(ctx, []string{"sh", "-c", "curl -LsSf https://astral.sh/uv/install.sh | sh"}) + if installErr != nil { + return "", NewError(ErrUvUnavailable, installErr, "uv installation failed") + } + bin, err = discoverUv(ctx) + if err != nil { + return "", err + } + } + log.Debugf(ctx, "uv: discovered binary at %s", bin) + m.bin = bin + + // Use --version (not "version") to avoid project-scoped sub-command that requires pyproject.toml. + version, err := process.Background(ctx, []string{m.bin, "--version"}) + if err != nil { + return "", uvFailure(ErrProvisionFailed, err, "uv version check") + } + return strings.TrimSpace(version), nil +} + +// EnsurePython installs the requested Python minor version via uv. +func (m *uvManager) EnsurePython(ctx context.Context, minor string) error { + args := append([]string{m.bin}, m.pythonInstallArgs(minor)...) + indexURL := m.resolveIndexURL(ctx) + var err error + if indexURL != "" { + _, err = process.Background(ctx, args, process.WithEnv("UV_INDEX_URL", indexURL)) + } else { + _, err = process.Background(ctx, args) + } + if err != nil { + return uvFailure(ErrProvisionFailed, err, "uv python install "+minor) + } + return nil +} + +// Provision runs `uv sync` inside projectDir to install project dependencies. +func (m *uvManager) Provision(ctx context.Context, projectDir string) error { + args := append([]string{m.bin}, m.syncArgs()...) + indexURL := m.resolveIndexURL(ctx) + var err error + if indexURL != "" { + _, err = process.Background(ctx, args, process.WithDir(projectDir), process.WithEnv("UV_INDEX_URL", indexURL)) + } else { + _, err = process.Background(ctx, args, process.WithDir(projectDir)) + } + if err != nil { + return uvFailure(ErrProvisionFailed, err, "uv sync") + } + return nil +} + +// venvPython returns the path to the virtualenv's Python interpreter, +// accounting for the Windows (Scripts/python.exe) vs Unix (bin/python) layout. +func venvPython(projectDir string) string { + if runtime.GOOS == "windows" { + return filepath.Join(projectDir, ".venv", "Scripts", "python.exe") + } + return filepath.Join(projectDir, ".venv", "bin", "python") +} + +// PostProvision seeds pip into the project's virtual environment. +// +// VS Code's ms-python.vscode-python-envs extension falls back to +// `python -m pip list` when its `uv --version` probe fails on the GUI PATH. +// uv virtual environments do not include pip by default, and `uv sync` strips +// pip if it was previously present. Seeding pip after every sync ensures the +// VS Code integration works correctly regardless of how the environment was +// activated. +func (m *uvManager) PostProvision(ctx context.Context, projectDir string) error { + args := append([]string{m.bin}, m.pipSeedArgs(venvPython(projectDir))...) + indexURL := m.resolveIndexURL(ctx) + var err error + if indexURL != "" { + _, err = process.Background(ctx, args, process.WithDir(projectDir), process.WithEnv("UV_INDEX_URL", indexURL)) + } else { + _, err = process.Background(ctx, args, process.WithDir(projectDir)) + } + if err != nil { + return uvFailure(ErrProvisionFailed, err, "uv pip seed") + } + return nil +} + +// Validate reads the Python minor version and databricks-connect package +// version from the project's virtual environment. +func (m *uvManager) Validate(ctx context.Context, projectDir string) (string, string, error) { + pyCode := `import sys, importlib.metadata; print(f"{sys.version_info.major}.{sys.version_info.minor}"); print(importlib.metadata.version("databricks-connect"))` + // --no-project runs the interpreter from the created .venv without re-resolving/syncing + // the project's declared dependencies, so validation observes exactly what was installed. + out, err := process.Background(ctx, + []string{m.bin, "run", "--no-project", "python", "-c", pyCode}, + process.WithDir(projectDir), + ) + if err != nil { + return "", "", uvFailure(ErrValidationFailed, err, "uv run python validation") + } + lines := strings.Split(strings.TrimSpace(out), "\n") + if len(lines) < 2 { + return "", "", NewError(ErrValidationFailed, nil, "unexpected output from uv run: %q", out) + } + return strings.TrimSpace(lines[0]), strings.TrimSpace(lines[1]), nil +} + +// syncArgs returns the argument slice for `uv sync` (without the binary). +func (m *uvManager) syncArgs() []string { + return []string{"sync"} +} + +// pythonInstallArgs returns the argument slice for `uv python install `. +func (m *uvManager) pythonInstallArgs(minor string) []string { + return []string{"python", "install", minor} +} + +// pipSeedArgs returns the argument slice for seeding pip into the venv. +func (m *uvManager) pipSeedArgs(venvPython string) []string { + return []string{"pip", "install", "pip", "--python", venvPython} +} + +// pipIndexURLRe matches `index-url = ` lines in pip.conf. +var pipIndexURLRe = regexp.MustCompile(`(?i)^\s*index-url\s*=\s*(\S+)`) + +// pipConfIndexURL reads ~/.config/pip/pip.conf and returns the index-url value. +// uv ignores pip.conf; on Databricks-managed machines pypi.org is blocked and +// the corporate PyPI proxy is declared via pip.conf. Bridging the value through +// UV_INDEX_URL lets uv reach the proxy. +// https://pip.pypa.io/en/stable/topics/configuration/ +func pipConfIndexURL(ctx context.Context) string { + home, err := env.UserHomeDir(ctx) + if err != nil || home == "" { + return "" + } + confPath := filepath.Join(home, ".config", "pip", "pip.conf") + f, err := os.Open(confPath) + if err != nil { + return "" + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + if m := pipIndexURLRe.FindStringSubmatch(scanner.Text()); m != nil { + return strings.TrimSpace(m[1]) + } + } + return "" +} + +// resolveIndexURL returns a UV_INDEX_URL value to inject, or "" when none is +// needed. It returns "" when UV_INDEX_URL is already set in the context env +// (so the caller's explicit value is never overridden) and also when pip.conf +// has no index-url entry. +func (m *uvManager) resolveIndexURL(ctx context.Context) string { + if _, ok := env.Lookup(ctx, "UV_INDEX_URL"); ok { + log.Debugf(ctx, "uv: UV_INDEX_URL already set in environment, not overriding") + return "" + } + url := pipConfIndexURL(ctx) + if url != "" { + log.Debugf(ctx, "uv: using package index %s from pip.conf", url) + } else { + log.Debugf(ctx, "uv: no UV_INDEX_URL and no index-url in pip.conf; uv will use its default index (pypi.org)") + } + return url +} + +// uvFailure builds a PipelineError from a failed uv invocation, appending uv's +// stderr to the message so callers can see the actual failure reason (e.g. +// "Connection refused") rather than just the exit code. +func uvFailure(code ErrorCode, err error, action string) *PipelineError { + msg := action + " failed" + if perr, ok := errors.AsType[*process.ProcessError](err); ok && strings.TrimSpace(perr.Stderr) != "" { + msg = msg + ": " + strings.TrimSpace(perr.Stderr) + } + return NewError(code, err, "%s", msg) +} + +// discoverUv searches for the uv binary on PATH and in well-known install +// locations. It returns NewError(ErrUvUnavailable, ...) if uv is not found. +// +// Candidate locations follow the uv installer defaults: +// https://docs.astral.sh/uv/getting-started/installation/ +// XDG_BIN_HOME is specified by the XDG Base Directory Specification: +// https://specifications.freedesktop.org/basedir-spec/latest/ +func discoverUv(ctx context.Context) (string, error) { + // Prefer PATH lookup first; it respects user customisation. + if p, err := exec.LookPath("uv"); err == nil { + return p, nil + } + + home, _ := env.UserHomeDir(ctx) + + // XDG_BIN_HOME defaults to $HOME/.local/bin when unset. + xdgBinHome, _ := env.Lookup(ctx, "XDG_BIN_HOME") + + candidates := []string{ + filepath.Join(home, ".local", "bin", "uv"), + filepath.Join(xdgBinHome, "uv"), + "/opt/homebrew/bin/uv", + "/usr/local/bin/uv", + } + + for _, c := range candidates { + if c == "/uv" || c == "" { + // Skip degenerate paths produced when home or xdgBinHome is empty. + continue + } + if _, err := os.Stat(c); err == nil { + return c, nil + } + } + + return "", NewError(ErrUvUnavailable, nil, + "uv not found on PATH or in well-known locations (%s)", strings.Join(candidates, ", ")) +} diff --git a/libs/dbconnect/uv_test.go b/libs/dbconnect/uv_test.go new file mode 100644 index 00000000000..113a3d43726 --- /dev/null +++ b/libs/dbconnect/uv_test.go @@ -0,0 +1,126 @@ +package dbconnect + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/process" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUvArgs(t *testing.T) { + m := &uvManager{bin: "uv"} + assert.Equal(t, []string{"sync"}, m.syncArgs()) + assert.Equal(t, []string{"python", "install", "3.12"}, m.pythonInstallArgs("3.12")) + assert.Equal(t, []string{"pip", "install", "pip", "--python", "/p/.venv/bin/python"}, m.pipSeedArgs("/p/.venv/bin/python")) +} + +func TestDiscoverUvFindsBinOnPath(t *testing.T) { + dir := t.TempDir() + bin := filepath.Join(dir, "uv") + require.NoError(t, os.WriteFile(bin, []byte("#!/bin/sh\n"), 0o755)) + t.Setenv("PATH", dir) + got, err := discoverUv(t.Context()) + require.NoError(t, err) + assert.Equal(t, bin, got) +} + +func TestPipConfIndexURL(t *testing.T) { + t.Run("returns_url_from_pip_conf", func(t *testing.T) { + tmp := t.TempDir() + confDir := filepath.Join(tmp, ".config", "pip") + require.NoError(t, os.MkdirAll(confDir, 0o755)) + confContent := "[global]\nindex-url = https://proxy.example/simple\n" + require.NoError(t, os.WriteFile(filepath.Join(confDir, "pip.conf"), []byte(confContent), 0o644)) + + ctx := env.WithUserHomeDir(t.Context(), tmp) + got := pipConfIndexURL(ctx) + assert.Equal(t, "https://proxy.example/simple", got) + }) + + t.Run("returns_empty_when_no_pip_conf", func(t *testing.T) { + tmp := t.TempDir() + ctx := env.WithUserHomeDir(t.Context(), tmp) + got := pipConfIndexURL(ctx) + assert.Empty(t, got) + }) + + t.Run("returns_empty_when_no_index_url_in_conf", func(t *testing.T) { + tmp := t.TempDir() + confDir := filepath.Join(tmp, ".config", "pip") + require.NoError(t, os.MkdirAll(confDir, 0o755)) + confContent := "[global]\nextra-index-url = https://other.example/simple\n" + require.NoError(t, os.WriteFile(filepath.Join(confDir, "pip.conf"), []byte(confContent), 0o644)) + + ctx := env.WithUserHomeDir(t.Context(), tmp) + got := pipConfIndexURL(ctx) + assert.Empty(t, got) + }) +} + +func TestResolveIndexURLRespectsExistingEnv(t *testing.T) { + m := &uvManager{} + + t.Run("returns_empty_when_UV_INDEX_URL_already_set", func(t *testing.T) { + // When UV_INDEX_URL is in ctx, resolveIndexURL must not override it. + ctx := env.Set(t.Context(), "UV_INDEX_URL", "https://explicit.example/simple") + + // Set up a pip.conf that would otherwise be used. + tmp := t.TempDir() + confDir := filepath.Join(tmp, ".config", "pip") + require.NoError(t, os.MkdirAll(confDir, 0o755)) + confContent := "[global]\nindex-url = https://proxy.example/simple\n" + require.NoError(t, os.WriteFile(filepath.Join(confDir, "pip.conf"), []byte(confContent), 0o644)) + ctx = env.WithUserHomeDir(ctx, tmp) + + got := m.resolveIndexURL(ctx) + assert.Empty(t, got) + }) + + t.Run("returns_pip_conf_url_when_UV_INDEX_URL_unset", func(t *testing.T) { + tmp := t.TempDir() + confDir := filepath.Join(tmp, ".config", "pip") + require.NoError(t, os.MkdirAll(confDir, 0o755)) + confContent := "[global]\nindex-url = https://proxy.example/simple\n" + require.NoError(t, os.WriteFile(filepath.Join(confDir, "pip.conf"), []byte(confContent), 0o644)) + + ctx := env.WithUserHomeDir(t.Context(), tmp) + got := m.resolveIndexURL(ctx) + assert.Equal(t, "https://proxy.example/simple", got) + }) +} + +func TestUvFailureIncludesStderr(t *testing.T) { + t.Run("includes_stderr_when_present", func(t *testing.T) { + underlying := &process.ProcessError{ + Command: "uv sync", + Err: errors.New("exit status 2"), + Stderr: "error: Connection refused\n", + } + pe := uvFailure(ErrProvisionFailed, underlying, "uv sync") + assert.Equal(t, ErrProvisionFailed, pe.Code) + assert.Contains(t, pe.Msg, "Connection refused") + assert.NotEqual(t, '\n', pe.Msg[len(pe.Msg)-1], "Msg must not end with a newline") + }) + + t.Run("omits_stderr_suffix_when_empty", func(t *testing.T) { + underlying := &process.ProcessError{ + Command: "uv sync", + Err: errors.New("exit status 2"), + Stderr: "", + } + pe := uvFailure(ErrProvisionFailed, underlying, "uv sync") + assert.Equal(t, ErrProvisionFailed, pe.Code) + assert.Equal(t, "uv sync failed", pe.Msg) + }) + + t.Run("non_process_error_uses_action_only", func(t *testing.T) { + pe := uvFailure(ErrProvisionFailed, errors.New("some other error"), "uv sync") + assert.Equal(t, ErrProvisionFailed, pe.Code) + assert.Equal(t, "uv sync failed", pe.Msg) + }) +}