diff --git a/src/sparqlx/__init__.py b/src/sparqlx/__init__.py index a321a20..0a8719d 100644 --- a/src/sparqlx/__init__.py +++ b/src/sparqlx/__init__.py @@ -8,11 +8,17 @@ SPARQLQueryTypeLiteral, SelectQuery, ) -from sparqlx.utils.utils import SPARQLParseException +from sparqlx.utils.utils import ( + QueryParseException, + SPARQLParseException, + UpdateParseException, +) __all__ = ( "SPARQLWrapper", "SPARQLParseException", + "QueryParseException", + "UpdateParseException", "AskQuery", "ConstructQuery", "DescribeQuery", diff --git a/src/sparqlx/sparqlwrapper.py b/src/sparqlx/sparqlwrapper.py index 66e4b11..a7033e5 100644 --- a/src/sparqlx/sparqlwrapper.py +++ b/src/sparqlx/sparqlwrapper.py @@ -24,7 +24,11 @@ QueryOperationParameters, UpdateOperationParameters, ) -from sparqlx.utils.utils import _get_query_type, _get_response_converter +from sparqlx.utils.utils import ( + _get_query_type, + _get_response_converter, + _parse_udpate_request, +) class SPARQLWrapper(AbstractContextManager, AbstractAsyncContextManager): @@ -42,10 +46,13 @@ def __init__( client_config: dict | None = None, aclient: httpx.AsyncClient | None = None, aclient_config: dict | None = None, + parse: bool = True, ) -> None: self.sparql_endpoint = sparql_endpoint self.update_endpoint = update_endpoint + self.parse = parse + self._client_manager = ClientManager( client=client, client_config=client_config, @@ -76,6 +83,7 @@ def query( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> list[SPARQLResultBinding]: ... @overload @@ -87,6 +95,7 @@ def query( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> bool: ... @overload @@ -98,6 +107,7 @@ def query( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> Graph: ... @overload @@ -109,6 +119,7 @@ def query( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> list[SPARQLResultBinding] | Graph | bool: ... @overload @@ -120,6 +131,7 @@ def query( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> httpx.Response: ... def query( @@ -130,8 +142,10 @@ def query( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> httpx.Response | list[SPARQLResultBinding] | Graph | bool: - query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query) + _parse: bool = self.parse if parse is None else parse + query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse) params = QueryOperationParameters( query=query, @@ -169,6 +183,7 @@ async def aquery( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> list[SPARQLResultBinding]: ... @overload @@ -180,6 +195,7 @@ async def aquery( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> bool: ... @overload @@ -191,6 +207,7 @@ async def aquery( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> Graph: ... @overload @@ -202,6 +219,7 @@ async def aquery( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> list[SPARQLResultBinding] | Graph | bool: ... @overload @@ -213,6 +231,7 @@ async def aquery( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> httpx.Response: ... async def aquery( @@ -223,8 +242,10 @@ async def aquery( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> httpx.Response | list[SPARQLResultBinding] | Graph | bool: - query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query) + _parse: bool = self.parse if parse is None else parse + query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse) params = QueryOperationParameters( query=query, @@ -264,8 +285,10 @@ def query_stream[T]( [httpx.Response], Iterator[T] ] = httpx.Response.iter_bytes, chunk_size: int | None = None, + parse: bool | None = None, ) -> Iterator[T]: - query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query) + _parse: bool = self.parse if parse is None else parse + query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse) params = QueryOperationParameters( query=query, @@ -305,8 +328,10 @@ async def aquery_stream[T]( [httpx.Response], AsyncIterator[T] ] = httpx.Response.aiter_bytes, chunk_size: int | None = None, + parse: bool | None = None, ) -> AsyncIterator[T]: - query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query) + _parse: bool = self.parse if parse is None else parse + query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse) params = QueryOperationParameters( query=query, @@ -344,6 +369,7 @@ def queries( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> Iterator[list[SPARQLResultBinding] | Graph | bool]: ... @overload @@ -355,6 +381,7 @@ def queries( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> Iterator[httpx.Response]: ... def queries( @@ -365,9 +392,14 @@ def queries( version: str | None = None, default_graph_uri: RequestDataValue = None, named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> Iterator[httpx.Response | list[SPARQLResultBinding] | Graph | bool]: + _parse: bool = self.parse if parse is None else parse + query_component = SPARQLWrapper( - sparql_endpoint=self.sparql_endpoint, aclient=self._client_manager.aclient + sparql_endpoint=self.sparql_endpoint, + aclient=self._client_manager.aclient, + parse=_parse, ) async def _runner() -> Iterator[httpx.Response]: @@ -397,7 +429,13 @@ def update( version: str | None = None, using_graph_uri: RequestDataValue = None, using_named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> httpx.Response: + _parse: bool = self.parse if parse is None else parse + + if _parse: + _parse_udpate_request(update_request=update_request) + params = UpdateOperationParameters( update_request=update_request, version=version, @@ -420,7 +458,13 @@ async def aupdate( version: str | None = None, using_graph_uri: RequestDataValue = None, using_named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> httpx.Response: + _parse: bool = self.parse if parse is None else parse + + if _parse: + _parse_udpate_request(update_request=update_request) + params = UpdateOperationParameters( update_request=update_request, version=version, @@ -443,9 +487,14 @@ def updates( version: str | None = None, using_graph_uri: RequestDataValue = None, using_named_graph_uri: RequestDataValue = None, + parse: bool | None = None, ) -> Iterator[httpx.Response]: + _parse: bool = self.parse if parse is None else parse + update_component = SPARQLWrapper( - update_endpoint=self.update_endpoint, aclient=self._client_manager.aclient + update_endpoint=self.update_endpoint, + aclient=self._client_manager.aclient, + parse=_parse, ) async def _runner() -> Iterator[httpx.Response]: diff --git a/src/sparqlx/utils/utils.py b/src/sparqlx/utils/utils.py index 84d0262..dcf8777 100644 --- a/src/sparqlx/utils/utils.py +++ b/src/sparqlx/utils/utils.py @@ -1,23 +1,80 @@ -from typing import cast +from typing import TypeGuard, get_args from rdflib.plugins.sparql import prepareQuery +from rdflib.plugins.sparql.parser import parseUpdate from rdflib.plugins.sparql.sparql import Query -from sparqlx.types import SPARQLQuery, SPARQLQueryTypeLiteral, SPARQLResponseFormat +from sparqlx.types import ( + AskQuery, + ConstructQuery, + DescribeQuery, + SPARQLQuery, + SPARQLQueryTypeLiteral, + SPARQLResponseFormat, + SelectQuery, +) from sparqlx.utils.converters import _convert_ask, _convert_bindings, _convert_graph class SPARQLParseException(Exception): ... -def _get_query_type(query: SPARQLQuery) -> SPARQLQueryTypeLiteral: +class QueryParseException(SPARQLParseException): ... + + +class UpdateParseException(SPARQLParseException): ... + + +def _parse_udpate_request(update_request: str) -> None: try: - _prepared_query: Query = prepareQuery(query) + parseUpdate(update_request) except Exception as exc: - raise SPARQLParseException(exc) from exc - else: - query_type = _prepared_query.algebra.name + raise UpdateParseException(exc) from exc + + +def _is_sparql_query_type_literal(value) -> TypeGuard[SPARQLQueryTypeLiteral]: + return value in get_args(SPARQLQueryTypeLiteral.__value__) + - return cast(SPARQLQueryTypeLiteral, query_type) +def _get_query_type(query: SPARQLQuery, parse: bool) -> SPARQLQueryTypeLiteral: + def _from_typed_query( + query: SelectQuery | AskQuery | ConstructQuery | DescribeQuery, + ) -> SPARQLQueryTypeLiteral: + match query: + case SelectQuery(): + query_type = "SelectQuery" + case AskQuery(): + query_type = "AskQuery" + case ConstructQuery(): + query_type = "ConstructQuery" + case DescribeQuery(): + query_type = "DescribeQuery" + case _: # pragma: no cover + assert False, "This should never happen." + + return query_type + + def _from_parsed_query(query: str) -> SPARQLQueryTypeLiteral: + try: + _prepared_query: Query = prepareQuery(query) + except Exception as exc: + raise QueryParseException(exc) from exc + else: + query_type = _prepared_query.algebra.name + + assert _is_sparql_query_type_literal(query_type) + return query_type + + is_typed_query: bool = isinstance( + query, SelectQuery | AskQuery | ConstructQuery | DescribeQuery + ) + + if not is_typed_query and not parse: + msg = "Query must be of type SelectQuery | AskQuery | ConstructQuery | DescribeQuery if parse=False." + raise ValueError(msg) + elif is_typed_query and not parse: + return _from_typed_query(query) + else: + return _from_parsed_query(query) def _get_response_converter( diff --git a/tests/test_sparqlwrapper/test_sparqlwrapper_optional_parsing.py b/tests/test_sparqlwrapper/test_sparqlwrapper_optional_parsing.py new file mode 100644 index 0000000..1f9ff0c --- /dev/null +++ b/tests/test_sparqlwrapper/test_sparqlwrapper_optional_parsing.py @@ -0,0 +1,121 @@ +from typing import NamedTuple + +import httpx +import pytest +from sparqlx import QueryParseException, SPARQLWrapper, UpdateParseException +from sparqlx.types import ( + AskQuery, + ConstructQuery, + DescribeQuery, + SPARQLQuery, + SelectQuery, +) + +from utils import acall + + +class OptionalParseParameters(NamedTuple): + exception: type[Exception] + invalid_sparql: str | SPARQLQuery = "INVALID" + + wrapper_parse: bool = True + method_parse: bool | None = None + + +query_params = [ + OptionalParseParameters(exception=QueryParseException), + OptionalParseParameters( + invalid_sparql=SelectQuery("INVALID"), + exception=httpx.HTTPStatusError, + wrapper_parse=False, + ), + OptionalParseParameters( + invalid_sparql=AskQuery("INVALID"), + exception=httpx.HTTPStatusError, + wrapper_parse=False, + ), + OptionalParseParameters( + invalid_sparql=ConstructQuery("INVALID"), + exception=httpx.HTTPStatusError, + wrapper_parse=False, + ), + OptionalParseParameters( + invalid_sparql=DescribeQuery("INVALID"), + exception=httpx.HTTPStatusError, + wrapper_parse=False, + ), + OptionalParseParameters(exception=ValueError, wrapper_parse=False), + OptionalParseParameters( + exception=QueryParseException, wrapper_parse=False, method_parse=True + ), + OptionalParseParameters(exception=QueryParseException, method_parse=True), +] + + +@pytest.mark.parametrize("method", ["query", "aquery"]) +@pytest.mark.parametrize("param", query_params) +@pytest.mark.parametrize("managed_client", [True, False]) +@pytest.mark.asyncio +async def test_sparqlwrapper_query_optional_parse( + method, param, triplestore, managed_client +): + sparql_endpoint: str = triplestore.sparql_endpoint + + client, aclient = ( + (httpx.Client(), httpx.AsyncClient()) if managed_client else (None, None) + ) + + sparqlwrapper = SPARQLWrapper( + sparql_endpoint=sparql_endpoint, + client=client, + aclient=aclient, + parse=param.wrapper_parse, + ) + + with pytest.raises(param.exception): + await acall( + sparqlwrapper, method, query=param.invalid_sparql, parse=param.method_parse + ) + + +update_params = [ + OptionalParseParameters(exception=UpdateParseException), + OptionalParseParameters( + exception=UpdateParseException, wrapper_parse=False, method_parse=True + ), + OptionalParseParameters(exception=UpdateParseException, method_parse=True), + OptionalParseParameters(exception=httpx.HTTPStatusError, wrapper_parse=False), + OptionalParseParameters(exception=httpx.HTTPStatusError, method_parse=False), + OptionalParseParameters( + exception=UpdateParseException, wrapper_parse=False, method_parse=True + ), +] + + +@pytest.mark.parametrize("method", ["update", "aupdate"]) +@pytest.mark.parametrize("param", update_params) +@pytest.mark.parametrize("managed_client", [True, False]) +@pytest.mark.asyncio +async def test_sparqlwrapper_update_optional_parse( + method, param, triplestore, managed_client +): + update_endpoint: str = triplestore.update_endpoint + + client, aclient = ( + (httpx.Client(), httpx.AsyncClient()) if managed_client else (None, None) + ) + + sparqlwrapper = SPARQLWrapper( + update_endpoint=update_endpoint, + client=client, + aclient=aclient, + parse=param.wrapper_parse, + ) + + with pytest.raises(param.exception): + await acall( + sparqlwrapper, + method, + update_request=param.invalid_sparql, + parse=param.method_parse, + )