Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mcp_server_snowflake/object_manager/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def parse_object(target_object: Any, obj_type: supported_objects):


def initialize_object_manager_tools(server: FastMCP, snowflake_service):
root = snowflake_service.root
supported_objects_list = list(get_args(supported_objects))
object_type_annotation = Annotated[
supported_objects,
Expand Down Expand Up @@ -221,7 +220,7 @@ def create_object_tool(
):
# If string is passed, parse JSON and create object
target_object = parse_object(target_object, object_type)
return create_object(target_object, root, mode)
return create_object(target_object, snowflake_service.root, mode)

@server.tool(
name="drop_object",
Expand All @@ -233,7 +232,7 @@ def drop_object_tool(
if_exists: bool = False,
):
target_object = parse_object(target_object, object_type)
return drop_object(target_object, root, if_exists)
return drop_object(target_object, snowflake_service.root, if_exists)

@server.tool(
name="create_or_alter_object",
Expand All @@ -244,7 +243,7 @@ def create_or_alter_object_tool(
target_object: target_object_annotation,
):
target_object = parse_object(target_object, object_type)
return create_or_alter_object(target_object, root)
return create_or_alter_object(target_object, snowflake_service.root)

@server.tool(
name="describe_object",
Expand All @@ -255,7 +254,7 @@ def describe_object_tool(
target_object: target_object_annotation,
):
target_object = parse_object(target_object, object_type)
return describe_object(target_object, root)
return describe_object(target_object, snowflake_service.root)

@server.tool(
name="list_objects",
Expand Down
47 changes: 21 additions & 26 deletions mcp_server_snowflake/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
transport: str,
connection_params: dict,
endpoint: str = "/mcp",
lazy_auth: bool = False,
):
if service_config_file is None:
raise ValueError(
Expand Down Expand Up @@ -141,9 +142,10 @@ def __init__(
self._is_spcs_container = is_running_in_spcs_container()

self.unpack_service_specs()
# Persist connection to avoid closing it after each request
self.connection = self._get_persistent_connection()
self.root = Root(self.connection)
self.connection = None
self.root = None
if not lazy_auth:
self._ensure_connection()

def unpack_service_specs(self) -> None:
"""
Expand Down Expand Up @@ -188,6 +190,11 @@ def unpack_service_specs(self) -> None:
logger.error(f"Error extracting service specifications: {e}")
raise

def _ensure_connection(self) -> None:
if self.connection is None:
self.connection = self._get_persistent_connection()
self.root = Root(self.connection)

def get_api_headers(self) -> Dict[str, str]:
"""
Get authentication headers for REST API calls.
Expand All @@ -205,6 +212,7 @@ def get_api_headers(self) -> Dict[str, str]:
}
else:
# For external environments, we need to use the connection token
self._ensure_connection()
return {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
Expand All @@ -225,6 +233,7 @@ def get_api_host(self) -> str:
"SNOWFLAKE_HOST", self.connection_params.get("account", "")
)
else:
self._ensure_connection()
return self.connection.host

@staticmethod
Expand Down Expand Up @@ -342,29 +351,7 @@ def get_connection(
"""

try:
if self.connection is None:
# Get connection parameters based on environment
if self._is_spcs_container:
logger.info("Using SPCS container OAuth authentication")
connection_params = {
"host": os.getenv("SNOWFLAKE_HOST"),
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"token": get_spcs_container_token(),
"authenticator": "oauth",
}
connection_params = {
k: v for k, v in connection_params.items() if v is not None
}
else:
logger.info("Using external authentication")
connection_params = self.connection_params.copy()

self.connection = connect(
**connection_params,
session_parameters=session_parameters,
client_session_keep_alive=False,
paramstyle="qmark",
)
self._ensure_connection()

cursor = (
self.connection.cursor(DictCursor)
Expand Down Expand Up @@ -508,6 +495,13 @@ def parse_arguments():
help="Enable verbose/debug logging",
default=False,
)
parser.add_argument(
"--lazy-auth",
action="store_true",
required=False,
default=bool(os.getenv("SNOWFLAKE_MCP_LAZY_AUTH")),
help="Defer Snowflake authentication until first tool use (default: False)",
)

return parser.parse_args()

Expand Down Expand Up @@ -542,6 +536,7 @@ async def create_snowflake_service(
transport=args.transport,
connection_params=connection_params,
endpoint=endpoint or args.endpoint,
lazy_auth=args.lazy_auth,
)

# Initialize tools and resources now that we have the service
Expand Down
144 changes: 144 additions & 0 deletions mcp_server_snowflake/tests/test_lazy_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock, patch

import pytest
import yaml

from mcp_server_snowflake.server import SnowflakeService, parse_arguments


@pytest.fixture
def minimal_config_file(tmp_path):
config = {"search_services": []}
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config))
return config_file


@pytest.fixture
def mock_snowflake_connect():
with (
patch("mcp_server_snowflake.server.connect") as mock_connect,
patch("mcp_server_snowflake.server.Root") as mock_root,
):
mock_connect.return_value = MagicMock()
mock_root.return_value = MagicMock()
yield mock_connect


# --- lazy_auth CLI argument tests ---


def test_parse_arguments_default_lazy_auth(monkeypatch):
monkeypatch.delenv("SNOWFLAKE_MCP_LAZY_AUTH", raising=False)
with patch("sys.argv", ["prog"]):
args = parse_arguments()
assert args.lazy_auth is False


def test_parse_arguments_lazy_auth_flag():
with patch("sys.argv", ["prog", "--lazy-auth"]):
args = parse_arguments()
assert args.lazy_auth is True


def test_parse_arguments_lazy_auth_from_env(monkeypatch):
monkeypatch.setenv("SNOWFLAKE_MCP_LAZY_AUTH", "1")
with patch("sys.argv", ["prog"]):
args = parse_arguments()
assert args.lazy_auth is True


# --- SnowflakeService lazy_auth behavior tests ---


def test_snowflake_service_eager_auth_connects_on_init(
mock_snowflake_connect, minimal_config_file
):
"""Without lazy_auth, connection is established during __init__."""
service = SnowflakeService(
service_config_file=str(minimal_config_file),
transport="stdio",
connection_params={},
lazy_auth=False,
)
mock_snowflake_connect.assert_called_once()
assert service.connection is not None
assert service.root is not None


def test_snowflake_service_lazy_auth_defers_connection(minimal_config_file):
"""With lazy_auth=True, connect is NOT called during __init__."""
with (
patch("mcp_server_snowflake.server.connect") as mock_connect,
patch("mcp_server_snowflake.server.Root") as mock_root,
):
mock_connect.return_value = MagicMock()
mock_root.return_value = MagicMock()

service = SnowflakeService(
service_config_file=str(minimal_config_file),
transport="stdio",
connection_params={},
lazy_auth=True,
)
mock_connect.assert_not_called()
assert service.connection is None
assert service.root is None


def test_snowflake_service_lazy_auth_connects_on_ensure(minimal_config_file):
"""With lazy_auth=True, _ensure_connection() triggers the connection."""
with (
patch("mcp_server_snowflake.server.connect") as mock_connect,
patch("mcp_server_snowflake.server.Root") as mock_root,
):
mock_connection = MagicMock()
mock_root_instance = MagicMock()
mock_connect.return_value = mock_connection
mock_root.return_value = mock_root_instance

service = SnowflakeService(
service_config_file=str(minimal_config_file),
transport="stdio",
connection_params={},
lazy_auth=True,
)
mock_connect.assert_not_called()

service._ensure_connection()

mock_connect.assert_called_once()
assert service.connection is mock_connection
assert service.root is mock_root_instance


def test_snowflake_service_ensure_connection_idempotent(minimal_config_file):
"""_ensure_connection() called twice only creates one connection."""
with (
patch("mcp_server_snowflake.server.connect") as mock_connect,
patch("mcp_server_snowflake.server.Root"),
):
mock_connect.return_value = MagicMock()

service = SnowflakeService(
service_config_file=str(minimal_config_file),
transport="stdio",
connection_params={},
lazy_auth=True,
)
service._ensure_connection()
service._ensure_connection()

mock_connect.assert_called_once()