diff --git a/doc/source/index.rst b/doc/source/index.rst index 59d73aa..fe63302 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -251,6 +251,9 @@ or within python code:: >>> with config.set(message_version="v1.01"): ... +Note that if the message version is not set explicitly with the above configuration, or when creating a message +object, the message version will be set to the lowest compatible version, that is v1.01 for messages not +encoding a datetime object, v1.2 otherwise. API --- diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 6a0b86d..5c3b767 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -10,7 +10,7 @@ from posttroll import config from posttroll.backends.zmq import get_tcp_keepalive_options from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket -from posttroll.message import MESSAGE_VERSION +from posttroll.message import CURRENT_MESSAGE_VERSION LOGGER = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def add(self, address: dict[str, str], topics=None): """ with self._lock: addr = ensure_address_is_dict(address) - if addr.get("supported_message_version", MESSAGE_VERSION) > MESSAGE_VERSION: + if addr.get("supported_message_version", CURRENT_MESSAGE_VERSION) > CURRENT_MESSAGE_VERSION: LOGGER.warning(f"Will not connect to {str(addr)}, message version mismatch") return if addr["URI"] in self.address_keys: @@ -102,9 +102,8 @@ def add_hook_sub(self, address, topics, callback): specified subscription. Good for operations, which is required to be done in the same thread as - the main recieve loop (e.q operations on the underlying sockets). + the main receive loop (e.q operations on the underlying sockets). """ - topics = topics LOGGER.info("Subscriber adding SUB hook %s for topics %s", str(address), str(topics)) socket = self._add_sub_socket(address, topics) @@ -242,5 +241,5 @@ def uri_keys(addresses) -> list[str]: def add_subscriptions(socket, topics): """Add subscriptions to a socket.""" - for t__ in topics: - socket.setsockopt_string(SUBSCRIBE, str(t__)) + for topic in topics: + socket.setsockopt_string(SUBSCRIBE, str(topic)) diff --git a/posttroll/message.py b/posttroll/message.py index 13c09d7..5c66250 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -17,11 +17,12 @@ import json import re from functools import partial +from typing import Any, Callable from posttroll import config _MAGICK : str = "pytroll:/" -MESSAGE_VERSION : str = config.get("message_version", "v1.2") +CURRENT_MESSAGE_VERSION : str = "v1.2" class MessageError(Exception): @@ -66,12 +67,12 @@ def is_valid_sender(obj: object) -> bool: return _is_valid_nonempty_string(obj) -def is_valid_data(obj:object, version:str = MESSAGE_VERSION): +def is_valid_data(obj:object, version:str|None, binary:bool): """Check if data is JSON serializable.""" if obj: - encoder = create_datetime_json_encoder_for_version(version) + version = render_version(version, obj, binary) try: - _ = json.dumps(obj, default=encoder) + _ = _encode_data(obj, binary, version) except (TypeError, UnicodeDecodeError): return False return True @@ -95,8 +96,8 @@ class Message: - It will make a Message pickleable. """ - def __init__(self, subject:str="", atype:str="", data="", binary:bool=False, - rawstr:str|None=None, version:str=MESSAGE_VERSION): + def __init__(self, subject:str="", atype:str="", data:str|dict[str, Any]="", binary:bool=False, + rawstr:str|bytes|None=None, version:str|None=None): """Initialize a Message from a subject, type and data, or from a raw string.""" if rawstr: self.__dict__ = _decode(rawstr) @@ -105,9 +106,9 @@ def __init__(self, subject:str="", atype:str="", data="", binary:bool=False, self.type:str = atype self.sender:str = _getsender() self.time = dt.datetime.now(dt.timezone.utc) - self.data = data + self.data:str|dict[str, Any] = data self.binary:bool = binary - self.version:str = version + self.version:str|None = version self._validate() @property @@ -133,11 +134,11 @@ def head(self): return _encode(self, head=True) @staticmethod - def decode(rawstr): + def decode(rawstr:str|bytes): """Decode a raw string into a Message.""" return Message(rawstr=rawstr) - def encode(self): + def encode(self) -> str: """Encode a Message to a raw string.""" self._validate() return _encode(self, binary=self.binary) @@ -162,7 +163,7 @@ def _validate(self): raise MessageError("Invalid type: '%s'" % self.type) if not is_valid_sender(self.sender): raise MessageError("Invalid sender: '%s'" % self.sender) - if not self.binary and not is_valid_data(self.data, self.version): + if not self.binary and not is_valid_data(self.data, self.version, self.binary): raise MessageError("Invalid data: data is not JSON serializable: %s" % str(self.data)) @@ -185,7 +186,7 @@ def __setstate__(self, state): def _is_valid_version(version): """Check version.""" - return version <= MESSAGE_VERSION + return version <= CURRENT_MESSAGE_VERSION def datetime_decoder(dct): @@ -210,7 +211,7 @@ def datetime_decoder(dct): return dict(result) -def _decode(rawstr): +def _decode(rawstr:str|bytes) -> dict[str, Any]: """Convert a raw string to a Message.""" rawstr = _check_for_magic_word(rawstr) @@ -268,7 +269,7 @@ def _check_for_element_count(rawstr): return raw -def _check_for_magic_word(rawstr: str | bytes): +def _check_for_magic_word(rawstr: str | bytes) -> str|bytes: """Check for the magick word.""" try: rawstr = rawstr.decode("utf-8") @@ -293,15 +294,15 @@ def datetime_encoder(obj, encoder): raise TypeError(repr(obj) + " is not JSON serializable") -def _encode_dt(obj): +def _encode_dt(obj: dt.datetime): return obj.isoformat() -def _encode_dt_no_timezone(obj): +def _encode_dt_no_timezone(obj: dt.datetime): return obj.replace(tzinfo=None).isoformat() -def create_datetime_encoder_for_version(version=MESSAGE_VERSION): +def create_datetime_encoder_for_version(version:str): """Create a datetime encoder depending on the message protocol version.""" if version <= "v1.01": dt_coder = _encode_dt_no_timezone @@ -310,33 +311,59 @@ def create_datetime_encoder_for_version(version=MESSAGE_VERSION): return dt_coder -def create_datetime_json_encoder_for_version(version=MESSAGE_VERSION): +def create_datetime_json_encoder_for_version(version:str) -> Callable[[Any], str]: """Create a datetime json encoder depending on the message protocol version.""" return partial(datetime_encoder, encoder=create_datetime_encoder_for_version(version)) -def _encode(msg, head=False, binary=False): +def _encode(msg:Message, head:bool=False, binary:bool=False) -> str: """Convert a Message to a raw string.""" - json_dt_encoder = create_datetime_json_encoder_for_version(msg.version) - dt_encoder = create_datetime_encoder_for_version(msg.version) + version = render_version(msg.version, msg.data, binary) rawstr = str(_MAGICK) + u"{0:s} {1:s} {2:s} {3:s} {4:s}".format( - msg.subject, msg.type, msg.sender, dt_encoder(msg.time), msg.version) + msg.subject, msg.type, msg.sender, msg.time.isoformat(), version) if not head and msg.data: - - if not binary and isinstance(msg.data, str): - return (rawstr + " " + - "text/ascii" + " " + msg.data) - elif not binary: - return (rawstr + " " + - "application/json" + " " + - json.dumps(msg.data, default=json_dt_encoder)) - else: - return (rawstr + " " + - "binary/octet-stream" + " " + msg.data) + mimetype, data = _encode_data(msg.data, binary, version) + return " ".join((rawstr, mimetype, data)) return rawstr +def render_version(version: str|None, data:str|bytes|dict[str, Any], binary:bool) -> str: + """Make the version a string.""" + configured_version : str = config.get("message_version", None) + return version or configured_version or version_needed(data, binary) + + +def version_needed(data:str|bytes|dict[str,Any], binary:bool) -> str: + """Check the data to see what in the minimal message version needed.""" + if binary: + return "v1.01" + if _contains_datetime(data): + return CURRENT_MESSAGE_VERSION + return "v1.01" + + +def _contains_datetime(data: object) -> bool: + if isinstance(data, dt.datetime): + return True + elif isinstance(data, dict): + return any(_contains_datetime(value) for value in data.values()) + elif isinstance(data, (list, tuple)): + return any(_contains_datetime(item) for item in data) + return False + + +def _encode_data(data:str|bytes|dict[str,Any], binary:bool, version:str): + json_dt_encoder = create_datetime_json_encoder_for_version(version) + if not binary: + if isinstance(data, (str, bytes)): + return "text/ascii", data + else: + return "application/json", json.dumps(data, default=json_dt_encoder) + else: + if not isinstance(data, (str, bytes)): + raise TypeError("Message binary data should be a string or bytes") + return "binary/octet-stream", data # ----------------------------------------------------------------------------- # diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 0443024..562fff6 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -135,7 +135,7 @@ def __init__(self, name, address, data_type: str, interval: int = 2, nameservers msg = message.Message("/address/%s" % name, "info", {"URI": address, "service": data_type, - "supported_message_version": message.MESSAGE_VERSION, + "supported_message_version": message.CURRENT_MESSAGE_VERSION, "backend": config["backend"]}).encode() MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) diff --git a/posttroll/ns.py b/posttroll/ns.py index f3c1251..3ad058c 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -12,7 +12,7 @@ from posttroll import config from posttroll.address_receiver import AddressReceiver -from posttroll.message import MESSAGE_VERSION, Message +from posttroll.message import Message # pylint: enable=E0611 @@ -81,7 +81,7 @@ def get_pub_address(name:str, timeout:float|int=10, nameserver:str="localhost"): # Server part. -def get_active_address(name, arec, message_version=MESSAGE_VERSION): +def get_active_address(name, arec, message_version:str): """Get the addresses of the active modules for a given publisher *name*.""" addrs = arec.get(name) if addrs: diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 99b3139..a1aa6b2 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -265,7 +265,7 @@ def _to_list(obj): class _AddressListener: """Listener for new addresses of interest.""" - def __init__(self, subscriber, services="", nameserver="localhost"): + def __init__(self, subscriber: Subscriber, services: str|list[str] ="", nameserver: str|None ="localhost"): """Initialize address listener.""" if isinstance(services, str): services = [services, ] diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py index ab45f6b..2cbdaa5 100644 --- a/posttroll/tests/test_message.py +++ b/posttroll/tests/test_message.py @@ -7,6 +7,7 @@ import pytest +from posttroll import config from posttroll.message import _MAGICK, Message HOME = os.path.dirname(__file__) or "." @@ -56,7 +57,7 @@ def test_encode(): msg1 = Message(subject, atype, data=data) sender = "%s@%s" % (msg1.user, msg1.host) full_message = (_MAGICK + subject + " " + atype + " " + sender + " " + - str(msg1.time.isoformat()) + " " + msg1.version + " " + "text/ascii" + " " + data) + str(msg1.time.isoformat()) + " " + "v1.01" + " " + "text/ascii" + " " + data) assert full_message == msg1.encode() @@ -152,9 +153,9 @@ def test_message_can_generate_v1_01(): data=dict(start_time=dt.datetime.now(dt.timezone.utc)), version=version) rawmsg = str(msg) - assert "+00:00" not in rawmsg + assert "+00:00" not in rawmsg.split(" ", 6)[-1] msg = Message(rawstr=rawmsg) - assert "+00:00" not in str(msg) + assert "+00:00" not in rawmsg.split(" ", 6)[-1] assert str(msg) == rawmsg @@ -168,3 +169,22 @@ def test_message_has_timezone_by_default(): assert "+00:00" in str(msg) assert str(msg) == rawmsg + +def test_message_encoding_can_choose_version_automatically(): + """Make sure the version number can be chosen automatically.""" + msg1 = Message("/test/whatup/doc", "info", data=dict(time=dt.datetime.now())) + + msg2 = Message.decode(msg1.encode()) + assert msg2.version == "v1.2" + + msg1 = Message("/test/whatup/doc", "info", data=dict(sting="Hi, Bugs")) + + msg2 = Message.decode(msg1.encode()) + assert msg2.version == "v1.01" + +def test_message_version_does_not_change_if_set(): + with config.set(message_version="v1.2"): + msg1 = Message("/test/whatup/doc", "info", data=dict(sting="Hi, Bugs")) + + msg2 = Message.decode(msg1.encode()) + assert msg2.version == "v1.2" diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py index b8d92f8..f03fa07 100644 --- a/posttroll/tests/test_nameserver.py +++ b/posttroll/tests/test_nameserver.py @@ -12,7 +12,7 @@ from posttroll import config from posttroll.backends.zmq.ns import create_unsecure_zmq_nameserver_address -from posttroll.message import MESSAGE_VERSION, Message +from posttroll.message import CURRENT_MESSAGE_VERSION, Message from posttroll.ns import ( NameServer, get_configured_unsecure_zmq_nameserver_port, @@ -122,7 +122,7 @@ def test_pub_addresses(multicast_enabled): u"service": [u"data_provider", u"this_data"], u"name": u"address", "backend": "unsecure_zmq", - "supported_message_version": MESSAGE_VERSION} + "supported_message_version": CURRENT_MESSAGE_VERSION} for key, val in expected.items(): assert res[0][key] == val assert "receive_time" in res[0] @@ -133,7 +133,7 @@ def test_pub_addresses(multicast_enabled): u"service": [u"data_provider", u"this_data"], u"name": u"address", "backend": "unsecure_zmq", - "supported_message_version": MESSAGE_VERSION} + "supported_message_version": CURRENT_MESSAGE_VERSION} for key, val in expected.items(): assert res[0][key] == val assert "receive_time" in res[0] diff --git a/pyproject.toml b/pyproject.toml index 5cf3169..2dc6224 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,3 +76,9 @@ convention = "google" [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 + +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-reraise>=2.1.2", +]