Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

# Release Notes
- v4.4.0(March 24,2026)
- Added ECDSA key support (ES256, ES384, ES512) for key-pair authentication.
- Bump the lower boundary of cryptography to 46.0.5 due to CVE-2026-26007.
- Added support for Python 3.14.
- Removed pyOpenSSL upper bound dependency constraint to allow installation of pyOpenSSL 26.0.0+, which includes a fix for GHSA-vp96-hxj8-p424.
Expand Down
51 changes: 40 additions & 11 deletions src/snowflake/connector/auth/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import (
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurvePrivateKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.serialization import (
Encoding,
Expand All @@ -28,7 +34,11 @@
class AuthByKeyPair(AuthByPlugin):
"""Key pair based authentication."""

ALGORITHM = "RS256"
ALG_RS256 = "RS256"
ALG_ES256 = "ES256"
ALG_ES384 = "ES384"
ALG_ES512 = "ES512"

ISSUER = "iss"
SUBJECT = "sub"
EXPIRE_TIME = "exp"
Expand All @@ -39,7 +49,7 @@ class AuthByKeyPair(AuthByPlugin):

def __init__(
self,
private_key: bytes | str | RSAPrivateKey,
private_key: bytes | str | RSAPrivateKey | EllipticCurvePrivateKey,
private_key_passphrase: bytes | None = None,
lifetime_in_seconds: int = LIFETIME,
**kwargs,
Expand All @@ -48,7 +58,7 @@ def __init__(

Args:
private_key: a byte array of der formats of private key, or an
object that implements the `RSAPrivateKey` interface.
object that implements the `RSAPrivateKey` or `EllipticCurvePrivateKey` interface.
lifetime_in_seconds: number of seconds the JWT token will be valid
"""
super().__init__(
Expand All @@ -72,7 +82,9 @@ def __init__(
).total_seconds()
)

self._private_key: bytes | str | RSAPrivateKey | None = private_key
self._private_key: (
bytes | str | RSAPrivateKey | EllipticCurvePrivateKey | None
) = private_key
self._private_key_passphrase: bytes | None = private_key_passphrase
self._jwt_token = ""
self._jwt_token_exp = 0
Expand Down Expand Up @@ -109,7 +121,7 @@ def prepare(
except Exception as e:
raise ProgrammingError(
msg=f"Failed to decode private key: {e}\nPlease provide a valid "
"unencrypted rsa private key in base64-encoded DER format as a "
"unencrypted RSA or ECDSA private key in base64-encoded DER format as a "
"str object",
errno=ER_INVALID_PRIVATE_KEY,
)
Expand All @@ -124,23 +136,23 @@ def prepare(
except Exception as e:
raise ProgrammingError(
msg=f"Failed to load private key: {e}\nPlease provide a valid "
"rsa private key in DER format as bytes object. If the key is "
"RSA or ECDSA private key in DER format as bytes object. If the key is "
"encrypted, provide the passphrase via private_key_passphrase",
errno=ER_INVALID_PRIVATE_KEY,
)

if not isinstance(private_key, RSAPrivateKey):
if not isinstance(private_key, (RSAPrivateKey, EllipticCurvePrivateKey)):
raise ProgrammingError(
msg=f"Private key type ({private_key.__class__.__name__}) not supported."
"\nPlease provide a valid rsa private key in DER format as bytes "
"\nPlease provide a valid RSA or ECDSA private key in DER format as bytes "
"object",
errno=ER_INVALID_PRIVATE_KEY,
)
elif isinstance(self._private_key, RSAPrivateKey):
elif isinstance(self._private_key, (RSAPrivateKey, EllipticCurvePrivateKey)):
private_key = self._private_key
else:
raise TypeError(
f"Expected bytes or RSAPrivateKey, got {type(self._private_key)}"
f"Expected bytes, RSAPrivateKey, or EllipticCurvePrivateKey, got {type(self._private_key)}"
)

public_key_fp = self.calculate_public_key_fingerprint(private_key)
Expand All @@ -153,7 +165,24 @@ def prepare(
self.EXPIRE_TIME: self._jwt_token_exp,
}

_jwt_token = jwt.encode(payload, private_key, algorithm=self.ALGORITHM)
# select algorithm based on key type and curve
if isinstance(private_key, EllipticCurvePrivateKey):
curve = private_key.curve
if isinstance(curve, SECP256R1):
algorithm = self.ALG_ES256
elif isinstance(curve, SECP384R1):
algorithm = self.ALG_ES384
elif isinstance(curve, SECP521R1):
algorithm = self.ALG_ES512
else:
raise ProgrammingError(
msg=f"Unsupported EC curve: {curve.name}. Supported: SECP256R1, SECP384R1, SECP521R1",
errno=ER_INVALID_PRIVATE_KEY,
)
else:
algorithm = self.ALG_RS256

_jwt_token = jwt.encode(payload, private_key, algorithm=algorithm)

# jwt.encode() returns bytes in pyjwt 1.x and a string
# in pyjwt 2.x
Expand Down
161 changes: 152 additions & 9 deletions test/unit/test_auth_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.serialization import load_der_private_key
from pytest import raises
Expand Down Expand Up @@ -41,7 +42,7 @@ def _mock_auth_key_pair_rest_response(url, headers, body, **kwargs):
@pytest.mark.parametrize("authenticator", ["SNOWFLAKE_JWT", "snowflake_jwt"])
def test_auth_keypair(authenticator):
"""Simple Key Pair test."""
private_key_der, public_key_der_encoded = generate_key_pair(2048)
private_key_der, public_key_der_encoded = generate_rsa_key_pair(2048)
application = "testapplication"
account = "testaccount"
user = "testuser"
Expand All @@ -68,7 +69,7 @@ def test_auth_keypair_with_passphrase():
"""Simple Key Pair test with passphrase."""

passphrase = b"test"
private_key_der, public_key_der_encoded = generate_key_pair(
private_key_der, public_key_der_encoded = generate_rsa_key_pair(
2048,
passphrase=passphrase,
)
Expand Down Expand Up @@ -102,7 +103,7 @@ def test_auth_keypair_encrypted_without_passphrase():
from snowflake.connector.errors import ProgrammingError

passphrase = b"test"
private_key_der, _ = generate_key_pair(
private_key_der, _ = generate_rsa_key_pair(
2048,
passphrase=passphrase,
)
Expand All @@ -124,7 +125,7 @@ def test_auth_keypair_wrong_passphrase():
from snowflake.connector.errors import ProgrammingError

passphrase = b"correct_passphrase"
private_key_der, _ = generate_key_pair(
private_key_der, _ = generate_rsa_key_pair(
2048,
passphrase=passphrase,
)
Expand All @@ -145,7 +146,7 @@ def test_auth_keypair_wrong_passphrase():


def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
private_key_der, _ = generate_key_pair(2048)
private_key_der, _ = generate_rsa_key_pair(2048)
auth_class = AuthByKeyPair(private_key=private_key_der)

req_body_before = create_mock_auth_body()
Expand All @@ -162,7 +163,7 @@ def test_auth_prepare_body_does_not_overwrite_client_environment_fields():

def test_auth_keypair_abc():
"""Simple Key Pair test using abstraction layer."""
private_key_der, public_key_der_encoded = generate_key_pair(2048)
private_key_der, public_key_der_encoded = generate_rsa_key_pair(2048)
application = "testapplication"
account = "testaccount"
user = "testuser"
Expand Down Expand Up @@ -211,7 +212,7 @@ class Bad:

@patch("snowflake.connector.auth.keypair.AuthByKeyPair.prepare")
def test_renew_token(mockPrepare):
private_key_der, _ = generate_key_pair(2048)
private_key_der, _ = generate_rsa_key_pair(2048)
auth_instance = AuthByKeyPair(private_key=private_key_der)

# force renew condition to be met
Expand Down Expand Up @@ -249,7 +250,7 @@ def _init_rest(application, post_requset):
return rest


def generate_key_pair(key_length: int, *, passphrase: bytes | None = None):
def generate_rsa_key_pair(key_length: int, *, passphrase: bytes | None = None):
private_key = rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, key_size=key_length
)
Expand Down Expand Up @@ -278,6 +279,148 @@ def generate_key_pair(key_length: int, *, passphrase: bytes | None = None):
return private_key_der, public_key_der_encoded


def generate_ec_key_pair(
curve: ec.EllipticCurve,
*,
passphrase: bytes | None = None,
):
private_key = ec.generate_private_key(curve, default_backend())

private_key_der = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=(
serialization.BestAvailableEncryption(passphrase)
if passphrase
else serialization.NoEncryption()
),
)

public_key_pem = (
private_key.public_key()
.public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
)
.decode("utf-8")
)

# strip off header
public_key_der_encoded = "".join(public_key_pem.split("\n")[1:-2])

return private_key_der, public_key_der_encoded


@pytest.mark.skipolddriver
@pytest.mark.parametrize(
"curve",
[ec.SECP256R1(), ec.SECP384R1(), ec.SECP521R1()],
ids=["P-256", "P-384", "P-521"],
)
def test_auth_keypair_ecdsa(curve):
"""ECDSA Key Pair test for supported curves."""
private_key_der, _ = generate_ec_key_pair(curve)
application = "testapplication"
account = "testaccount"
user = "testuser"
auth_instance = AuthByKeyPair(private_key=private_key_der)
auth_instance._retry_ctx.set_start_time()
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)

rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
auth = Auth(rest)
auth.authenticate(auth_instance, account, user)
assert not rest._connection.errorhandler.called
assert rest.token == "TOKEN"
assert rest.master_token == "MASTER_TOKEN"


@pytest.mark.skipolddriver
def test_auth_keypair_ecdsa_with_passphrase():
"""ECDSA Key Pair test with passphrase."""
passphrase = b"test"
private_key_der, _ = generate_ec_key_pair(ec.SECP256R1(), passphrase=passphrase)
application = "testapplication"
account = "testaccount"
user = "testuser"
auth_instance = AuthByKeyPair(
private_key=private_key_der,
private_key_passphrase=passphrase,
)
auth_instance._retry_ctx.set_start_time()
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)

rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
auth = Auth(rest)
auth.authenticate(auth_instance, account, user)
assert not rest._connection.errorhandler.called
assert rest.token == "TOKEN"
assert rest.master_token == "MASTER_TOKEN"


@pytest.mark.skipolddriver
def test_auth_keypair_ecdsa_abc():
"""ECDSA Key Pair test using abstraction layer."""
private_key_der, _ = generate_ec_key_pair(ec.SECP256R1())
application = "testapplication"
account = "testaccount"
user = "testuser"

private_key = load_der_private_key(
data=private_key_der,
password=None,
backend=default_backend(),
)

assert isinstance(private_key, EllipticCurvePrivateKey)

auth_instance = AuthByKeyPair(private_key=private_key)
auth_instance._retry_ctx.set_start_time()
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)

rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
auth = Auth(rest)
auth.authenticate(auth_instance, account, user)
assert not rest._connection.errorhandler.called
assert rest.token == "TOKEN"
assert rest.master_token == "MASTER_TOKEN"


@pytest.mark.skipolddriver
def test_auth_keypair_ecdsa_unsupported_curve():
"""Test that unsupported EC curve raises error."""
from snowflake.connector.errors import ProgrammingError

# SECP192R1 is not supported
private_key_der, _ = generate_ec_key_pair(ec.SECP192R1())
account = "testaccount"
user = "testuser"

auth_instance = AuthByKeyPair(private_key=private_key_der)

with raises(ProgrammingError) as ex:
auth_instance.prepare(account=account, user=user)

assert "Unsupported EC curve" in str(ex.value)


@pytest.mark.skipolddriver
def test_expand_tilde(monkeypatch, tmp_path):
"""Test tilde expansion on both Windows and Linux/Mac"""
Expand Down
Loading