Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions obelisk/python/obelisk_py/core/control.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod

from obelisk_py.core.node import ObeliskNode
from obelisk_py.core.obelisk_typing import ObeliskControlMsg, ObeliskEstimatorMsg


class ObeliskController(ABC, ObeliskNode):
Expand Down Expand Up @@ -38,15 +37,15 @@ def __init__(self, node_name: str) -> None:
)

@abstractmethod
def update_x_hat(self, x_hat_msg: ObeliskEstimatorMsg) -> None:
def update_x_hat(self, x_hat_msg) -> None:
"""Update the state estimate.

Parameters:
x_hat_msg: The Obelisk message containing the state estimate.
"""

@abstractmethod
def compute_control(self) -> ObeliskControlMsg:
def compute_control(self):
"""Compute the control signal.

This is the control timer callback and is expected to call 'publisher_ctrl' internally. Note that the control
Expand Down
4 changes: 1 addition & 3 deletions obelisk/python/obelisk_py/core/estimation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from abc import ABC, abstractmethod
from typing import Union

import obelisk_sensor_msgs.msg as osm
from rclpy.lifecycle.node import LifecycleState, TransitionCallbackReturn

from obelisk_py.core.node import ObeliskNode
from obelisk_py.core.obelisk_typing import ObeliskEstimatorMsg, ObeliskSensorMsg
from obelisk_py.core.utils.internal import get_classes_in_module


Expand Down Expand Up @@ -65,7 +63,7 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn:
return TransitionCallbackReturn.SUCCESS

@abstractmethod
def compute_state_estimate(self) -> Union[ObeliskEstimatorMsg, ObeliskSensorMsg]:
def compute_state_estimate(self):
"""Compute the state estimate.

This is the state estimate timer callback and is expected to call 'publisher_est' internally. Note that the
Expand Down
253 changes: 120 additions & 133 deletions obelisk/python/obelisk_py/core/node.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

import rclpy
from rclpy._rclpy_pybind11 import RCLError
from rclpy.callback_groups import CallbackGroup, MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup
from rclpy.lifecycle import LifecycleNode
from rclpy.lifecycle.node import LifecycleState, TransitionCallbackReturn
from rclpy.publisher import Publisher
from rclpy.qos import QoSProfile
from rclpy.qos_event import PublisherEventCallbacks, SubscriptionEventCallbacks
from rclpy.qos_overriding_options import QoSOverridingOptions
from rclpy.subscription import Subscription

from obelisk_py.core.exceptions import ObeliskMsgError
from obelisk_py.core.obelisk_typing import ObeliskAllowedMsg, ObeliskMsg, is_in_bound
from obelisk_py.core.utils.internal import check_and_get_obelisk_msg_type
# from obelisk_py.core.obelisk_typing import ObeliskAllowedMsg, ObeliskMsg, is_in_bound

MsgType = TypeVar("MsgType") # hack to denote any message type


class ObeliskNode(LifecycleNode):
"""A lifecycle node whose publishers and subscribers can only publish and subscribe to Obelisk messages.
"""A lifecycle node designed for use in Obelisk.

