Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ or within python code::
>>> with config.set(message_version="v1.01"):
...

Note that if the message version is not set explicitly with the above configuration, or when creating a message
object, the message version will be set to the lowest compatible version, that is v1.01 for messages not
encoding a datetime object, v1.2 otherwise.

API
---
Expand Down
11 changes: 5 additions & 6 deletions posttroll/backends/zmq/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from posttroll import config
from posttroll.backends.zmq import get_tcp_keepalive_options
from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket
from posttroll.message import MESSAGE_VERSION
from posttroll.message import CURRENT_MESSAGE_VERSION

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def add(self, address: dict[str, str], topics=None):
"""
with self._lock:
addr = ensure_address_is_dict(address)
if addr.get("supported_message_version", MESSAGE_VERSION) > MESSAGE_VERSION:
if addr.get("supported_message_version", CURRENT_MESSAGE_VERSION) > CURRENT_MESSAGE_VERSION:
LOGGER.warning(f"Will not connect to {str(addr)}, message version mismatch")
return
if addr["URI"] in self.address_keys:
Expand Down Expand Up @@ -102,9 +102,8 @@ def add_hook_sub(self, address, topics, callback):
specified subscription.

Good for operations, which is required to be done in the same thread as
the main recieve loop (e.q operations on the underlying sockets).
the main receive loop (e.q operations on the underlying sockets).
"""
topics = topics
LOGGER.info("Subscriber adding SUB hook %s for topics %s",
str(address), str(topics))
socket = self._add_sub_socket(address, topics)
Expand Down Expand Up @@ -242,5 +241,5 @@ def uri_keys(addresses) -> list[str]:

def add_subscriptions(socket, topics):
"""Add subscriptions to a socket."""
for t__ in topics:
socket.setsockopt_string(SUBSCRIBE, str(t__))
for topic in topics:
socket.setsockopt_string(SUBSCRIBE, str(topic))
93 changes: 60 additions & 33 deletions posttroll/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import json
import re
from functools import partial
from typing import Any, Callable

from posttroll import config

_MAGICK : str = "pytroll:/"
MESSAGE_VERSION : str = config.get("message_version", "v1.2")
CURRENT_MESSAGE_VERSION : str = "v1.2"


class MessageError(Exception):
Expand Down Expand Up @@ -66,12 +67,12 @@ def is_valid_sender(obj: object) -> bool:
return _is_valid_nonempty_string(obj)


