Source code for axros.axros_tf

from __future__ import annotations

import asyncio
import bisect
import itertools
import math
from typing import TYPE_CHECKING, List

import genpy
import numpy
from geometry_msgs.msg import Pose
from geometry_msgs.msg import Transform as TransformMsg
from geometry_msgs.msg import TransformStamped
from tf import transformations  # XXX
from tf2_msgs.msg import TFMessage

from . import util

if TYPE_CHECKING:
    from .nodehandle import NodeHandle
    from .subscriber import Subscriber


def _sinc(x):
    return numpy.sinc(x / numpy.pi)


def _quat_from_rotvec(r):
    r = numpy.array(r)
    angle = numpy.linalg.norm(r)
    return numpy.concatenate(
        [
            _sinc(angle / 2)
            / 2
            * r,  # = sin(angle/2) * normalized(r), without the singularity
            [math.cos(angle / 2)],
        ]
    )


def _rotvec_from_quat(q):
    q = transformations.unit_vector(q)
    if q[3] < 0:
        q = -q
    return 2 / _sinc(math.acos(min(1, q[3]))) * q[:3]


[docs]class Transform: """ Represents a tf transform in the axros suite. .. container:: operations .. describe:: x + y Adds the components of each transform together. .. describe:: x - y Subtracts the components between each transform. .. describe:: x * y Multiplies the components of each transform together. .. describe:: str(x) Prints the constructor of the class in a string-formatted fashion. .. describe:: repr(x) Prints the constructor of the class in a string-formatted fashion. Equivalent to ``str(x)``. """
[docs] @classmethod def identity(cls): """ Creates and identity transform. The x, y, and z coordinates are all set to zero. The quaternion is set to ``[0, 0, 0, 1]``. """ return cls([0, 0, 0], [0, 0, 0, 1])
[docs] @classmethod def from_Transform_message(cls, msg: TransformMsg): """ Constructs a transform from a :class:`geometry_msgs.msg.Transform` message. """ return cls( [msg.translation.x, msg.translation.y, msg.translation.z], [msg.rotation.x, msg.rotation.y, msg.rotation.z, msg.rotation.w], )
[docs] @classmethod def from_Pose_message(cls, msg: Pose): """ Constructs a transform from a :class:`geometry_msgs.msg.Pose` message. """ return cls( [msg.position.x, msg.position.y, msg.position.z], [ msg.orientation.x, msg.orientation.y, msg.orientation.z, msg.orientation.w, ], )
def __init__(self, p: list[float], q: list[float]): self._p = numpy.array(p) self._q = numpy.array(q) self._q_mat = transformations.quaternion_matrix(self._q)[:3, :3] def __sub__(self, other: Transform): return numpy.concatenate( [ self._p - other._p, _rotvec_from_quat( transformations.quaternion_multiply( self._q, transformations.quaternion_inverse(other._q), ) ), ] ) def __add__(self, other: Transform): return Transform( self._p + other[:3], transformations.quaternion_multiply( _quat_from_rotvec(other[3:]), self._q, ), ) def __mul__(self, other: Transform): return Transform( self._p + self._q_mat.dot(other._p), transformations.quaternion_multiply(self._q, other._q), )
[docs] def inverse(self) -> Transform: """ Constructs an inverse transform. Returns: Transform: The inverse transform. """ return Transform( -self._q_mat.T.dot(self._p), transformations.quaternion_inverse(self._q), )
def as_matrix(self): return transformations.translation_matrix(self._p).dot( transformations.quaternion_matrix(self._q) ) def __str__(self): return f"Transform({self._p!r}, {self._q!r})" __repr__ = __str__ def transform_point(self, point): return self._p + self._q_mat.dot(point) def transform_vector(self, vector): return self._q_mat.dot(vector) def transform_quaternion(self, quaternion): return transformations.quaternion_multiply(self._q, quaternion)
def _make_absolute(frame_id: str) -> str: if not frame_id.startswith("/"): return "/" + frame_id return frame_id
[docs]class TransformListener: _node_handle: NodeHandle _history_length: genpy.Duration _id_count: itertools.count _tfs: dict[str, list[genpy.Time | str | Transform]] _futs: dict[str, dict[int, asyncio.Future]] _tf_subscriber: Subscriber def __init__( self, node_handle: NodeHandle, history_length: genpy.Duration = genpy.Duration(10), ): self._node_handle = node_handle self._history_length = history_length self._id_counter = itertools.count() self._tfs = {} # child_frame_id -> sorted list of (time, frame_id, Transform) self._futs: dict[ str, dict[int, asyncio.Future] ] = {} # child_frame_id -> dict of futures to call when new tf is received self._tf_subscriber = self._node_handle.subscribe( "/tf", TFMessage, self._got_tfs ) async def setup(self): await self._tf_subscriber.setup()
[docs] async def shutdown(self) -> None: """ Shuts the transform listener down. """ await self._tf_subscriber.shutdown()
def _got_tfs(self, msg: TFMessage): for transform in msg.transforms: frame_id = _make_absolute(transform.header.frame_id) child_frame_id = _make_absolute(transform.child_frame_id) l = self._tfs.setdefault(child_frame_id, []) if l and transform.header.stamp < l[-1][0]: del l[:] l.append( ( transform.header.stamp, frame_id, Transform.from_Transform_message(transform.transform), ) ) if l[0][0] + self._history_length * 2 <= transform.header.stamp: pos = 0 while l[pos][0] + self._history_length <= transform.header.stamp: pos += 1 del l[:pos] for transform in msg.transforms: frame_id = _make_absolute(transform.header.frame_id) futs = self._futs.pop(frame_id, {}) for fut in futs.values(): fut.set_result(None) def _wait_for_new(self, child_frame_id) -> tuple[asyncio.Future, int]: id_ = next(self._id_counter) fut = asyncio.Future() self._futs.setdefault(child_frame_id, {})[id_] = fut return fut, id_ async def get_transform( self, to_frame: str, from_frame: str, time: genpy.Time | None = None ): to_frame = _make_absolute(to_frame) from_frame = _make_absolute(from_frame) assert time is None or isinstance(time, genpy.Time) to_tfs = {to_frame: Transform.identity()} # x -> Transform from to_frame to x to_pos = to_frame from_tfs = { from_frame: Transform.identity() } # x -> Transform from from_frame to x from_pos = from_frame while True: while True: try: new_to_pos, t = self._interpolate(self._tfs.get(to_pos, []), time) except _TooFutureError: break except TooPastError: raise else: assert new_to_pos not in to_tfs to_tfs[new_to_pos] = t * to_tfs[to_pos] if new_to_pos in from_tfs: return to_tfs[new_to_pos].inverse() * from_tfs[new_to_pos] to_pos = new_to_pos while True: try: new_from_pos, t = self._interpolate( self._tfs.get(from_pos, []), time ) except _TooFutureError: break except TooPastError: raise else: assert new_from_pos not in from_tfs from_tfs[new_from_pos] = t * from_tfs[from_pos] if new_from_pos in to_tfs: return to_tfs[new_from_pos].inverse() * from_tfs[new_from_pos] from_pos = new_from_pos to_pos_fut, to_pos_id = self._wait_for_new(to_pos) from_pos_fut, from_pos_id = self._wait_for_new(from_pos) lst = [to_pos_fut, from_pos_fut] try: await asyncio.wait(lst, return_when=asyncio.FIRST_COMPLETED) finally: for fut in lst: fut.cancel() if to_pos in self._futs: self._futs[to_pos].pop(to_pos_id) if not self._futs[to_pos]: self._futs.pop(to_pos) if from_pos in self._futs: self._futs[from_pos].pop(from_pos_id) if not self._futs[from_pos]: self._futs.pop(from_pos) def _interpolate( self, sorted_list: list[tuple[genpy.Time, str, Transform]], time: genpy.Time | None, ) -> tuple[str, genpy.Time]: if time is None: if sorted_list: return sorted_list[-1][1], sorted_list[-1][2] else: raise _TooFutureError() if not sorted_list or time > sorted_list[-1][0]: raise _TooFutureError() if time < sorted_list[0][0]: raise TooPastError() pos = bisect.bisect_left(sorted_list, (time,)) left = sorted_list[0 if pos == 0 else pos - 1] right = sorted_list[pos] assert left[0] <= time assert right[0] >= time x = (time - left[0]).to_sec() / (right[0] - left[0]).to_sec() return left[1], left[2] + x * (right[2] - left[2])
class _TooFutureError(Exception): """This is an internal exception; it should never escape to the user"""
[docs]class TooPastError(Exception): """ User asked for a transform that will never known because it's from before the start of the history buffer. Inherits from :class:`Exception`. """
[docs]class TransformBroadcaster: """ Broadasts transforms onto a topic. """ def __init__(self, node_handle: NodeHandle): """ Args: node_handle (NodeHandle): The node handle used to communicate with the ROS master server. """ self._node_handle = node_handle self._tf_publisher = self._node_handle.advertise("/tf", TFMessage) async def setup(self) -> None: await self._tf_publisher.setup()
[docs] def send_transform(self, transform: TransformStamped): """ Sends a stamped Transform message onto the publisher. Args: transform (TransformStamped): The transform to publish onto the topic. Raises: TypeError: The transform class was not of the TransformStamped class. """ if not isinstance(transform, TransformStamped): raise TypeError("expected TransformStamped") self._tf_publisher.publish( TFMessage( transforms=[transform], ) )