Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyhosts/hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ def __getattr__(self, name: str) -> Optional[Host]:
self._ensure_loaded()
return self.find_one(name)

# Context manager support

def __enter__(self) -> 'Hosts':
"""Enter the context manager."""
return self

def __exit__(self, exc_type: type | None, exc_val: BaseException | None,
exc_tb: object) -> None:
"""Exit the context manager, saving on clean exit."""
if exc_type is None:
self.save()

def __repr__(self) -> str:
"""Developer-friendly representation."""
if self._loaded:
Expand Down
36 changes: 36 additions & 0 deletions test/test_new_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,39 @@ def test_persist_backward_compatibility(self):
self.assertIsNotNone(hosts2.find_one('newhost'))
finally:
temp_path.unlink()

def test_context_manager_saves_on_exit(self):
"""Test that context manager auto-saves on clean exit."""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.hosts') as f:
f.write('127.0.0.1 localhost\n')
temp_path = Path(f.name)

try:
with Hosts(file_path=temp_path) as hosts:
hosts.add(Host(ip_address=ip_address('10.0.0.1'), hostname='ctxhost'))

# Verify it was saved automatically
hosts2 = Hosts(file_path=temp_path)
self.assertIsNotNone(hosts2.find_one('ctxhost'))
finally:
temp_path.unlink()

def test_context_manager_no_save_on_exception(self):
"""Test that context manager does not save when an exception occurs."""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.hosts') as f:
f.write('127.0.0.1 localhost\n')
temp_path = Path(f.name)

try:
try:
with Hosts(file_path=temp_path) as hosts:
hosts.add(Host(ip_address=ip_address('10.0.0.1'), hostname='badhost'))
raise RuntimeError("something went wrong")
except RuntimeError:
pass

# Verify it was NOT saved
hosts2 = Hosts(file_path=temp_path)
self.assertIsNone(hosts2.find_one('badhost'))
finally:
temp_path.unlink()
Loading