def is_valid_data(obj:object, version:str = MESSAGE_VERSION):
def is_valid_data(obj:object, version:str|None, binary:bool):
"""Check if data is JSON serializable."""
if obj:
encoder = create_datetime_json_encoder_for_version(version)
version = render_version(version, obj, binary)
try:
_ = json.dumps(obj, default=encoder)
_ = _encode_data(obj, binary, version)
except (TypeError, UnicodeDecodeError):
return False
return True
Expand All @@ -95,8 +96,8 @@ class Message:
- It will make a Message pickleable.
"""

def __init__(self, subject:str="", atype:str="", data="", binary:bool=False,
rawstr:str|None=None, version:str=MESSAGE_VERSION):
def __init__(self, subject:str="", atype:str="", data:str|dict[str, Any]="", binary:bool=False,
rawstr:str|bytes|None=None, version:str|None=None):
"""Initialize a Message from a subject, type and data, or from a raw string."""
if rawstr:
self.__dict__ = _decode(rawstr)
Expand All @@ -105,9 +106,9 @@ def __init__(self, subject:str="", atype:str="", data="", binary:bool=False,
self.type:str = atype
self.sender:str = _getsender()
self.time = dt.datetime.now(dt.timezone.utc)
self.data = data
self.data:str|dict[str, Any] = data
self.binary:bool = binary
self.version:str = version
self.version:str|None = version
self._validate()

@property
Expand All @@ -133,11 +134,11 @@ def head(self):
return _encode(self, head=True)

@staticmethod
def decode(rawstr):
def decode(rawstr:str|bytes):
"""Decode a raw string into a Message."""
return Message(rawstr=rawstr)

def encode(self):
def encode(self) -> str:
"""Encode a Message to a raw string."""
self._validate()
return _encode(self, binary=self.binary)
Expand All @@ -162,7 +163,7 @@ def _validate(self):
raise MessageError("Invalid type: '%s'" % self.type)
if not is_valid_sender(self.sender):
raise MessageError("Invalid sender: '%s'" % self.sender)
if not self.binary and not is_valid_data(self.data, self.version):
if not self.binary and not is_valid_data(self.data, self.version, self.binary):
raise MessageError("Invalid data: data is not JSON serializable: %s"
% str(self.data))

Expand All @@ -185,7 +186,7 @@ def __setstate__(self, state):

def _is_valid_version(version):
"""Check version."""
return version <= MESSAGE_VERSION
return version <= CURRENT_MESSAGE_VERSION


def datetime_decoder(dct):
Expand All @@ -210,7 +211,7 @@ def datetime_decoder(dct):
return dict(result)


def _decode(rawstr):
def _decode(rawstr:str|bytes) -> dict[str, Any]:
"""Convert a raw string to a Message."""
rawstr = _check_for_magic_word(rawstr)

Expand Down Expand Up @@ -268,7 +269,7 @@ def _check_for_element_count(rawstr):
return raw


def _check_for_magic_word(rawstr: str | bytes):
def _check_for_magic_word(rawstr: str | bytes) -> str|bytes:
"""Check for the magick word."""
try:
rawstr = rawstr.decode("utf-8")
Expand All @@ -293,15 +294,15 @@ def datetime_encoder(obj, encoder):
raise TypeError(repr(obj) + " is not JSON serializable")


def _encode_dt(obj):
def _encode_dt(obj: dt.datetime):
return obj.isoformat()


def _encode_dt_no_timezone(obj):
def _encode_dt_no_timezone(obj: dt.datetime):
return obj.replace(tzinfo=None).isoformat()


def create_datetime_encoder_for_version(version=MESSAGE_VERSION):
def create_datetime_encoder_for_version(version:str):
"""Create a datetime encoder depending on the message protocol version."""
if version <= "v1.01":
dt_coder = _encode_dt_no_timezone
Expand All @@ -310,33 +311,59 @@ def create_datetime_encoder_for_version(version=MESSAGE_VERSION):
return dt_coder


def create_datetime_json_encoder_for_version(version=MESSAGE_VERSION):
def create_datetime_json_encoder_for_version(version:str) -> Callable[[Any], str]:
"""Create a datetime json encoder depending on the message protocol version."""
return partial(datetime_encoder,
encoder=create_datetime_encoder_for_version(version))


def _encode(msg, head=False, binary=False):
def _encode(msg:Message, head:bool=False, binary:bool=False) -> str:
"""Convert a Message to a raw string."""
json_dt_encoder = create_datetime_json_encoder_for_version(msg.version)
dt_encoder = create_datetime_encoder_for_version(msg.version)
version = render_version(msg.version, msg.data, binary)

rawstr = str(_MAGICK) + u"{0:s} {1:s} {2:s} {3:s} {4:s}".format(
msg.subject, msg.type, msg.sender, dt_encoder(msg.time), msg.version)
msg.subject, msg.type, msg.sender, msg.time.isoformat(), version)
if not head and msg.data:

if not binary and isinstance(msg.data, str):
return (rawstr + " " +
"text/ascii" + " " + msg.data)
elif not binary:
return (rawstr + " " +
"application/json" + " " +
json.dumps(msg.data, default=json_dt_encoder))
else:
return (rawstr + " " +
"binary/octet-stream" + " " + msg.data)
mimetype, data = _encode_data(msg.data, binary, version)
return " ".join((rawstr, mimetype, data))
return rawstr

def render_version(version: str|None, data:str|bytes|dict[str, Any], binary:bool) -> str:
"""Make the version a string."""
configured_version : str = config.get("message_version", None)
return version or configured_version or version_needed(data, binary)


def version_needed(data:str|bytes|dict[str,Any], binary:bool) -> str:
"""Check the data to see what in the minimal message version needed."""
if binary:
return "v1.01"
if _contains_datetime(data):
return CURRENT_MESSAGE_VERSION
return "v1.01"


def _contains_datetime(data: object) -> bool:
if isinstance(data, dt.datetime):
return True
elif isinstance(data, dict):
return any(_contains_datetime(value) for value in data.values())
elif isinstance(data, (list, tuple)):
return any(_contains_datetime(item) for item in data)
return False


def _encode_data(data:str|bytes|dict[str,Any], binary:bool, version:str):
json_dt_encoder = create_datetime_json_encoder_for_version(version)
if not binary:
if isinstance(data, (str, bytes)):
return "text/ascii", data
else:
return "application/json", json.dumps(data, default=json_dt_encoder)
else:
if not isinstance(data, (str, bytes)):
raise TypeError("Message binary data should be a string or bytes")
return "binary/octet-stream", data

# -----------------------------------------------------------------------------
#
Expand Down
2 changes: 1 addition & 1 deletion posttroll/message_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, name, address, data_type: str, interval: int = 2, nameservers
msg = message.Message("/address/%s" % name, "info",
{"URI": address,
"service": data_type,
"supported_message_version": message.MESSAGE_VERSION,
"supported_message_version": message.CURRENT_MESSAGE_VERSION,
"backend": config["backend"]}).encode()
MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval,
nameservers)
Expand Down
4 changes: 2 additions & 2 deletions posttroll/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from posttroll import config
from posttroll.address_receiver import AddressReceiver
from posttroll.message import MESSAGE_VERSION, Message
from posttroll.message import Message

# pylint: enable=E0611

Expand Down Expand Up @@ -81,7 +81,7 @@ def get_pub_address(name:str, timeout:float|int=10, nameserver:str="localhost"):
# Server part.


def get_active_address(name, arec, message_version=MESSAGE_VERSION):
def get_active_address(name, arec, message_version:str):
"""Get the addresses of the active modules for a given publisher *name*."""
addrs = arec.get(name)
if addrs:
Expand Down
2 changes: 1 addition & 1 deletion posttroll/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _to_list(obj):
class _AddressListener:
"""Listener for new addresses of interest."""

def __init__(self, subscriber, services="", nameserver="localhost"):
def __init__(self, subscriber: Subscriber, services: str|list[str] ="", nameserver: str|None ="localhost"):
"""Initialize address listener."""
if isinstance(services, str):
services = [services, ]
Expand Down
26 changes: 23 additions & 3 deletions posttroll/tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

from posttroll import config
from posttroll.message import _MAGICK, Message

HOME = os.path.dirname(__file__) or "."
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_encode():
msg1 = Message(subject, atype, data=data)
sender = "%s@%s" % (msg1.user, msg1.host)
full_message = (_MAGICK + subject + " " + atype + " " + sender + " " +
str(msg1.time.isoformat()) + " " + msg1.version + " " + "text/ascii" + " " + data)
str(msg1.time.isoformat()) + " " + "v1.01" + " " + "text/ascii" + " " + data)
assert full_message == msg1.encode()


Expand Down Expand Up @@ -152,9 +153,9 @@ def test_message_can_generate_v1_01():
data=dict(start_time=dt.datetime.now(dt.timezone.utc)),
version=version)
rawmsg = str(msg)
assert "+00:00" not in rawmsg
assert "+00:00" not in rawmsg.split(" ", 6)[-1]
msg = Message(rawstr=rawmsg)
assert "+00:00" not in str(msg)
assert "+00:00" not in rawmsg.split(" ", 6)[-1]
assert str(msg) == rawmsg


Expand All @@ -168,3 +169,22 @@ def test_message_has_timezone_by_default():
assert "+00:00" in str(msg)
assert str(msg) == rawmsg


def test_message_encoding_can_choose_version_automatically():
"""Make sure the version number can be chosen automatically."""
msg1 = Message("/test/whatup/doc", "info", data=dict(time=dt.datetime.now()))

msg2 = Message.decode(msg1.encode())
assert msg2.version == "v1.2"

msg1 = Message("/test/whatup/doc", "info", data=dict(sting="Hi, Bugs"))

msg2 = Message.decode(msg1.encode())
assert msg2.version == "v1.01"

def test_message_version_does_not_change_if_set():
with config.set(message_version="v1.2"):
msg1 = Message("/test/whatup/doc", "info", data=dict(sting="Hi, Bugs"))

msg2 = Message.decode(msg1.encode())
assert msg2.version == "v1.2"
6 changes: 3 additions & 3 deletions posttroll/tests/test_nameserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from posttroll import config
from posttroll.backends.zmq.ns import create_unsecure_zmq_nameserver_address
from posttroll.message import MESSAGE_VERSION, Message
from posttroll.message import CURRENT_MESSAGE_VERSION, Message
from posttroll.ns import (
NameServer,
get_configured_unsecure_zmq_nameserver_port,
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_pub_addresses(multicast_enabled):
u"service": [u"data_provider", u"this_data"],
u"name": u"address",
"backend": "unsecure_zmq",
"supported_message_version": MESSAGE_VERSION}
"supported_message_version": CURRENT_MESSAGE_VERSION}
for key, val in expected.items():
assert res[0][key] == val
assert "receive_time" in res[0]
Expand All @@ -133,7 +133,7 @@ def test_pub_addresses(multicast_enabled):
u"service": [u"data_provider", u"this_data"],
u"name": u"address",
"backend": "unsecure_zmq",
"supported_message_version": MESSAGE_VERSION}
"supported_message_version": CURRENT_MESSAGE_VERSION}
for key, val in expected.items():
assert res[0][key] == val
assert "receive_time" in res[0]
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,9 @@ convention = "google"
[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10

[dependency-groups]
dev = [
"pytest>=9.0.2",
"pytest-reraise>=2.1.2",
]
Loading