donkeycar-rl-autoresearch/agent/experiments/exp13_gentrack_v4.py

205 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Exp 13: Single track — generated_track, v4 reward, back to basics.
This is a DELIBERATE return to the setup that worked in Wave 4 Trial 9.
What Wave 4 used (from git history at commit 7534527):
- v4 reward: base × efficiency × speed_bonus
- Circles give ~0 reward naturally (efficiency → 0)
- No extra termination heuristics needed
- wrap_env: ThrottleClampWrapper + SpeedRewardWrapper ONLY
- No StuckTerminationWrapper in the gym wrapper chain
- Stuck detection was a PPO callback (HealthCheckCallback)
- throttle_min=0.2, lr=0.000725
- Single track
We have been overcomplicating this with efficiency gates, progress
terminators, CTE patience, wall-clock timeouts etc. Wave 4 Trial 9
drove generated_track 2000/2000 without any of that. Going back.
Stopping criterion: eval every 5k steps, stop when 3 laps achieved.
"""
import sys, os, time
sys.path.insert(0, '/home/paulh/projects/donkeycar-rl-autoresearch/agent')
from donkeycar_sb3_runner import ThrottleClampWrapper
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
import numpy as np
HOST = '10.0.0.55'
PORT = 9091
TRACK_ID = 'donkey-generated-track-v0'
TRACK_NAME = 'generated_track'
THROTTLE_MIN = 0.2
SPEED_SCALE = 0.1
LR = 0.000725
MAX_STEPS = 300000
EVAL_EVERY = 5000
LAP_STOP = 3 # stop when eval achieves this many laps
SAVE_DIR = '/home/paulh/projects/donkeycar-rl-autoresearch/agent/models/exp13-gentrack-v4'
os.makedirs(SAVE_DIR, exist_ok=True)
# ---- v4 reward (inline — same formula as Wave 4) ----
import gymnasium as gym_mod
from collections import deque
class V4RewardWrapper(gym_mod.Wrapper):
"""
v4 reward: base × efficiency × speed_bonus.
Exactly as used during Wave 4 successful training.
Circles give ~0 reward (efficiency → 0). No extra termination needed.
"""
def __init__(self, env, speed_scale=0.1, window_size=60,
min_efficiency=0.05, max_cte=8.0):
super().__init__(env)
self.speed_scale = speed_scale
self.min_efficiency = min_efficiency
self.max_cte = max_cte
self._pos_history = deque(maxlen=window_size + 1)
def reset(self, **kwargs):
self._pos_history.clear()
return self.env.reset(**kwargs)
def step(self, action):
result = self.env.step(action)
if len(result) == 5:
obs, _sim_r, terminated, truncated, info = result
done = terminated or truncated
else:
obs, _sim_r, done, info = result
terminated, truncated = done, False
reward = self._compute_reward(done, info)
if len(result) == 5:
return obs, reward, terminated, truncated, info
return obs, reward, done, info
def _compute_reward(self, done, info):
if done:
return -1.0
pos = info.get('pos', None)
if pos is not None:
try:
self._pos_history.append(np.array(list(pos)[:3], dtype=np.float64))
except (TypeError, ValueError):
pass
try:
cte = float(info.get('cte', 0.0) or 0.0)
except (TypeError, ValueError):
cte = 0.0
base = 1.0 - min(abs(cte) / self.max_cte, 1.0)
efficiency = self._compute_efficiency()
eff = max(0.0, (efficiency - self.min_efficiency) / (1.0 - self.min_efficiency))
try:
speed = max(0.0, float(info.get('speed', 0.0) or 0.0))
except (TypeError, ValueError):
speed = 0.0
return base * eff * (1.0 + self.speed_scale * speed)
def _compute_efficiency(self):
if len(self._pos_history) < 3:
return 1.0
positions = list(self._pos_history)
net = np.linalg.norm(positions[-1] - positions[0])
total = sum(np.linalg.norm(positions[i+1] - positions[i])
for i in range(len(positions) - 1))
return float(net / total) if total > 1e-6 else 1.0
def log(msg):
from datetime import datetime
print(f'[{datetime.now().strftime("%H:%M:%S")}] {msg}', flush=True)
def make_env():
def _init():
raw = gym.make(TRACK_ID, conf={'host': HOST, 'port': PORT})
env = ThrottleClampWrapper(raw, throttle_min=THROTTLE_MIN)
env = V4RewardWrapper(env, speed_scale=SPEED_SCALE)
return env
return _init
log('='*60)
log(f'Exp 13: {TRACK_NAME}, v4 reward, back to basics')
log(f' Host: {HOST}:{PORT}')
log(f' throttle_min={THROTTLE_MIN}, lr={LR}')
log(f' Reward: v4 (base × efficiency × speed) — same as Wave 4')
log(f' Wrappers: ThrottleClamp + V4Reward ONLY (no extra terminators)')
log(f' Stop: eval every {EVAL_EVERY:,} steps, stop at {LAP_STOP} laps')
log('='*60)
env = VecTransposeImage(DummyVecEnv([make_env()]))
model = PPO('CnnPolicy', env, learning_rate=LR, verbose=1, device='cpu')
log('PPO created. Training...')
best_reward = float('-inf')
best_laps = 0
steps_done = 0
while steps_done < MAX_STEPS:
seg = min(EVAL_EVERY, MAX_STEPS - steps_done)
model.learn(total_timesteps=seg, reset_num_timesteps=False)
steps_done += seg
ckpt = os.path.join(SAVE_DIR, f'checkpoint_{steps_done:07d}')
model.save(ckpt)
model.save(os.path.join(SAVE_DIR, 'model'))
# Eval: one deterministic episode, count laps
try:
obs = env.reset()
ep_r = 0.0
ep_steps = 0
laps = 0
prev_lc = 0
for _ in range(2000):
action, _ = model.predict(obs, deterministic=True)
obs, r, d, info = env.step(action)
ep_r += float(r[0])
ep_steps += 1
try:
lc = int((info[0] if isinstance(info, (list, tuple)) else info)
.get('lap_count', 0) or 0)
if lc > prev_lc:
laps = lc
prev_lc = lc
except Exception:
pass
if bool(d[0]):
break
status = '' if ep_steps >= 2000 else f'❌@{ep_steps}'
log(f'[{steps_done:,}] reward={ep_r:.1f} steps={ep_steps} '
f'laps={laps} {status}')
if ep_r > best_reward:
best_reward = ep_r
model.save(os.path.join(SAVE_DIR, 'best_model'))
log(f' ⭐ NEW BEST: {best_reward:.1f}')
if laps > best_laps:
best_laps = laps
log(f' 🏆 BEST LAPS: {best_laps}')
if laps >= LAP_STOP:
log(f' 🎯 {laps} laps achieved at {steps_done:,} steps — STOPPING')
break
except Exception as e:
log(f' Eval error: {e}')
env.close()
time.sleep(3)
log(f'\nDone. best_laps={best_laps} best_reward={best_reward:.1f}')
log(f'Best model: {SAVE_DIR}/best_model.zip')
log('=== Exp 13 COMPLETE ===')