Source code for lqrrt.tree

Class for an lqrrt tree.
Each node in the tree has:

- ID: integer (never specified by user)
- pID: parent ID integer

- state: array of state values
- lqr: tuple of local cost-to-go and policy arrays (S, K)

- x_seq: list of state arrays from parent state to node state
- u_seq: list of effort arrays from parent state to node state

The ID corresponds to when the node was added (i.e. the third node
added has an ID of 3). The seed node always has ID 0, and pID -1.
The ID of the last node added to the tree is always mytree.size-1,
where mytree.size is the number of nodes in the tree.

To get the value of some feature of a node, do
For example, mytree.state[3] will return the state of node ID3 in
mytree. Or, mytree.x_seq[6][3] will return the fourth state vector on
the edge connecting node ID6 to its parent (starting from the parent).


################################################# DEPENDENCIES

from __future__ import division
import numpy as np

################################################# PRIMARY CLASS

[docs]class Tree: """ To initialize, provide... seed_state: An array of the state of the seed node. seed_lqr: The local LQR-optimal cost-to-go array S and policy array K as a tuple (S, K). That is, S solves the local Riccati equation and K = (R^-1)*(B^T)*(S) for effort jacobian B. """ def __init__(self, seed_state, seed_lqr): # Store number of states self.nstates = len(seed_state) # Store number of controls try: self.ncontrols = seed_lqr[1].shape[0] except: print("\nThe given seed_lqr is not consistent.") print( "Continuing, assuming you don't care about the lqr or effort features...\n" ) self.ncontrols = 1 # Initialize state array self.state = np.array(seed_state, dtype=np.float64).reshape(1, self.nstates) # Initialize all other feature lists self.pID = [-1] self.lqr = [seed_lqr] self.x_seq = [[seed_state]] self.u_seq = [[np.zeros(self.ncontrols)]] # Initialize number of nodes self.size = 1 #################################################
[docs] def add_node(self, pID, state, lqr, x_seq, u_seq): """ Adds a node to the tree with the given features. """ # Make sure the desired parent exists if pID >= self.size or pID < 0: raise ValueError("The given parent ID, {}, doesn't exist.".format(pID)) # Update state array self.state = np.vstack((self.state, state)) # Append all other feature lists self.pID.append(pID) self.lqr.append(lqr) self.x_seq.append(x_seq) self.u_seq.append(u_seq) # Increment node count self.size += 1
[docs] def climb(self, ID): """ Returns a list of node IDs that connect the seed to the node with the given ID. The first element in the list is always 0 (seed) and the last is always ID (the given climb destination). """ # Make sure the desired end node exists if ID >= self.size or ID < 0: raise ValueError("The given ID, {}, doesn't exist.".format(ID)) # Follow parents backward and then reverse list IDs = [] while ID != -1: IDs.append(ID) ID = self.pID[ID] return IDs[::-1]
[docs] def trajectory(self, IDs): """ Given a list of node IDs, the full sequence of states and efforts to go from IDs[0] to IDs[-1] are returned as a tuple (x_seq_full, u_seq_full). """ x_seq_full = [] u_seq_full = [] for ID in IDs: x_seq_full.extend(self.x_seq[ID]) u_seq_full.extend(self.u_seq[ID]) return (x_seq_full, u_seq_full)
[docs] def visualize(self, dx, dy, node_seq=None): """ Plots the (dx,dy)-cross-section of the current tree, and highlights the path given by the list, node_seq. For example, dx=0, dy=1 plots the states #0 and #1. """ print("\ plotting...") from matplotlib import pyplot as plt fig = plt.figure() fig.suptitle("Tree") plt.axis("equal") ax = fig.add_subplot(1, 1, 1) ax.set_xlabel("- State {} +".format(dx)) ax.set_ylabel("- State {} +".format(dy)) ax.grid(True) if node_seq is None: node_seq = [] if self.size > 1: for ID in range(self.size): x_seq = np.array(self.x_seq[ID]) if ID in node_seq: ax.plot((x_seq[:, dx]), (x_seq[:, dy]), color="r", zorder=2) else: ax.plot((x_seq[:, dx]), (x_seq[:, dy]), color="0.75", zorder=1) ax.scatter(self.state[0, dx], self.state[0, dy], color="b", s=48) if len(node_seq): ax.scatter( self.state[node_seq[-1], dx], self.state[node_seq[-1], dy], color="r", s=48, ) print("Done! Close window to continue.\n")