Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/app/v1/endpoints/create/data_array_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
check_missing_properties,
handle_datetime_fields,
handle_result_field,
build_self_link,
extract_iot_id,
)
from app.v1.endpoints.functions import set_role
from asyncpg.exceptions import InsufficientPrivilegeError
Expand Down
6 changes: 2 additions & 4 deletions api/tests/test_issue7_exception_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@


class TestIssue7ExceptionHandling:
async def test_create_user_unexpected_error_returns_500_without_details(
self,
):
async def test_create_user_unexpected_error_returns_400_without_details(self):
pool = MagicMock()
tx = MagicMock()
tx.__aenter__ = AsyncMock(return_value=None)
Expand All @@ -56,7 +54,7 @@ async def test_create_user_unexpected_error_returns_500_without_details(
pgpool=pool,
)

assert response.status_code == 500
assert response.status_code == 400
assert "details leaked" not in response.body.decode()

async def test_catch_all_get_stream_error_returns_500(self):
Expand Down
22 changes: 12 additions & 10 deletions api/tests/test_oauth_connection_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ async def mock_connect(**kwargs):
)

async def test_connection_always_closed_even_on_exception(self):
"""After fix: Connection is always closed in finally block."""
"""After fix: Connection is always closed when an error happens after acquire."""
close_called = []

async def mock_connect(**kwargs):
Expand All @@ -428,23 +428,25 @@ async def track_close():
mock_conn.close = track_close
return mock_conn

# Mock pool that raises exception
# Connection used inside pool.acquire()
mock_pool_conn = AsyncMock()
mock_pool_conn.fetchrow = AsyncMock(
side_effect=asyncpg.PostgresConnectionError("Connection lost")
)

mock_pool = MagicMock()

@asynccontextmanager
async def mock_acquire():
raise asyncpg.PostgresConnectionError("Connection lost")
yield mock_pool_conn

mock_pool.acquire = mock_acquire
mock_pool.acquire = MagicMock(return_value=mock_acquire())

with patch("asyncpg.connect", side_effect=mock_connect), patch(
"app.oauth.get_pool", AsyncMock(return_value=mock_pool)
):
with patch("asyncpg.connect", side_effect=mock_connect), \
patch("app.oauth.get_pool", AsyncMock(return_value=mock_pool)):

try:
with pytest.raises(asyncpg.PostgresConnectionError):
await oauth.authenticate_user("test", "pass")
except HTTPException:
pass # Expected

# Verify connection was closed despite exception
assert len(close_called) == 1, "Connection was not closed!"
Expand Down
15 changes: 6 additions & 9 deletions api/tests/test_set_role_sql_safety.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
import sys
from contextlib import asynccontextmanager
Expand All @@ -7,14 +6,10 @@

import pytest

pytestmark = pytest.mark.asyncio(loop_scope="function")

# Ensure api/ is on sys.path so 'app' resolves to api/app
API_DIR = str(Path(__file__).resolve().parents[1])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)

# Patch env vars before importing app
os.environ.setdefault("ISTSOS_ADMIN", "admin")
os.environ.setdefault("ISTSOS_ADMIN_PASSWORD", "secret")
os.environ.setdefault("POSTGRES_HOST", "localhost")
Expand All @@ -34,20 +29,22 @@ async def transaction(self):
yield


def test_set_role_quotes_valid_identifier():
@pytest.mark.asyncio
async def test_set_role_quotes_valid_identifier():
conn = DummyConnection()
current_user = {"username": "test_user"}

asyncio.run(set_role(conn, current_user))
await set_role(conn, current_user)

conn.execute.assert_awaited_once_with('SET ROLE "test_user";')


def test_set_role_rejects_invalid_identifier():
@pytest.mark.asyncio
async def test_set_role_rejects_invalid_identifier():
conn = DummyConnection()
current_user = {"username": 'bad"name'}

with pytest.raises(ValueError, match="Invalid role identifier"):
asyncio.run(set_role(conn, current_user))
await set_role(conn, current_user)

conn.execute.assert_not_awaited()