By convention, the initialization function should only declare ROS parameters and define stateful quantities.
Some guidelines for the on_configure, on_activate, and on_deactivate callbacks are provided below.
Expand Down Expand Up @@ -115,7 +107,7 @@ def register_obk_publisher(
def register_obk_subscription(
self,
ros_parameter: str,
callback: Callable[[Union[ObeliskAllowedMsg, MsgType]], None],
callback: Callable[[MsgType], None],
key: Optional[str] = None,
msg_type: Optional[Type] = None,
default_config_str: Optional[str] = None,
Expand Down Expand Up @@ -258,20 +250,21 @@ def _get_key_from_config_dict(config_dict: Dict) -> str:
assert isinstance(config_dict["key"], str), "The 'key' field must be a string!"
return config_dict["key"]

@staticmethod
def _get_msg_type_from_config_dict(config_dict: Dict) -> Type:
"""Get the message type from a configuration dictionary."""
assert config_dict.get("msg_type") is not None, "No message type supplied!"
assert isinstance(config_dict["msg_type"], str), "The 'msg_type' field must be a string!"

if "non_obelisk" in config_dict:
assert isinstance(config_dict["non_obelisk"], str), "The 'non_obelisk' field must be a string!"
assert config_dict["non_obelisk"].lower() != "true", "non_obelisk=True but no message type supplied!"
msg_type = check_and_get_obelisk_msg_type(config_dict["msg_type"], ObeliskMsg)
else:
msg_type = check_and_get_obelisk_msg_type(config_dict["msg_type"], ObeliskAllowedMsg)
# @staticmethod
# def _get_msg_type_from_config_dict(config_dict: Dict) -> Type:
# """Get the message type from a configuration dictionary."""
# assert config_dict.get("msg_type") is not None, "No message type supplied!"
# assert isinstance(config_dict["msg_type"], str), "The 'msg_type' field must be a string!"

return msg_type
# # TODO: Remove
# if "non_obelisk" in config_dict:
# assert isinstance(config_dict["non_obelisk"], str), "The 'non_obelisk' field must be a string!"
# assert config_dict["non_obelisk"].lower() != "true", "non_obelisk=True but no message type supplied!"
# msg_type = check_and_get_obelisk_msg_type(config_dict["msg_type"], ObeliskMsg)
# else:
# msg_type = check_and_get_obelisk_msg_type(config_dict["msg_type"], ObeliskAllowedMsg)

# return msg_type

@staticmethod
def _create_callback_groups_from_config_str(config_str: str) -> Dict[str, CallbackGroup]:
Expand Down Expand Up @@ -364,7 +357,7 @@ def _create_publisher_from_config_str(
# parse and check the configuration string
field_names, value_names = ObeliskNode._parse_config_str(config_str)
required_field_names = ["topic"]
optional_field_names = ["key", "msg_type", "history_depth", "callback_group", "non_obelisk"]
optional_field_names = ["key", "msg_type", "history_depth", "callback_group"]
ObeliskNode._check_fields(field_names, required_field_names, optional_field_names)
config_dict = dict(zip(field_names, value_names))

Expand Down Expand Up @@ -392,16 +385,13 @@ def _create_publisher_from_config_str(

# run type assertions and create the publisher
history_depth = config_dict.get("history_depth", 10)
non_obelisk_field = config_dict.get("non_obelisk", "False")
assert isinstance(config_dict["topic"], str), "The 'topic' field must be a string!"
assert isinstance(history_depth, int), "The 'history_depth' field must be an int!"
assert isinstance(non_obelisk_field, str), "The 'non_obelisk' field must be a str!"
self.obk_publishers[key] = self.create_publisher(
msg_type=msg_type,
topic=config_dict["topic"],
qos_profile=history_depth,
callback_group=callback_group,
non_obelisk=non_obelisk_field.lower() == "true",
)
assert not hasattr(self, key), f"Attribute {key} already exists in the node!"
setattr(self, key + "_key", self.obk_publishers[key]) # create key attribute for publisher
Expand All @@ -410,7 +400,7 @@ def _create_publisher_from_config_str(
def _create_subscription_from_config_str(
self,
config_str: str,
callback: Callable[[Union[ObeliskAllowedMsg, MsgType]], None],
callback: Callable[[MsgType], None],
key: Optional[str] = None,
msg_type: Optional[Type] = None,
) -> Tuple[str, Type]:
Expand Down Expand Up @@ -441,7 +431,7 @@ def _create_subscription_from_config_str(
# parse and check the configuration string
field_names, value_names = ObeliskNode._parse_config_str(config_str)
required_field_names = ["topic"]
optional_field_names = ["key", "msg_type", "history_depth", "callback_group", "non_obelisk"]
optional_field_names = ["key", "msg_type", "history_depth", "callback_group"]
ObeliskNode._check_fields(field_names, required_field_names, optional_field_names)
config_dict = dict(zip(field_names, value_names))

Expand Down Expand Up @@ -469,18 +459,15 @@ def _create_subscription_from_config_str(

# run type assertions and return the subscription
history_depth = config_dict.get("history_depth", 10)
non_obelisk_field = config_dict.get("non_obelisk", "False")
assert isinstance(config_dict["topic"], str), "The 'topic' field must be a string!"
assert isinstance(history_depth, int), "The 'history_depth' field must be an int!"
assert isinstance(non_obelisk_field, str), "The 'non_obelisk' field must be a str!"

self.obk_subscriptions[key] = self.create_subscription(
msg_type=msg_type,
topic=config_dict["topic"],
callback=callback, # type: ignore
qos_profile=history_depth,
callback_group=callback_group,
non_obelisk=non_obelisk_field.lower() == "true",
)
assert not hasattr(self, key), f"Attribute {key} already exists in the node!"
setattr(self, key + "_key", self.obk_subscriptions[key]) # create key attribute for subscription
Expand Down Expand Up @@ -548,105 +535,105 @@ def _create_timer_from_config_str(
# PUB/SUB CREATION #
# ################ #

def create_publisher(
self,
msg_type: ObeliskAllowedMsg,
topic: str,
qos_profile: Union[QoSProfile, int],
*,
callback_group: Optional[CallbackGroup] = None,
event_callbacks: Optional[PublisherEventCallbacks] = None,
qos_overriding_options: Optional[QoSOverridingOptions] = None,
publisher_class: Type[Publisher] = Publisher,
non_obelisk: bool = False,
) -> Publisher:
"""Create a new publisher that can only publish Obelisk messages.

See: github.com/ros2/rclpy/blob/e4042398d6f0403df2fafdadbdfc90b6f6678d13/rclpy/rclpy/node.py#L1242

Parameters:
non_obelisk: If True, the publisher can publish non-Obelisk messages. Default is False.

Raises:
ObeliskMsgError: If the message type is not an Obelisk message.
"""
if not non_obelisk and not is_in_bound(msg_type, ObeliskAllowedMsg):
if get_origin(ObeliskAllowedMsg.__bound__) is Union:
valid_msg_types = [a.__name__ for a in get_args(ObeliskAllowedMsg.__bound__)]
else:
valid_msg_types = [ObeliskAllowedMsg.__name__]
raise ObeliskMsgError(
f"msg_type must be one of {valid_msg_types}. "
"Got {msg_type.__name__}. If you are sure that the message type is correct, "
"set non_obelisk=True. Note that this may cause certain API incompatibilies."
)

try:
return super().create_publisher(
msg_type=msg_type,
topic=topic,
qos_profile=qos_profile,
callback_group=callback_group,
event_callbacks=event_callbacks,
qos_overriding_options=qos_overriding_options,
publisher_class=publisher_class,
)
except RCLError as e:
self.get_logger().error(
"Failed to create publisher: verify that you haven't declared the same topic twice!"
)
raise e

def create_subscription(
self,
msg_type: ObeliskAllowedMsg,
topic: str,
callback: Callable[[ObeliskAllowedMsg], None],
qos_profile: Union[QoSProfile, int],
*,
callback_group: Optional[CallbackGroup] = None,
event_callbacks: Optional[SubscriptionEventCallbacks] = None,
qos_overriding_options: Optional[QoSOverridingOptions] = None,
raw: bool = False,
non_obelisk: bool = False,
) -> Subscription:
"""Create a new subscription that can only subscribe to Obelisk messages.

See: github.com/ros2/rclpy/blob/e4042398d6f0403df2fafdadbdfc90b6f6678d13/rclpy/rclpy/node.py#L1316

Parameters:
non_obelisk: If True, the subscriber can receive non-Obelisk messages. Default is False.

Raises:
ObeliskMsgError: If the message type is not an Obelisk message.
"""
if not non_obelisk and not is_in_bound(msg_type, ObeliskAllowedMsg):
if get_origin(ObeliskAllowedMsg.__bound__) is Union:
valid_msg_types = [a.__name__ for a in get_args(ObeliskAllowedMsg.__bound__)]
else:
valid_msg_types = [ObeliskAllowedMsg.__name__]
raise ObeliskMsgError(
f"msg_type must be one of {valid_msg_types}. "
"Got {msg_type.__name__}. If you are sure that the message type is correct, "
"set non_obelisk=True. Note that this may cause certain API incompatibilies."
)

try:
return super().create_subscription(
msg_type=msg_type,
topic=topic,
callback=callback,
qos_profile=qos_profile,
callback_group=callback_group,
event_callbacks=event_callbacks,
qos_overriding_options=qos_overriding_options,
raw=raw,
)
except RCLError as e:
self.get_logger().error(
"Failed to create subscription: verify that you haven't declared the same topic twice!"
)
raise e
# def create_publisher(
# self,
# msg_type: ObeliskAllowedMsg,
# topic: str,
# qos_profile: Union[QoSProfile, int],
# *,
# callback_group: Optional[CallbackGroup] = None,
# event_callbacks: Optional[PublisherEventCallbacks] = None,
# qos_overriding_options: Optional[QoSOverridingOptions] = None,
# publisher_class: Type[Publisher] = Publisher,
# non_obelisk: bool = False,
# ) -> Publisher:
# """Create a new publisher that can only publish Obelisk messages.

# See: github.com/ros2/rclpy/blob/e4042398d6f0403df2fafdadbdfc90b6f6678d13/rclpy/rclpy/node.py#L1242

# Parameters:
# non_obelisk: If True, the publisher can publish non-Obelisk messages. Default is False.

# Raises:
# ObeliskMsgError: If the message type is not an Obelisk message.
# """
# if not non_obelisk and not is_in_bound(msg_type, ObeliskAllowedMsg):
# if get_origin(ObeliskAllowedMsg.__bound__) is Union:
# valid_msg_types = [a.__name__ for a in get_args(ObeliskAllowedMsg.__bound__)]
# else:
# valid_msg_types = [ObeliskAllowedMsg.__name__]
# raise ObeliskMsgError(
# f"msg_type must be one of {valid_msg_types}. "
# "Got {msg_type.__name__}. If you are sure that the message type is correct, "
# "set non_obelisk=True. Note that this may cause certain API incompatibilies."
# )

# try:
# return super().create_publisher(
# msg_type=msg_type,
# topic=topic,
# qos_profile=qos_profile,
# callback_group=callback_group,
# event_callbacks=event_callbacks,
# qos_overriding_options=qos_overriding_options,
# publisher_class=publisher_class,
# )
# except RCLError as e:
# self.get_logger().error(
# "Failed to create publisher: verify that you haven't declared the same topic twice!"
# )
# raise e

# def create_subscription(
# self,
# msg_type: ObeliskAllowedMsg,
# topic: str,
# callback: Callable[[ObeliskAllowedMsg], None],
# qos_profile: Union[QoSProfile, int],
# *,
# callback_group: Optional[CallbackGroup] = None,
# event_callbacks: Optional[SubscriptionEventCallbacks] = None,
# qos_overriding_options: Optional[QoSOverridingOptions] = None,
# raw: bool = False,
# non_obelisk: bool = False,
# ) -> Subscription:
# """Create a new subscription that can only subscribe to Obelisk messages.

# See: github.com/ros2/rclpy/blob/e4042398d6f0403df2fafdadbdfc90b6f6678d13/rclpy/rclpy/node.py#L1316

# Parameters:
# non_obelisk: If True, the subscriber can receive non-Obelisk messages. Default is False.

# Raises:
# ObeliskMsgError: If the message type is not an Obelisk message.
# """
# if not non_obelisk and not is_in_bound(msg_type, ObeliskAllowedMsg):
# if get_origin(ObeliskAllowedMsg.__bound__) is Union:
# valid_msg_types = [a.__name__ for a in get_args(ObeliskAllowedMsg.__bound__)]
# else:
# valid_msg_types = [ObeliskAllowedMsg.__name__]
# raise ObeliskMsgError(
# f"msg_type must be one of {valid_msg_types}. "
# "Got {msg_type.__name__}. If you are sure that the message type is correct, "
# "set non_obelisk=True. Note that this may cause certain API incompatibilies."
# )

# try:
# return super().create_subscription(
# msg_type=msg_type,
# topic=topic,
# callback=callback,
# qos_profile=qos_profile,
# callback_group=callback_group,
# event_callbacks=event_callbacks,
# qos_overriding_options=qos_overriding_options,
# raw=raw,
# )
# except RCLError as e:
# self.get_logger().error(
# "Failed to create subscription: verify that you haven't declared the same topic twice!"
# )
# raise e

# ################### #
# LIFECYCLE CALLBACKS #
Expand Down
Loading