277 lines
11 KiB
Python
277 lines
11 KiB
Python
"""
|
|
DonkeyCar RL Runner — Real Training Edition
|
|
============================================
|
|
Trains a PPO or DQN model using Stable-Baselines3, evaluates with evaluate_policy(),
|
|
saves the model to disk, and exits cleanly.
|
|
|
|
Usage:
|
|
python3 donkeycar_sb3_runner.py \
|
|
--agent ppo \
|
|
--env donkey-generated-roads-v0 \
|
|
--timesteps 10000 \
|
|
--eval-episodes 5 \
|
|
--learning-rate 0.0003 \
|
|
--save-dir agent/models/trial-0001 \
|
|
--n-steer 7 \
|
|
--n-throttle 3 \
|
|
--reward-shaping \
|
|
--seed 42
|
|
|
|
Exit codes:
|
|
0 — success, model saved, evaluation complete
|
|
100 — failed to connect to simulator
|
|
101 — training failed
|
|
102 — evaluation failed
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import time
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
import gym_donkeycar
|
|
|
|
from stable_baselines3 import PPO, DQN
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
|
|
from discretize_action import DiscretizedActionWrapper
|
|
|
|
# Optional reward shaping — imported only if available
|
|
try:
|
|
from reward_wrapper import SpeedRewardWrapper
|
|
REWARD_WRAPPER_AVAILABLE = True
|
|
except ImportError:
|
|
REWARD_WRAPPER_AVAILABLE = False
|
|
|
|
|
|
class ThrottleClampWrapper(gym.ActionWrapper):
|
|
"""
|
|
Clamps the throttle dimension of a continuous action to [throttle_min, 1.0].
|
|
Prevents PPO's random initial policy from outputting zero throttle
|
|
and leaving the car stationary.
|
|
Action format expected: [steer, throttle] where steer ∈ [-1,1], throttle ∈ [0,1].
|
|
"""
|
|
def __init__(self, env, throttle_min=0.2):
|
|
super().__init__(env)
|
|
self.throttle_min = throttle_min
|
|
# Update action space so SB3 knows the real bounds
|
|
import numpy as np
|
|
low = np.array([-1.0, throttle_min], dtype=np.float32)
|
|
high = np.array([1.0, 1.0], dtype=np.float32)
|
|
self.action_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
|
|
|
def action(self, action):
|
|
import numpy as np
|
|
action = np.array(action, dtype=np.float32)
|
|
action[1] = float(np.clip(action[1], self.throttle_min, 1.0))
|
|
return action
|
|
|
|
|
|
def log(msg):
|
|
print(msg, flush=True)
|
|
|
|
|
|
def make_env(env_id, agent, n_steer, n_throttle, reward_shaping):
|
|
"""Create and wrap the gym environment."""
|
|
env = gym.make(env_id)
|
|
|
|
if agent == 'dqn':
|
|
env = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
|
|
log(f'[SB3 Runner][MONITOR] Action discretization: steer={n_steer}, throttle={n_throttle}. {time.ctime()}')
|
|
else:
|
|
# PPO uses continuous actions. Clip throttle to [0.2, 1.0] so the car always moves.
|
|
# Without this, PPO's random initial policy outputs throttle~0 and the car sits still.
|
|
log(f'[SB3 Runner][MONITOR] PPO continuous actions. Throttle clamped to [0.2, 1.0]. {time.ctime()}')
|
|
env = ThrottleClampWrapper(env, throttle_min=0.2)
|
|
|
|
if reward_shaping:
|
|
if REWARD_WRAPPER_AVAILABLE:
|
|
env = SpeedRewardWrapper(env)
|
|
log(f'[SB3 Runner][MONITOR] Speed reward shaping ENABLED. {time.ctime()}')
|
|
else:
|
|
log(f'[SB3 Runner][MONITOR] WARNING: reward_wrapper.py not found — reward shaping disabled. {time.ctime()}')
|
|
|
|
return env
|
|
|
|
|
|
class SimHealthCallback:
|
|
"""
|
|
Stable-Baselines3 compatible callback that detects a stuck/dead simulator.
|
|
If the car speed stays near zero for too many consecutive steps, raises an error.
|
|
Also detects if observations stop changing (frozen frame = connection lost).
|
|
"""
|
|
def __init__(self, max_stuck_steps=100, min_speed=0.05):
|
|
self.max_stuck_steps = max_stuck_steps
|
|
self.min_speed = min_speed
|
|
self._stuck_count = 0
|
|
self._last_obs = None
|
|
self._frozen_count = 0
|
|
|
|
def on_step(self, obs, reward, done, info):
|
|
"""Call after each env.step(). Returns False if sim appears dead."""
|
|
# Check speed from info dict
|
|
speed = info.get('speed', None) if isinstance(info, dict) else None
|
|
if speed is not None:
|
|
if float(speed) < self.min_speed:
|
|
self._stuck_count += 1
|
|
else:
|
|
self._stuck_count = 0
|
|
if self._stuck_count >= self.max_stuck_steps:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Sim appears STUCK: speed<{self.min_speed} for {self._stuck_count} steps. {time.ctime()}')
|
|
return False
|
|
|
|
# Check for frozen observation (connection lost)
|
|
if obs is not None and self._last_obs is not None:
|
|
if np.array_equal(obs, self._last_obs):
|
|
self._frozen_count += 1
|
|
else:
|
|
self._frozen_count = 0
|
|
if self._frozen_count >= 30:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Sim appears FROZEN: observation unchanged for {self._frozen_count} steps. {time.ctime()}')
|
|
return False
|
|
self._last_obs = obs
|
|
return True
|
|
|
|
|
|
def train_model(agent, env, learning_rate, timesteps, seed):
|
|
"""Train a PPO or DQN model and return it."""
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
|
|
class HealthCheckCallback(BaseCallback):
|
|
"""SB3 callback that checks sim health each step and stops training if stuck."""
|
|
def __init__(self, max_stuck_steps=100, min_speed=0.05):
|
|
super().__init__(verbose=0)
|
|
self.health = SimHealthCallback(max_stuck_steps=max_stuck_steps, min_speed=min_speed)
|
|
|
|
def _on_step(self):
|
|
infos = self.locals.get('infos', [{}])
|
|
obs = self.locals.get('new_obs', None)
|
|
info = infos[0] if infos else {}
|
|
obs_arr = obs[0] if obs is not None and len(obs) > 0 else None
|
|
healthy = self.health.on_step(obs_arr, None, None, info)
|
|
if not healthy:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Health check failed — stopping training early. {time.ctime()}')
|
|
return False # Stops SB3 training
|
|
return True
|
|
|
|
if agent == 'ppo':
|
|
model = PPO(
|
|
'CnnPolicy',
|
|
env,
|
|
learning_rate=learning_rate,
|
|
verbose=1,
|
|
seed=seed,
|
|
)
|
|
elif agent == 'dqn':
|
|
model = DQN(
|
|
'CnnPolicy',
|
|
env,
|
|
learning_rate=learning_rate,
|
|
verbose=1,
|
|
seed=seed,
|
|
)
|
|
else:
|
|
raise ValueError(f'Unknown agent: {agent}. Use ppo or dqn.')
|
|
|
|
log(f'[SB3 Runner][MONITOR] Starting training: agent={agent} timesteps={timesteps} lr={learning_rate} {time.ctime()}')
|
|
start = time.time()
|
|
health_cb = HealthCheckCallback(max_stuck_steps=100, min_speed=0.02)
|
|
model.learn(total_timesteps=timesteps, callback=health_cb)
|
|
elapsed = time.time() - start
|
|
log(f'[SB3 Runner][MONITOR] Training complete in {elapsed:.1f}s. {time.ctime()}')
|
|
return model
|
|
|
|
|
|
def evaluate_model(model, env, eval_episodes):
|
|
"""Evaluate the model using SB3 evaluate_policy and print per-episode detail."""
|
|
log(f'[SB3 Runner][MONITOR] Evaluating model for {eval_episodes} episodes. {time.ctime()}')
|
|
mean_reward, std_reward = evaluate_policy(
|
|
model,
|
|
env,
|
|
n_eval_episodes=eval_episodes,
|
|
return_episode_rewards=False,
|
|
deterministic=True,
|
|
)
|
|
log(f'[SB3 Runner][TEST] mean_reward={mean_reward:.4f}')
|
|
log(f'[SB3 Runner][TEST] std_reward={std_reward:.4f}')
|
|
return mean_reward, std_reward
|
|
|
|
|
|
def save_model(model, save_dir):
|
|
"""Save the model to save_dir/model.zip."""
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
save_path = os.path.join(save_dir, 'model')
|
|
model.save(save_path)
|
|
log(f'[SB3 Runner][MONITOR] Model saved to {save_path}.zip {time.ctime()}')
|
|
return save_path + '.zip'
|
|
|
|
|
|
def teardown(env):
|
|
"""Close environment cleanly with race avoidance sleep."""
|
|
log(f'[SB3 Runner][MONITOR] Calling env.close() at {time.ctime()}')
|
|
try:
|
|
env.close()
|
|
log(f'[SB3 Runner][MONITOR] env.close() complete. {time.ctime()}')
|
|
except Exception as e:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Exception during env.close(): {e} {time.ctime()}')
|
|
log(f'[SB3 Runner][MONITOR] Waiting 2s before process exit to avoid race. {time.ctime()}')
|
|
time.sleep(2)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Train and evaluate an RL agent on DonkeyCar.')
|
|
parser.add_argument('--agent', type=str, default='ppo', choices=['ppo', 'dqn'], help='RL agent type')
|
|
parser.add_argument('--env', type=str, default='donkey-generated-roads-v0', help='Gym env ID')
|
|
parser.add_argument('--timesteps', type=int, default=10000, help='Training timesteps')
|
|
parser.add_argument('--eval-episodes', type=int, default=5, help='Evaluation episodes')
|
|
parser.add_argument('--learning-rate', type=float, default=0.0003, help='Learning rate')
|
|
parser.add_argument('--save-dir', type=str, default=None, help='Directory to save model')
|
|
parser.add_argument('--n-steer', type=int, default=7, help='Steer bins (DQN only)')
|
|
parser.add_argument('--n-throttle', type=int, default=3, help='Throttle bins (DQN only)')
|
|
parser.add_argument('--reward-shaping', action='store_true', help='Enable speed reward shaping')
|
|
parser.add_argument('--seed', type=int, default=None, help='Random seed')
|
|
args = parser.parse_args()
|
|
|
|
log(f'[SB3 Runner] Starting: agent={args.agent} timesteps={args.timesteps} lr={args.learning_rate} {time.ctime()}')
|
|
|
|
# --- 1. Connect to simulator ---
|
|
env = None
|
|
try:
|
|
env = make_env(args.env, args.agent, args.n_steer, args.n_throttle, args.reward_shaping)
|
|
log(f'[SB3 Runner][MONITOR] Connected to gym env. {time.ctime()}')
|
|
except Exception as e:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Failed to connect to sim: {e}')
|
|
sys.exit(100)
|
|
|
|
# --- 2. Train model ---
|
|
model = None
|
|
try:
|
|
model = train_model(args.agent, env, args.learning_rate, args.timesteps, args.seed)
|
|
except Exception as e:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Training failed: {e} {time.ctime()}')
|
|
teardown(env)
|
|
sys.exit(101)
|
|
|
|
# --- 3. Save model ---
|
|
save_dir = args.save_dir or f'/tmp/donkeycar-trial-{int(time.time())}'
|
|
try:
|
|
saved_path = save_model(model, save_dir)
|
|
except Exception as e:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Model save failed: {e} {time.ctime()}')
|
|
teardown(env)
|
|
sys.exit(101)
|
|
|
|
# --- 4. Evaluate trained policy ---
|
|
try:
|
|
mean_reward, std_reward = evaluate_model(model, env, args.eval_episodes)
|
|
except Exception as e:
|
|
log(f'[SB3 Runner][MONITOR ALERT] Evaluation failed: {e} {time.ctime()}')
|
|
teardown(env)
|
|
sys.exit(102)
|
|
|
|
# --- 5. Teardown ---
|
|
teardown(env)
|
|
log(f'[SB3 Runner][MONITOR] Exiting RL runner at {time.ctime()}')
|