diff --git a/api/app/utils/policy_expression.py b/api/app/utils/policy_expression.py new file mode 100644 index 0000000..1261ce8 --- /dev/null +++ b/api/app/utils/policy_expression.py @@ -0,0 +1,273 @@ +"""Validation and rendering for custom row-level-security policy predicates.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + + +_IDENTIFIER_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") +_NUMBER_RE = re.compile(r"-?(?:\d+\.\d+|\d+)") + +_KEYWORDS = { + "AND", + "FALSE", + "ILIKE", + "IN", + "IS", + "LIKE", + "NOT", + "NULL", + "OR", + "TRUE", +} + + +@dataclass(frozen=True) +class _Token: + kind: str + value: object + + +class _PolicyExpressionParser: + def __init__(self, tokens: list[_Token]): + self._tokens = tokens + self._pos = 0 + + def parse(self) -> str: + expression = self._parse_or() + if self._peek().kind != "EOF": + raise ValueError("Unexpected token in policy expression") + return expression + + def _parse_or(self) -> str: + expression = self._parse_and() + while self._match_keyword("OR"): + rhs = self._parse_and() + expression = f"({expression} OR {rhs})" + return expression + + def _parse_and(self) -> str: + expression = self._parse_not() + while self._match_keyword("AND"): + rhs = self._parse_not() + expression = f"({expression} AND {rhs})" + return expression + + def _parse_not(self) -> str: + if self._match_keyword("NOT"): + return f"(NOT {self._parse_not()})" + return self._parse_predicate() + + def _parse_predicate(self) -> str: + left = self._parse_primary() + + if self._match_keyword("IS"): + not_sql = " NOT" if self._match_keyword("NOT") else "" + value = self._consume_constant({"NULL", "TRUE", "FALSE"}) + return f"{left} IS{not_sql} {value}" + + if self._match_keyword("NOT"): + if self._match_keyword("IN"): + values = self._parse_value_list() + return f"{left} NOT IN ({', '.join(values)})" + if self._match_keyword("LIKE"): + return f"{left} NOT LIKE {self._parse_primary()}" + if self._match_keyword("ILIKE"): + return f"{left} NOT ILIKE {self._parse_primary()}" + raise ValueError("Expected IN, LIKE, or ILIKE after NOT") + + if self._match_keyword("IN"): + values = self._parse_value_list() + return f"{left} IN ({', '.join(values)})" + + if self._match_keyword("LIKE"): + return f"{left} LIKE {self._parse_primary()}" + + if self._match_keyword("ILIKE"): + return f"{left} ILIKE {self._parse_primary()}" + + if self._peek().kind == "OP": + operator = self._advance().value + right = self._parse_primary() + return f"{left} {operator} {right}" + + return left + + def _parse_primary(self) -> str: + token = self._peek() + + if self._match("("): + expression = self._parse_or() + self._expect(")") + return f"({expression})" + + if token.kind == "IDENT": + return self._quote_identifier(str(self._advance().value)) + + if token.kind == "KEYWORD": + if token.value in {"TRUE", "FALSE", "NULL"}: + return str(self._advance().value) + raise ValueError("Unexpected keyword in policy expression") + + if token.kind == "STRING": + value = str(self._advance().value) + return "'" + value.replace("'", "''") + "'" + + if token.kind == "NUMBER": + return str(self._advance().value) + + raise ValueError("Unexpected token in policy expression") + + def _parse_value_list(self) -> list[str]: + self._expect("(") + values = [self._parse_primary()] + while self._match(","): + values.append(self._parse_primary()) + self._expect(")") + return values + + def _consume_constant(self, allowed: set[str]) -> str: + token = self._peek() + if token.kind == "KEYWORD" and token.value in allowed: + return str(self._advance().value) + raise ValueError("Expected constant in policy expression") + + def _peek(self) -> _Token: + return self._tokens[self._pos] + + def _advance(self) -> _Token: + token = self._peek() + self._pos += 1 + return token + + def _match(self, value: str) -> bool: + if self._peek().value == value: + self._advance() + return True + return False + + def _match_keyword(self, value: str) -> bool: + token = self._peek() + if token.kind == "KEYWORD" and token.value == value: + self._advance() + return True + return False + + def _expect(self, value: str) -> None: + if not self._match(value): + raise ValueError("Malformed policy expression") + + @staticmethod + def _quote_identifier(value: str) -> str: + return '"' + value.replace('"', '""') + '"' + + +def render_policy_expression(value: str) -> str: + """Return a normalized SQL predicate from a constrained expression grammar.""" + if not isinstance(value, str): + raise ValueError("Policy expression must be a string") + + expression = value.strip() + if expression == "": + raise ValueError("Policy expression must not be empty") + + tokens = _tokenize(expression) + return _PolicyExpressionParser(tokens).parse() + + +def _tokenize(expression: str) -> list[_Token]: + tokens: list[_Token] = [] + pos = 0 + + while pos < len(expression): + char = expression[pos] + + if char.isspace(): + pos += 1 + continue + + if char in "(),": + tokens.append(_Token("PUNCT", char)) + pos += 1 + continue + + if expression.startswith(("<=", ">=", "<>", "!="), pos): + tokens.append(_Token("OP", expression[pos : pos + 2])) + pos += 2 + continue + + if char in "=<>": + tokens.append(_Token("OP", char)) + pos += 1 + continue + + if char == "'": + value, pos = _read_string(expression, pos) + tokens.append(_Token("STRING", value)) + continue + + if char == '"': + value, pos = _read_quoted_identifier(expression, pos) + tokens.append(_Token("IDENT", value)) + continue + + identifier = _IDENTIFIER_RE.match(expression, pos) + if identifier is not None: + value = identifier.group(0) + upper_value = value.upper() + if upper_value in _KEYWORDS: + tokens.append(_Token("KEYWORD", upper_value)) + else: + tokens.append(_Token("IDENT", value)) + pos = identifier.end() + continue + + number = _NUMBER_RE.match(expression, pos) + if number is not None: + tokens.append(_Token("NUMBER", number.group(0))) + pos = number.end() + continue + + raise ValueError("Unsafe policy expression") + + tokens.append(_Token("EOF", "EOF")) + return tokens + + +def _read_string(expression: str, pos: int) -> tuple[str, int]: + chars: list[str] = [] + pos += 1 + + while pos < len(expression): + char = expression[pos] + if char == "'": + if pos + 1 < len(expression) and expression[pos + 1] == "'": + chars.append("'") + pos += 2 + continue + return "".join(chars), pos + 1 + chars.append(char) + pos += 1 + + raise ValueError("Unterminated string in policy expression") + + +def _read_quoted_identifier(expression: str, pos: int) -> tuple[str, int]: + chars: list[str] = [] + pos += 1 + + while pos < len(expression): + char = expression[pos] + if char == '"': + if pos + 1 < len(expression) and expression[pos + 1] == '"': + chars.append('"') + pos += 2 + continue + if not chars: + raise ValueError("Invalid policy expression identifier") + return "".join(chars), pos + 1 + chars.append(char) + pos += 1 + + raise ValueError("Unterminated identifier in policy expression") diff --git a/api/app/v1/endpoints/create/policy.py b/api/app/v1/endpoints/create/policy.py index 61df883..8fd5a78 100644 --- a/api/app/v1/endpoints/create/policy.py +++ b/api/app/v1/endpoints/create/policy.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re - from app import POSTGRES_PORT_WRITE from app.db.asyncpg_db import get_pool, get_pool_w from app.oauth import get_current_user +from app.utils.policy_expression import render_policy_expression from app.v1.endpoints.functions import set_role from asyncpg.exceptions import DuplicateObjectError, InsufficientPrivilegeError from fastapi import APIRouter, Body, Depends, status @@ -24,7 +23,6 @@ v1 = APIRouter() -_UNSAFE_POLICY_TOKENS_RE = re.compile(r";|--|/\*|\*/|\x00") _VALID_OPERATION_KEYS = {"select", "insert", "update", "delete"} PAYLOAD_EXAMPLE = { @@ -177,16 +175,6 @@ def quote_identifier(value: str) -> str: raise ValueError("Invalid SQL identifier") return '"' + value.replace('"', '""') + '"' - def validate_policy_expression(value: str) -> str: - if not isinstance(value, str): - raise ValueError("Policy condition must be a string") - expression = value.strip() - if expression == "": - raise ValueError("Policy condition must not be empty") - if _UNSAFE_POLICY_TOKENS_RE.search(expression): - raise ValueError("Unsafe policy condition") - return expression - table_mapping = { "location": "Location", "thing": "Thing", @@ -224,7 +212,7 @@ def validate_policy_expression(value: str) -> str: safe_name = quote_identifier( f"{name}_{table.lower()}_{operation_lc}" ) - safe_condition = validate_policy_expression(condition) + safe_condition = render_policy_expression(condition) if operation_lc in {"select", "delete"}: query = f""" @@ -252,4 +240,4 @@ def validate_policy_expression(value: str) -> str: WITH CHECK ({safe_condition}); """ - await connection.execute(query) \ No newline at end of file + await connection.execute(query) diff --git a/api/app/v1/endpoints/update/policy.py b/api/app/v1/endpoints/update/policy.py index dde93de..1111bda 100644 --- a/api/app/v1/endpoints/update/policy.py +++ b/api/app/v1/endpoints/update/policy.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re - from app import POSTGRES_PORT_WRITE from app.db.asyncpg_db import get_pool, get_pool_w from app.oauth import get_current_user +from app.utils.policy_expression import render_policy_expression from app.utils.utils import pg_quote_ident, validate_payload_keys from app.v1.endpoints.functions import set_role from asyncpg.exceptions import InsufficientPrivilegeError, UndefinedObjectError @@ -32,23 +31,6 @@ "policy", ] -_UNSAFE_POLICY_TOKENS_RE = re.compile(r";|--|/\*|\*/|\x00") - - -def _validate_policy_expression(value: str) -> str: - if not isinstance(value, str): - raise ValueError("Policy expression must be a string") - - expression = value.strip() - if expression == "": - raise ValueError("Policy expression must not be empty") - - if _UNSAFE_POLICY_TOKENS_RE.search(expression): - raise ValueError("Unsafe policy expression") - - return expression - - @v1.api_route( "/Policies", methods=["PATCH"], @@ -98,7 +80,7 @@ async def update_policy( tablename, cmd = row["tablename"], row["cmd"] if payload.get("policy") is not None: - policy_expression = _validate_policy_expression( + safe_condition = render_policy_expression( payload["policy"] ) policy_ident = pg_quote_ident(policy) @@ -106,11 +88,11 @@ async def update_policy( cmd_upper = (cmd or "").upper() policy_sql = { - "SELECT": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({policy_expression});", - "INSERT": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} WITH CHECK ({policy_expression});", - "UPDATE": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({policy_expression}) WITH CHECK ({policy_expression});", - "DELETE": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({policy_expression});", - "ALL": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({policy_expression}) WITH CHECK ({policy_expression});", + "SELECT": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({safe_condition})", + "INSERT": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} WITH CHECK ({safe_condition})", + "UPDATE": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({safe_condition}) WITH CHECK ({safe_condition})", + "DELETE": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({safe_condition})", + "ALL": f"ALTER POLICY {policy_ident} ON sensorthings.{table_ident} USING ({safe_condition}) WITH CHECK ({safe_condition})", }.get(cmd_upper) if policy_sql is None: @@ -139,4 +121,4 @@ async def update_policy( return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"message": str(e)}, - ) \ No newline at end of file + ) diff --git a/api/tests/test_policy_expression_safety.py b/api/tests/test_policy_expression_safety.py new file mode 100644 index 0000000..5c9417f --- /dev/null +++ b/api/tests/test_policy_expression_safety.py @@ -0,0 +1,139 @@ +"""Regression tests for custom policy expression handling.""" + +import asyncio +import os +import sys +from contextlib import asynccontextmanager +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + + +API_DIR = str(Path(__file__).resolve().parents[1]) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +os.environ.setdefault("SECRET_KEY", "test_secret_key") + +from app.utils.policy_expression import render_policy_expression # noqa: E402 +import app.v1.endpoints.create.policy as create_policy_endpoint # noqa: E402 +import app.v1.endpoints.update.policy as update_policy_endpoint # noqa: E402 + + +def test_render_policy_expression_normalizes_simple_condition(): + assert ( + render_policy_expression("network = 'IDROLOGIA' and public = true") + == '("network" = \'IDROLOGIA\' AND "public" = TRUE)' + ) + + +def test_render_policy_expression_keeps_quoted_identifier_case(): + assert ( + render_policy_expression('"unitOfMeasurement" is not null') + == '"unitOfMeasurement" IS NOT NULL' + ) + + +@pytest.mark.parametrize( + "expression", + [ + "true) WITH CHECK (false", + "network = 'x'; DROP TABLE sensorthings.\"User\"", + "exists (select 1)", + "network = 'unterminated", + ], +) +def test_render_policy_expression_rejects_sql_structure(expression): + with pytest.raises(ValueError): + render_policy_expression(expression) + + +def _mock_pgpool(connection): + @asynccontextmanager + async def _acquire(): + yield connection + + class _Pool: + def acquire(self): + return _acquire() + + return _Pool() + + +def _attach_transaction_cm(connection): + @asynccontextmanager + async def _tx(): + yield + + connection.transaction = _tx + + +def test_update_policy_uses_normalized_policy_condition(): + connection = AsyncMock() + connection.execute = AsyncMock() + connection.fetchrow = AsyncMock( + return_value={"tablename": "Datastream", "cmd": "SELECT"} + ) + _attach_transaction_cm(connection) + + response = asyncio.run( + update_policy_endpoint.update_policy( + policy="p1", + payload={"policy": "network = 'IDROLOGIA'"}, + current_user={"username": "admin_user", "role": "administrator"}, + pgpool=_mock_pgpool(connection), + ) + ) + + sql_calls = [c.args[0] for c in connection.execute.await_args_list] + + assert response.status_code == 200 + assert ( + 'ALTER POLICY "p1" ON sensorthings."Datastream" ' + "USING (\"network\" = 'IDROLOGIA')" + ) in sql_calls + + +def test_update_policy_rejects_unrenderable_policy_condition(): + connection = AsyncMock() + connection.execute = AsyncMock() + connection.fetchrow = AsyncMock( + return_value={"tablename": "Datastream", "cmd": "SELECT"} + ) + _attach_transaction_cm(connection) + + response = asyncio.run( + update_policy_endpoint.update_policy( + policy="p1", + payload={"policy": "true) WITH CHECK (false"}, + current_user={"username": "admin_user", "role": "administrator"}, + pgpool=_mock_pgpool(connection), + ) + ) + + sql_calls = [c.args[0] for c in connection.execute.await_args_list] + + assert response.status_code == 400 + assert not any(sql.startswith("ALTER POLICY") for sql in sql_calls) + + +def test_create_policies_rejects_unrenderable_policy_condition(): + connection = AsyncMock() + connection.execute = AsyncMock() + + with pytest.raises(ValueError): + asyncio.run( + create_policy_endpoint.create_policies( + connection=connection, + users=["alice"], + policies={ + "datastream": { + "select": "true) TO PUBLIC USING (true", + } + }, + name="rbac_test", + ) + ) + + connection.execute.assert_not_awaited()