Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/snowflake/connector/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import abc
import collections
import contextlib
import functools
import itertools
Expand Down Expand Up @@ -172,6 +171,24 @@ def get_adapter(self, **override_adapter_factory_kwargs) -> HTTPAdapter:
return self.adapter_factory(**self_kwargs_for_adapter_factory)


class _SessionsMap(dict):
"""A dict subclass that auto-creates SessionPool on missing key access.

Unlike defaultdict with a lambda, this avoids creating a reference cycle
between the SessionManager and the factory function closure, preventing
memory leaks (see GitHub issue #2727).
"""

def __init__(self, manager: SessionManager) -> None:
super().__init__()
self._manager = manager

def __missing__(self, key):
pool = SessionPool(self._manager)
self[key] = pool
return pool


class SessionPool(Generic[SessionT]):
"""
Component responsible for storing and reusing established session instances.
Expand Down Expand Up @@ -408,9 +425,7 @@ def __init__(self, config: HttpConfig | None = None, **http_config_kwargs) -> No
config = HttpConfig(**http_config_kwargs)
self._cfg: HttpConfig = config
# Maps hostname to SessionPool instance for its connections
self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict(
lambda: SessionPool(self)
)
self._sessions_map: dict[str | None, SessionPool] = _SessionsMap(self)

@classmethod
def from_config(cls, cfg: HttpConfig, **overrides: Any) -> SessionManager:
Expand Down Expand Up @@ -575,16 +590,16 @@ def clone(

def __getstate__(self):
state = self.__dict__.copy()
# `_sessions_map` contains a defaultdict with a lambda referencing `self`,
# which is not pickle-able. Convert to a regular dict for serialization.
# `_sessions_map` contains a _SessionsMap referencing `self`,
# which is not directly pickle-able. Convert to a regular dict for serialization.
state["_sessions_map_items"] = list(state.pop("_sessions_map").items())
return state

def __setstate__(self, state):
# Restore attributes except sessions_map
sessions_items = state.pop("_sessions_map_items", [])
self.__dict__.update(state)
self._sessions_map = collections.defaultdict(lambda: SessionPool(self))
self._sessions_map = _SessionsMap(self)
for host, pool in sessions_items:
self._sessions_map[host] = pool

Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_context_var_weakref_does_not_leak():
sm = SessionManager(passed_config)
token = set_current_session_manager(sm)

# The context var should return the same object while its alive
# The context var should return the same object while it's alive
assert (
get_current_session_manager(create_default_if_missing=False).config
== passed_config
Expand Down
Loading