Skip to content

Commit 745e22e

Browse files
committed
Improve UX with private repos
* Fetch and check credentials stored on the server side if no credentials provided via the command line (otherwise, check the provided credentials as usual) * Detect the default branch using the provided or stored credentials Closes: #3061
1 parent be36bde commit 745e22e

File tree

6 files changed

+193
-121
lines changed

6 files changed

+193
-121
lines changed

src/dstack/_internal/cli/commands/init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from dstack._internal.cli.commands import BaseCommand
77
from dstack._internal.cli.services.repos import (
88
get_repo_from_dir,
9-
get_repo_from_url,
109
is_git_repo_url,
1110
register_init_repo_args,
1211
)
1312
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
1413
from dstack._internal.core.errors import ConfigurationError
14+
from dstack._internal.core.models.repos.remote import RemoteRepo
1515
from dstack._internal.core.services.configs import ConfigManager
1616
from dstack.api import Client
1717

@@ -101,7 +101,7 @@ def _command(self, args: argparse.Namespace):
101101
if repo_url is not None:
102102
# Dummy repo branch to avoid autodetection that fails on private repos.
103103
# We don't need branch/hash for repo_id anyway.
104-
repo = get_repo_from_url(repo_url, repo_branch="master")
104+
repo = RemoteRepo.from_url(repo_url, repo_branch="master")
105105
elif repo_path is not None:
106106
repo = get_repo_from_dir(repo_path, local=local)
107107
else:

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args
1818
from dstack._internal.cli.services.repos import (
1919
get_repo_from_dir,
20-
get_repo_from_url,
2120
init_default_virtual_repo,
2221
is_git_repo_url,
2322
register_init_repo_args,
@@ -43,13 +42,19 @@
4342
ServiceConfiguration,
4443
TaskConfiguration,
4544
)
45+
from dstack._internal.core.models.repos import RepoHeadWithCreds
4646
from dstack._internal.core.models.repos.base import Repo
4747
from dstack._internal.core.models.repos.local import LocalRepo
48+
from dstack._internal.core.models.repos.remote import RemoteRepo, RemoteRepoCreds
4849
from dstack._internal.core.models.resources import CPUSpec
4950
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
5051
from dstack._internal.core.services.configs import ConfigManager
5152
from dstack._internal.core.services.diff import diff_models
52-
from dstack._internal.core.services.repos import load_repo
53+
from dstack._internal.core.services.repos import (
54+
InvalidRepoCredentialsError,
55+
get_repo_creds_and_default_branch,
56+
load_repo,
57+
)
5358
from dstack._internal.utils.common import local_time
5459
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
5560
from dstack._internal.utils.logging import get_logger
@@ -524,15 +529,17 @@ def get_repo(
524529
return init_default_virtual_repo(api=self.api)
525530

526531
repo: Optional[Repo] = None
532+
repo_head: Optional[RepoHeadWithCreds] = None
527533
repo_branch: Optional[str] = configurator_args.repo_branch
528534
repo_hash: Optional[str] = configurator_args.repo_hash
535+
repo_creds: Optional[RemoteRepoCreds] = None
536+
git_identity_file: Optional[str] = configurator_args.git_identity_file
537+
git_private_key: Optional[str] = None
538+
oauth_token: Optional[str] = configurator_args.gh_token
529539
# Should we (re)initialize the repo?
530540
# If any Git credentials provided, we reinitialize the repo, as the user may have provided
531541
# updated credentials.
532-
init = (
533-
configurator_args.git_identity_file is not None
534-
or configurator_args.gh_token is not None
535-
)
542+
init = git_identity_file is not None or oauth_token is not None
536543

537544
url: Optional[str] = None
538545
local_path: Optional[Path] = None
@@ -565,15 +572,15 @@ def get_repo(
565572
local_path = Path.cwd()
566573
legacy_local_path = True
567574
if url:
568-
repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash)
569-
if not self.api.repos.is_initialized(repo, by_user=True):
570-
init = True
575+
# "master" is a dummy value, we'll fetch the actual default branch later
576+
repo = RemoteRepo.from_url(repo_url=url, repo_branch="master")
577+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
571578
elif local_path:
572579
if legacy_local_path:
573580
if repo_config := config_manager.get_repo_config(local_path):
574581
repo = load_repo(repo_config)
575-
# allow users with legacy configurations use shared repo creds
576-
if self.api.repos.is_initialized(repo, by_user=False):
582+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
583+
if repo_head is not None:
577584
warn(
578585
"The repo is not specified but found and will be used in the run\n"
579586
"Future versions will not load repos automatically\n"
@@ -600,20 +607,55 @@ def get_repo(
600607
)
601608
local: bool = configurator_args.local
602609
repo = get_repo_from_dir(local_path, local=local)
603-
if not self.api.repos.is_initialized(repo, by_user=True):
604-
init = True
610+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
611+
if isinstance(repo, RemoteRepo):
612+
repo_branch = repo.run_repo_data.repo_branch
613+
repo_hash = repo.run_repo_data.repo_hash
605614
else:
606615
assert False, "should not reach here"
607616

608617
if repo is None:
609618
return init_default_virtual_repo(api=self.api)
610619

620+
if isinstance(repo, RemoteRepo):
621+
assert repo.repo_url is not None
622+
623+
if repo_head is not None and repo_head.repo_creds is not None:
624+
if git_identity_file is None and oauth_token is None:
625+
git_private_key = repo_head.repo_creds.private_key
626+
oauth_token = repo_head.repo_creds.oauth_token
627+
else:
628+
init = True
629+
630+
try:
631+
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
632+
repo_url=repo.repo_url,
633+
identity_file=git_identity_file,
634+
private_key=git_private_key,
635+
oauth_token=oauth_token,
636+
)
637+
except InvalidRepoCredentialsError as e:
638+
raise CLIError(*e.args) from e
639+
640+
if repo_branch is None and repo_hash is None:
641+
repo_branch = default_repo_branch
642+
if repo_branch is None:
643+
raise CLIError(
644+
"Failed to automatically detect remote repo branch."
645+
" Specify branch or hash."
646+
)
647+
repo = RemoteRepo.from_url(
648+
repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash
649+
)
650+
611651
if init:
612652
self.api.repos.init(
613653
repo=repo,
614-
git_identity_file=configurator_args.git_identity_file,
615-
oauth_token=configurator_args.gh_token,
654+
git_identity_file=git_identity_file,
655+
oauth_token=oauth_token,
656+
creds=repo_creds,
616657
)
658+
617659
if isinstance(repo, LocalRepo):
618660
warn(
619661
f"{repo.repo_dir} is a local repo\n"

src/dstack/_internal/cli/services/repos.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
from typing import Literal, Optional, Union, overload
2+
from typing import Literal, Union, overload
33

44
import git
55

@@ -8,7 +8,6 @@
88
from dstack._internal.core.models.repos.local import LocalRepo
99
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
1010
from dstack._internal.core.models.repos.virtual import VirtualRepo
11-
from dstack._internal.core.services.repos import get_default_branch
1211
from dstack._internal.utils.path import PathLike
1312
from dstack.api._public import Client
1413

@@ -43,22 +42,6 @@ def init_default_virtual_repo(api: Client) -> VirtualRepo:
4342
return repo
4443

4544

46-
def get_repo_from_url(
47-
repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None
48-
) -> RemoteRepo:
49-
if repo_branch is None and repo_hash is None:
50-
repo_branch = get_default_branch(repo_url)
51-
if repo_branch is None:
52-
raise CLIError(
53-
"Failed to automatically detect remote repo branch. Specify branch or hash."
54-
)
55-
return RemoteRepo.from_url(
56-
repo_url=repo_url,
57-
repo_branch=repo_branch,
58-
repo_hash=repo_hash,
59-
)
60-
61-
6245
@overload
6346
def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ...
6447

0 commit comments

Comments
 (0)