Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,12 @@ def iter_stage(self, stage_path: StagePath):
path = StagePath.get_user_stage() / file["name"]
elif stage_path.is_git_repo():
path = self.build_path(file["name"])
elif stage_path.is_vstage():
path = stage_path.root_path() / file["name"]
else:
# Snowflake `ls` returns unqualified names; re-attach the original FQN.
# For @-stage paths: Snowflake ls returns unqualified names (e.g.
# "stage_name/path/file.sql"); re-attach the original FQN prefix so
# the result uses the correct qualified stage name.
file_name = file["name"]
parts = file_name.split("/", maxsplit=1)
relative_path = parts[1] if len(parts) > 1 else ""
Expand Down
43 changes: 42 additions & 1 deletion src/snowflake/cli/api/stage_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
USER_STAGE_PREFIX = "~"
SNOW_PREFIX = "snow://"

# Snowflake unquoted identifiers contain only [A-Za-z0-9_$] and are normalized
# to lowercase in ls output. Components outside this charset came from quoted
# identifiers whose case Snowflake preserves — those must be compared exactly.
_UNQUOTED_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9_$]+$")


class StagePath:
def __init__(
Expand Down Expand Up @@ -70,6 +75,9 @@ def is_user_stage(self) -> bool:
def is_git_repo(self) -> bool:
return self._is_git_repo

def is_vstage(self) -> bool:
return self._is_snow_prefixed_stage

@property
def git_ref(self) -> str | None:
return self._git_ref
Expand Down Expand Up @@ -287,7 +295,40 @@ def quoted_absolute_path(self) -> str:
return to_string_literal(self.absolute_path())

def relative_to(self, stage_path: StagePath) -> PurePosixPath:
return self.path.relative_to(stage_path.path)
# Snowflake normalizes unquoted identifiers to lowercase in ls output,
# but the user-specified path may use uppercase (e.g. DEPLOYMENT$1 vs
# deployment$1). PurePosixPath.relative_to() is case-sensitive, so we
# do a prefix check and return the actual-cased tail.
#
# Case sensitivity rules:
# - Unquoted identifiers: only [A-Za-z0-9_$] chars, normalized to
# lowercase by Snowflake — compare case-insensitively.
# - Quoted identifiers: may contain any char; Snowflake preserves their
# case — compare exactly. We detect them by the presence of characters
# outside the unquoted charset (dots, spaces, etc.).
self_parts = self.path.parts
other_parts = stage_path.path.parts
if len(self_parts) < len(other_parts):
raise ValueError(
f"{self.path!r} is not in the subpath of {stage_path.path!r}"
)
for s, o in zip(self_parts[: len(other_parts)], other_parts):
if _UNQUOTED_IDENTIFIER_RE.match(s) and _UNQUOTED_IDENTIFIER_RE.match(o):
# Both are unquoted-style: case-insensitive comparison
if s.lower() != o.lower():
raise ValueError(
f"{self.path!r} is not in the subpath of {stage_path.path!r}"
)
else:
# At least one component has special chars → quoted identifier → exact match
if s != o:
raise ValueError(
f"{self.path!r} is not in the subpath of {stage_path.path!r}"
)
remainder = self_parts[len(other_parts) :]
if remainder:
return PurePosixPath(*remainder)
return PurePosixPath(".")

def get_local_target_path(self, target_dir: Path, stage_root: StagePath):
# Case for downloading @stage/aa/file.py with root @stage/aa
Expand Down
105 changes: 105 additions & 0 deletions tests/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,111 @@ def test_copy_get_recursive_from_git_repo(
]


@pytest.mark.parametrize(
"stage_path, files_on_stage, expected_stage_path, expected_calls",
[
(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3/",
[
"deployments/DEPLOYMENT$3/manifest.yml",
"deployments/DEPLOYMENT$3/scripts/setup.sql",
],
"'snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3/'",
[
"get 'snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3/manifest.yml' file://{}/ parallel=4",
"get 'snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3/scripts/setup.sql' file://{}/scripts/ parallel=4",
],
),
(
"snow://project/MY_DB.MY_SCHEMA.MY_PROJECT/deployments/deploy_v1/",
[
"deployments/deploy_v1/app.py",
"deployments/deploy_v1/lib/utils.py",
],
"'snow://project/MY_DB.MY_SCHEMA.MY_PROJECT/deployments/deploy_v1/'",
[
"get 'snow://project/MY_DB.MY_SCHEMA.MY_PROJECT/deployments/deploy_v1/app.py' file://{}/ parallel=4",
"get 'snow://project/MY_DB.MY_SCHEMA.MY_PROJECT/deployments/deploy_v1/lib/utils.py' file://{}/lib/ parallel=4",
],
),
(
"snow://streamlit/DB.SCHEMA.MY_APP/",
[
"streamlit_app.py",
"pages/page1.py",
],
"'snow://streamlit/DB.SCHEMA.MY_APP/'",
[
"get 'snow://streamlit/DB.SCHEMA.MY_APP/streamlit_app.py' file://{}/ parallel=4",
"get 'snow://streamlit/DB.SCHEMA.MY_APP/pages/page1.py' file://{}/pages/ parallel=4",
],
),
],
)
@mock.patch(f"{STAGE_MANAGER}.execute_query")
def test_copy_get_recursive_from_vstage(
mock_execute,
mock_cursor,
temporary_directory,
stage_path,
files_on_stage,
expected_stage_path,
expected_calls,
):
mock_execute.return_value = mock_cursor(
[{"name": file} for file in files_on_stage], []
)

StageManager().get_recursive(stage_path, Path(temporary_directory))

ls_call, *copy_calls = mock_execute.mock_calls
assert ls_call == mock.call(f"ls {expected_stage_path}", cursor_class=DictCursor)
assert copy_calls == [
mock.call(c.format(temporary_directory)) for c in expected_calls
]


@pytest.mark.parametrize(
"stage_path, files_on_stage, expected_stage_path, expected_calls",
[
# Snowflake returns lowercase identifiers in ls output (DEPLOYMENT$1 -> deployment$1)
(
"snow://project/TEMP.DEV_PLATFORM_MAA.DEVPLATFORM_MAA_PROJECT/deployments/DEPLOYMENT$1/",
[
"deployments/deployment$1/deploy_metadata.json",
"deployments/deployment$1/scripts/setup.sql",
],
"'snow://project/TEMP.DEV_PLATFORM_MAA.DEVPLATFORM_MAA_PROJECT/deployments/DEPLOYMENT$1/'",
[
"get 'snow://project/TEMP.DEV_PLATFORM_MAA.DEVPLATFORM_MAA_PROJECT/deployments/deployment$1/deploy_metadata.json' file://{}/ parallel=4",
"get 'snow://project/TEMP.DEV_PLATFORM_MAA.DEVPLATFORM_MAA_PROJECT/deployments/deployment$1/scripts/setup.sql' file://{}/scripts/ parallel=4",
],
),
],
)
@mock.patch(f"{STAGE_MANAGER}.execute_query")
def test_copy_get_recursive_from_vstage_case_mismatch(
mock_execute,
mock_cursor,
temporary_directory,
stage_path,
files_on_stage,
expected_stage_path,
expected_calls,
):
mock_execute.return_value = mock_cursor(
[{"name": file} for file in files_on_stage], []
)

StageManager().get_recursive(stage_path, Path(temporary_directory))

ls_call, *copy_calls = mock_execute.mock_calls
assert ls_call == mock.call(f"ls {expected_stage_path}", cursor_class=DictCursor)
assert copy_calls == [
mock.call(c.format(temporary_directory)) for c in expected_calls
]


@mock.patch(f"{STAGE_MANAGER}.execute_query")
def test_stage_create(mock_execute, runner, mock_cursor):
mock_execute.return_value = mock_cursor(["row"], [])
Expand Down
160 changes: 159 additions & 1 deletion tests/stage/test_stage_path.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import tempfile
from pathlib import Path
from pathlib import Path, PurePosixPath

import pytest
from snowflake.cli._plugins.stage.manager import DefaultStagePathParts, VStagePathParts
Expand Down Expand Up @@ -559,6 +559,17 @@ def test_vstage_paths(stage_str):
"deep/nested/structure/",
True,
),
(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3/",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3",
"PROJECTS",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV",
"project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV",
"project",
"deployments/DEPLOYMENT$3/",
True,
),
],
)
def test_vstage_path_parts_properties(
Expand Down Expand Up @@ -602,6 +613,153 @@ def test_vstage_path_parts_invalid_paths(invalid_path):
VStagePathParts(invalid_path)


def test_vstage_file_path_reconstruction_from_ls_output():
vstage_root = StagePath.from_stage_str(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$3/"
)

ls_file_name = "deployments/DEPLOYMENT$3/manifest.yml"

correct_path = vstage_root.root_path() / ls_file_name
assert correct_path.absolute_path() == (
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/" + ls_file_name
)
assert correct_path.path_for_sql() == (
"'snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/" + ls_file_name + "'"
)


@pytest.mark.parametrize(
"file_path_str, stage_root_str, expected_relative",
[
# Snowflake lowercases unquoted identifiers in ls output
(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/deployment$1/deploy_metadata.json",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$1/",
"deploy_metadata.json",
),
# Same case should still work
(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$1/deploy_metadata.json",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$1/",
"deploy_metadata.json",
),
# Nested path with case mismatch
(
"snow://project/TEMP.DEV.PROJECT/deployments/deployment$1/scripts/setup.sql",
"snow://project/TEMP.DEV.PROJECT/deployments/DEPLOYMENT$1/",
"scripts/setup.sql",
),
# Regular @stage (no case mismatch — should still work)
(
"@my_stage/data/file.csv",
"@my_stage/data/",
"file.csv",
),
],
)
def test_relative_to_case_insensitive(file_path_str, stage_root_str, expected_relative):
file_path = StagePath.from_stage_str(file_path_str)
stage_root = StagePath.from_stage_str(stage_root_str)
result = file_path.relative_to(stage_root)
assert result == PurePosixPath(expected_relative)


@pytest.mark.parametrize(
"file_path_str, stage_root_str, target_dir, expected_local_path",
[
# Case mismatch — Snowflake returns lowercase, user specified uppercase
(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/deployment$1/deploy_metadata.json",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$1/",
"/tmp/download",
"/tmp/download",
),
# Nested file with case mismatch
(
"snow://project/TEMP.DEV.PROJECT/deployments/deployment$1/scripts/setup.sql",
"snow://project/TEMP.DEV.PROJECT/deployments/DEPLOYMENT$1/",
"/tmp/download",
"/tmp/download/scripts",
),
# Same case (no mismatch)
(
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$1/manifest.yml",
"snow://project/DCM_DEMO.PROJECTS.DCM_PROJECT_DEV/deployments/DEPLOYMENT$1/",
"/tmp/download",
"/tmp/download",
),
],
)
def test_get_local_target_path_case_insensitive(
file_path_str, stage_root_str, target_dir, expected_local_path
):
file_path = StagePath.from_stage_str(file_path_str)
stage_root = StagePath.from_stage_str(stage_root_str)
result = file_path.get_local_target_path(Path(target_dir), stage_root)
assert result == Path(expected_local_path)


@pytest.mark.parametrize(
"stage_str, expected",
[
("@my_stage", False),
("@db.schema.my_stage/path", False),
("~", False),
("snow://streamlit/db.schema.name/versions/live", True),
("snow://project/MY_DB.MY_SCHEMA.MY_PROJECT/deployments/v1/", True),
("snow://notebook/schema.name", True),
],
)
def test_is_vstage(stage_str, expected):
assert StagePath.from_stage_str(stage_str).is_vstage() == expected


def test_is_vstage_preserved_through_operations():
sp = StagePath.from_stage_str(
"snow://project/DCM_DEMO.PROJECTS.MY_PROJECT/deployments/v1/"
)
assert sp.is_vstage() is True
assert (sp / "subdir").is_vstage() is True
assert sp.parent.is_vstage() is True
assert sp.root_path().is_vstage() is True


@pytest.mark.parametrize(
"file_path_str, stage_root_str, expected_relative",
[
# $ in unquoted identifier — case-insensitive
("@s/deployment$1/file.sql", "@s/DEPLOYMENT$1/", "file.sql"),
# Plain alpha case mismatch — case-insensitive
("@s/DATA_DIR/file.csv", "@s/data_dir/", "file.csv"),
# Component with dots — exact match succeeds
("@s/My.Dir/file.sql", "@s/My.Dir/", "file.sql"),
],
)
def test_relative_to_quoted_identifier_success(
file_path_str, stage_root_str, expected_relative
):
fp = StagePath.from_stage_str(file_path_str)
root = StagePath.from_stage_str(stage_root_str)
assert fp.relative_to(root) == PurePosixPath(expected_relative)


@pytest.mark.parametrize(
"file_path_str, stage_root_str",
[
# Component with dots and wrong case — exact match required, must raise
("@s/My.Dir/file.sql", "@s/my.dir/"),
# Completely different paths
("@s/foo/file.sql", "@s/bar/"),
],
)
def test_relative_to_raises_for_non_subpath(file_path_str, stage_root_str):
fp = StagePath.from_stage_str(file_path_str)
root = StagePath.from_stage_str(stage_root_str)
with pytest.raises(ValueError, match="is not in the subpath of"):
fp.relative_to(root)


def test_local_dir_with_dot_are_identified_as_dir_not_file():
with tempfile.TemporaryDirectory(suffix="dot.in.name") as dir_path:
assert "." in dir_path
Expand Down
Loading