Cart Pole
In this fourth tutorial we’ll be creating a simulation of a cart pole using machine learning. The pole is attached by an un-actuated joint to the cart. The cart is controlled by a linear actuator, that drives the cart left or right. The goal is to swing the pole up and balance it above the cart using motor control. Start by downloading the CAD geometry.
Solution
The completed model can be downloaded.
RL Problem
The reinforcement learning problem describes how an intelligent agent can take actions in an environment (simulated world) to maximize the cumulative reward. The problem is described in the paper Neuronlike adaptive elements that can solve difficult learning control problems by Andrew Barto, Richard Sutton and Charles Anderson.
Markov Decision Process
The RL problem can be formalised as a Markov Decision Process (MDP). A MDP is a mathematical framework used to describe discrete-time stochastic systems where we would like to model decision making. At each time step, the MDP is in some state s
and the agent can select an action a
. At the next time step, the MDP responds by moving into a new state and giving the agent a reward R(s,a)
.
The MDP can be formally defined as a 5-tuple:
Markov Decision Process refers to the fact that the system obeys the Markov property: transitions only depend on the most recent state and action, and no prior history of states and actions. The discount factor γ
is used to trade-off short-term rewards against long-term future rewards.
Optimization Objective
The goal is to find the optimal policy 𝜋*
for the agent which maximizes the expected discounted cumulative reward. It has the following objective function:
ProtoTwin Model
In this tutorial, we train the cart pole inside ProtoTwin Connect. The state transition function P
is deterministic, and is modeled by stepping the simulation. The reward function R
is deterministic. The policy function 𝜋
is deterministic (a = 𝜋(s))
. The discount factor γ = .99
and the episode is finite-horizon with total number of time steps T
corresponding to a time limit of 10 seconds. At each time step, the agent sees a partial observation of the state of the virtual world. The goal is to maximize the cumulative reward:
Action Space
The action space is an ndarray with shape (1,)
containing:
NUM | ACTION | MIN | MAX |
---|---|---|---|
0 | Cart Target Velocity | -1 | 1 |
Observation Space
The observation space is an ndarray with shape (4,)
containing:
NUM | OBSERVATION | MIN | MAX |
---|---|---|---|
0 | Cart Position | -1 | 1 |
1 | Pole Angle | -1 | 1 |
2 | Cart Velocity | -inf | inf |
3 | Pole Angular Velocity | -inf | inf |
The observation space is a subset of the state space.
- Cart Position is a measure of the cart’s distance from the center, where 0 is at the center and +/-1 is at the limit.
- Pole Angle is a measure of the pole’s angular distance from the upright position, where 0 is at the upright position and +/-1 is at the down position.
Reward Function
Since we want to balance the pole upright for as long as possible, our reward function is defined as:
def reward(self, obs):
distance = 1 - math.fabs(obs[0]) # How close the cart is to the center
angle = 1 - math.fabs(obs[1]) # How close the pole is to the upright position
force = math.fabs(self.get(address_cart_force)) # How much force is being applied to drive the cart's motor
reward = angle * 0.8 + distance * 0.2 - force * 0.004
return max(reward * self.dt, 0)
Episode End
The episode ends if any one of the following conditions are met:
- Termination: Cart goes beyond the limits
[-1, 1]
. These limits correspond to a distance of±0.65m
from the center. - Truncation: Episode time is greater than 10 seconds.
Algorithm
To solve the cart pole problem, we will use an algorithm called Proximal Policy Optimization (PPO). PPO is a policy gradient method for training the agent’s policy neural network. The learned policy 𝜋
is a Multi-Layer Perceptron (MLP) which takes as input an observation and outputs a probability distribution over the actions. PPO is an actor-critic algorithm meaning it uses MLP to learn both the optimal policy function (actor network) and value function (critic network). The action to take is the one with the highest probability:
Signals
Signals represent I/O for components defined in ProtoTwin. Signals are either readable or writable. You can find the signals provided by each component inside ProtoTwin under the I/O dropdown menu. The I/O window lists the name, address and type of each signal along with its access (readable/writable). The signals used in this tutorial are:
- The target velocity of the cart motor.
- The current position of the cart motor.
- The current velocity of the cart motor.
- The current force applied by the cart motor.
- The current position of the pole motor.
- The current velocity of the pole motor.
Packages
Make sure to install the following packages:
pip install prototwin
pip install prototwin-gymnasium
pip install stable-baselines3
pip install torch
pip install numpy
pip install asyncio
The prototwin package provides a client for starting and connecting to an instance of ProtoTwin Connect. Using this client you can issue commands to load a model, step the simulation forwards in time, read signal values and write signal values. The prototwin gymnasium package provides a base environment for Gymnasium for being used in RL workflows. The stable baselines3 package provides a reliable set of RL algorithm implementations in PyTorch. We also use NumPy when working with arrays and asyncio for writing concurrent code using the async/await syntax.
Python Script
The complete python script is provided below:
# STEP 1: Import dependencies
import asyncio
import os
import torch
import numpy as np
import math
import gymnasium
import prototwin
import stable_baselines3.ppo
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor
from stable_baselines3.common.callbacks import CheckpointCallback
from prototwin_gymnasium import VecEnvInstance, VecEnv
# STEP 2: Define signal addresses (obtain these values from ProtoTwin)
address_cart_target_velocity = 3
address_cart_position = 5
address_cart_velocity = 6
address_cart_force = 7
address_pole_angle = 12
address_pole_angular_velocity = 13
# STEP 3: Create your vectorized instance environment by extending the base environment
class CartPoleEnv(VecEnvInstance):
def __init__(self, client: prototwin.Client, instance: int) -> None:
super().__init__(client, instance)
self.dt = 0.01 # Time step
self.x_threshold = 0.65 # Maximum cart distance
def reward(self, obs):
distance = 1 - math.fabs(obs[0]) # How close the cart is to the center
angle = 1 - math.fabs(obs[1]) # How close the pole is to the upright position
force = math.fabs(self.get(address_cart_force)) # How much force is being applied to drive the cart's motor
reward = angle * 0.8 + distance * 0.2 - force * 0.004
return max(reward * self.dt, 0)
def observations(self):
cart_position = self.get(address_cart_position) # Read the current cart position
cart_velocity = self.get(address_cart_velocity) # Read the current cart velocity
pole_angle = self.get(address_pole_angle) # Read the current pole angle
pole_angular_velocity = self.get(address_pole_angular_velocity) # Read the current pole angular velocity
pole_angular_distance = math.atan2(math.sin(pole_angle), math.cos(math.pi - pole_angle)) # Calculate the pole's angular distance from upright position
return np.array([cart_position / self.x_threshold, pole_angular_distance / math.pi, cart_velocity, pole_angular_velocity])
def reset(self, seed = None):
super().reset(seed=seed)
return self.observations(), {}
def apply(self, action):
self.set(address_cart_target_velocity, action[0]) # Apply action by setting the cart's target velocity
def step(self):
obs = self.observations()
reward = self.reward(obs) # Calculate reward
done = abs(obs[0]) > 1 # Terminate if cart goes beyond limits
truncated = self.time > 10 # Truncate after 10 seconds
return obs, reward, done, truncated, {}
# STEP 4: Setup the training session
async def main():
# Start ProtoTwin Connect
client = await prototwin.start()
# Load the ProtoTwin model
filepath = os.path.join(os.path.dirname(__file__), "CartPole.ptm")
await client.load(filepath)
# Create the vectorized environment
entity_name = "Main"
num_envs = 64
# The observation space contains:
# 0. A measure of the cart's distance from the center, where 0 is at the center and +/-1 is at the limit
# 1. A measure of the pole's angular distance from the upright position, where 0 is at the upright position and +/-1 is at the down position
# 2. The cart's current velocity (m/s)
# 3. The pole's angular velocity (rad/s)
observation_high = np.array([1, 1, np.finfo(np.float32).max, np.finfo(np.float32).max], dtype=np.float32)
observation_space = gymnasium.spaces.Box(-observation_high, observation_high, dtype=np.float32)
# The action space contains only the cart's target velocity
action_high = np.array([1.0], dtype=np.float32)
action_space = gymnasium.spaces.Box(-action_high, action_high, dtype=np.float32)
env = VecEnv(CartPoleEnv, client, entity_name, num_envs, observation_space, action_space)
monitored = VecMonitor(env) # Monitor the training progress
# Create callback to regularly save the model
save_freq = 10000 # Number of timesteps per instance
checkpoint_callback = CheckpointCallback(save_freq=save_freq, save_path="./logs/checkpoints/",
name_prefix="checkpoint", save_replay_buffer=True, save_vecnormalize=True)
# Define learning rate schedule
def lr_schedule(progress_remaining):
initial_lr = 0.003
return initial_lr * (progress_remaining ** 2)
# Define the ML model
model = PPO(stable_baselines3.ppo.MlpPolicy, monitored, device=torch.cuda.current_device(),
verbose=1, batch_size=4096, n_steps=1000, learning_rate=lr_schedule, tensorboard_log="./tensorboard/")
# Start training!
model.learn(total_timesteps=10_000_000, callback=checkpoint_callback)
asyncio.run(main())
Exporting to ONNX
It is possible to export trained models to the ONNX format. This can be used to embed trained agents into ProtoTwin models for inferencing. Please refer to the Stable Baselines exporting documentation for further details. The complete python script is provided below:
import torch as th
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
# Export to ONNX for embedding into ProtoTwin models using ONNX Runtime Web
def export():
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.policy = policy
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
return self.policy(observation, deterministic=True)
# Load the trained ML model
model = PPO.load("model", device="cpu")
# Create the Onnx policy
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(onnx_policy, dummy_input, "CartPole.onnx", opset_version=17, input_names=["input"], output_names=["output"])
export()
Inference in ProtoTwin
It is possible to embed trained agents into ProtoTwin models. To do this, you must create a scripted component that loads the ONNX model, feeds observations into the model, and applies the output actions. This example assumes the ONNX file has been included into the model by dragging the file into the script editor’s file explorer. Alternatively, the ONNX file can be loaded from a URL. The complete source code for the inference component is provided below:
import { type Entity, type Handle, InferenceComponent, MotorComponent, Util } from "prototwin";
export class CartPole extends InferenceComponent {
public cartMotor: Handle<MotorComponent>;
public poleMotor: Handle<MotorComponent>;
constructor(entity: Entity) {
super(entity);
this.cartMotor = this.handle(MotorComponent);
this.poleMotor = this.handle(MotorComponent);
}
public override async initializeAsync() {
// Load the ONNX model from the local filesystem.
this.loadModelFromFile("CartPole.onnx", 4, new Float32Array([-1]), new Float32Array([1]));
}
public override async updateAsync() {
const cartMotor = this.cartMotor.value;
const poleMotor = this.poleMotor.value;
const observations = this.observations;
if (cartMotor === null || poleMotor === null || observations === null) { return; }
// Populate observation array
const cartPosition = cartMotor.currentPosition;
const cartVelocity = cartMotor.currentVelocity;
const poleAngularDistance = Util.signedAngularDifference(poleMotor.currentPosition, Math.PI);
const poleAngularVelocity = poleMotor.currentVelocity;
observations[0] = cartPosition / 0.65;
observations[1] = poleAngularDistance / Math.PI;
observations[2] = cartVelocity;
observations[3] = poleAngularVelocity;
// Apply the actions
const actions = await this.run();
if (actions !== null) {
cartMotor.targetVelocity = actions[0];
}
}
}