Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/sparqlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
SPARQLQueryTypeLiteral,
SelectQuery,
)
from sparqlx.utils.utils import SPARQLParseException
from sparqlx.utils.utils import (
QueryParseException,
SPARQLParseException,
UpdateParseException,
)

__all__ = (
"SPARQLWrapper",
"SPARQLParseException",
"QueryParseException",
"UpdateParseException",
"AskQuery",
"ConstructQuery",
"DescribeQuery",
Expand Down
63 changes: 56 additions & 7 deletions src/sparqlx/sparqlwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
QueryOperationParameters,
UpdateOperationParameters,
)
from sparqlx.utils.utils import _get_query_type, _get_response_converter
from sparqlx.utils.utils import (
_get_query_type,
_get_response_converter,
_parse_udpate_request,
)


class SPARQLWrapper(AbstractContextManager, AbstractAsyncContextManager):
Expand All @@ -42,10 +46,13 @@ def __init__(
client_config: dict | None = None,
aclient: httpx.AsyncClient | None = None,
aclient_config: dict | None = None,
parse: bool = True,
) -> None:
self.sparql_endpoint = sparql_endpoint
self.update_endpoint = update_endpoint

self.parse = parse

self._client_manager = ClientManager(
client=client,
client_config=client_config,
Expand Down Expand Up @@ -76,6 +83,7 @@ def query(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> list[SPARQLResultBinding]: ...

@overload
Expand All @@ -87,6 +95,7 @@ def query(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> bool: ...

@overload
Expand All @@ -98,6 +107,7 @@ def query(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> Graph: ...

@overload
Expand All @@ -109,6 +119,7 @@ def query(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> list[SPARQLResultBinding] | Graph | bool: ...

@overload
Expand All @@ -120,6 +131,7 @@ def query(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> httpx.Response: ...

def query(
Expand All @@ -130,8 +142,10 @@ def query(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> httpx.Response | list[SPARQLResultBinding] | Graph | bool:
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query)
_parse: bool = self.parse if parse is None else parse
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse)

params = QueryOperationParameters(
query=query,
Expand Down Expand Up @@ -169,6 +183,7 @@ async def aquery(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> list[SPARQLResultBinding]: ...

@overload
Expand All @@ -180,6 +195,7 @@ async def aquery(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> bool: ...

@overload
Expand All @@ -191,6 +207,7 @@ async def aquery(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> Graph: ...

@overload
Expand All @@ -202,6 +219,7 @@ async def aquery(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> list[SPARQLResultBinding] | Graph | bool: ...

@overload
Expand All @@ -213,6 +231,7 @@ async def aquery(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> httpx.Response: ...

async def aquery(
Expand All @@ -223,8 +242,10 @@ async def aquery(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> httpx.Response | list[SPARQLResultBinding] | Graph | bool:
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query)
_parse: bool = self.parse if parse is None else parse
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse)

params = QueryOperationParameters(
query=query,
Expand Down Expand Up @@ -264,8 +285,10 @@ def query_stream[T](
[httpx.Response], Iterator[T]
] = httpx.Response.iter_bytes,
chunk_size: int | None = None,
parse: bool | None = None,
) -> Iterator[T]:
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query)
_parse: bool = self.parse if parse is None else parse
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse)

params = QueryOperationParameters(
query=query,
Expand Down Expand Up @@ -305,8 +328,10 @@ async def aquery_stream[T](
[httpx.Response], AsyncIterator[T]
] = httpx.Response.aiter_bytes,
chunk_size: int | None = None,
parse: bool | None = None,
) -> AsyncIterator[T]:
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query)
_parse: bool = self.parse if parse is None else parse
query_type: SPARQLQueryTypeLiteral = _get_query_type(query=query, parse=_parse)

params = QueryOperationParameters(
query=query,
Expand Down Expand Up @@ -344,6 +369,7 @@ def queries(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> Iterator[list[SPARQLResultBinding] | Graph | bool]: ...

@overload
Expand All @@ -355,6 +381,7 @@ def queries(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> Iterator[httpx.Response]: ...

def queries(
Expand All @@ -365,9 +392,14 @@ def queries(
version: str | None = None,
default_graph_uri: RequestDataValue = None,
named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> Iterator[httpx.Response | list[SPARQLResultBinding] | Graph | bool]:
_parse: bool = self.parse if parse is None else parse

query_component = SPARQLWrapper(
sparql_endpoint=self.sparql_endpoint, aclient=self._client_manager.aclient
sparql_endpoint=self.sparql_endpoint,
aclient=self._client_manager.aclient,
parse=_parse,
)

async def _runner() -> Iterator[httpx.Response]:
Expand Down Expand Up @@ -397,7 +429,13 @@ def update(
version: str | None = None,
using_graph_uri: RequestDataValue = None,
using_named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> httpx.Response:
_parse: bool = self.parse if parse is None else parse

if _parse:
_parse_udpate_request(update_request=update_request)

params = UpdateOperationParameters(
update_request=update_request,
version=version,
Expand All @@ -420,7 +458,13 @@ async def aupdate(
version: str | None = None,
using_graph_uri: RequestDataValue = None,
using_named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> httpx.Response:
_parse: bool = self.parse if parse is None else parse

if _parse:
_parse_udpate_request(update_request=update_request)

params = UpdateOperationParameters(
update_request=update_request,
version=version,
Expand All @@ -443,9 +487,14 @@ def updates(
version: str | None = None,
using_graph_uri: RequestDataValue = None,
using_named_graph_uri: RequestDataValue = None,
parse: bool | None = None,
) -> Iterator[httpx.Response]:
_parse: bool = self.parse if parse is None else parse

update_component = SPARQLWrapper(
update_endpoint=self.update_endpoint, aclient=self._client_manager.aclient
update_endpoint=self.update_endpoint,
aclient=self._client_manager.aclient,
parse=_parse,
)

async def _runner() -> Iterator[httpx.Response]:
Expand Down
73 changes: 65 additions & 8 deletions src/sparqlx/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,80 @@
from typing import cast
from typing import TypeGuard, get_args

from rdflib.plugins.sparql import prepareQuery
from rdflib.plugins.sparql.parser import parseUpdate
from rdflib.plugins.sparql.sparql import Query
from sparqlx.types import SPARQLQuery, SPARQLQueryTypeLiteral, SPARQLResponseFormat
from sparqlx.types import (
AskQuery,
ConstructQuery,
DescribeQuery,
SPARQLQuery,
SPARQLQueryTypeLiteral,
SPARQLResponseFormat,
SelectQuery,
)
from sparqlx.utils.converters import _convert_ask, _convert_bindings, _convert_graph


class SPARQLParseException(Exception): ...


def _get_query_type(query: SPARQLQuery) -> SPARQLQueryTypeLiteral:
class QueryParseException(SPARQLParseException): ...


class UpdateParseException(SPARQLParseException): ...


def _parse_udpate_request(update_request: str) -> None:
try:
_prepared_query: Query = prepareQuery(query)
parseUpdate(update_request)
except Exception as exc:
raise SPARQLParseException(exc) from exc
else:
query_type = _prepared_query.algebra.name
raise UpdateParseException(exc) from exc


def _is_sparql_query_type_literal(value) -> TypeGuard[SPARQLQueryTypeLiteral]:
return value in get_args(SPARQLQueryTypeLiteral.__value__)


return cast(SPARQLQueryTypeLiteral, query_type)
def _get_query_type(query: SPARQLQuery, parse: bool) -> SPARQLQueryTypeLiteral:
def _from_typed_query(
query: SelectQuery | AskQuery | ConstructQuery | DescribeQuery,
) -> SPARQLQueryTypeLiteral:
match query:
case SelectQuery():
query_type = "SelectQuery"
case AskQuery():
query_type = "AskQuery"
case ConstructQuery():
query_type = "ConstructQuery"
case DescribeQuery():
query_type = "DescribeQuery"
case _: # pragma: no cover
assert False, "This should never happen."

return query_type

def _from_parsed_query(query: str) -> SPARQLQueryTypeLiteral:
try:
_prepared_query: Query = prepareQuery(query)
except Exception as exc:
raise QueryParseException(exc) from exc
else:
query_type = _prepared_query.algebra.name

assert _is_sparql_query_type_literal(query_type)
return query_type

is_typed_query: bool = isinstance(
query, SelectQuery | AskQuery | ConstructQuery | DescribeQuery
)

if not is_typed_query and not parse:
msg = "Query must be of type SelectQuery | AskQuery | ConstructQuery | DescribeQuery if parse=False."
raise ValueError(msg)
elif is_typed_query and not parse:
return _from_typed_query(query)
else:
return _from_parsed_query(query)


def _get_response_converter(
Expand Down
Loading