donkeycar-rl-autoresearch/agent/multitrack_runner.py

769 lines
33 KiB
Python

"""
Wave 3 Multi-Track Runner
=========================
Trains PPO across multiple DonkeyCar tracks by round-robin switching between
training segments. After training, evaluates on zero-shot test tracks
(mini_monaco + warren) to measure cross-track generalization.
Track classification (from visual analysis):
TRAINING : generated_track, mountain_track (Wave 4 — no generated_road, no warm-start)
(outdoor, same road markings — yellow centre + white edge)
TEST/EVAL : mini_monaco, warren
(never seen during training — generalization benchmark)
SKIPPED : warehouse, robo_racing_league, waveshare, circuit_launch
(fully indoor — different domain entirely)
avc_sparkfun (outdoor but orange markings — too different)
Track switching strategy:
Close env → send_exit_scene_raw() → wait 4s → gym.make(next_track)
This avoids the double-connect issue in switch_track() when an env is
already open on the current track.
Key invariants (ADR-005, ADR-006):
- model is always defined before model.save()
- env.close() + time.sleep(2) before every track switch
- Results appended to JSONL, never overwritten
Output lines parsed by wave3_controller.py:
[W3 Runner][TRAIN] track=<name> segment_reward=<float>
[W3 Runner][TEST] track=<name> mean_reward=<float> mean_steps=<float>
[W3 Runner][TEST] combined_test_score=<float>
Usage:
python3 multitrack_runner.py \\
--total-timesteps 200000 \\
--steps-per-switch 10000 \\
--learning-rate 0.000225 \\
--warm-start models/champion/model.zip \\
--save-dir models/wave3/trial-0001 \\
--eval-episodes 3
Exit codes:
0 — success, model saved, evaluation complete
100 — failed to connect to simulator on initial track
101 — training error
102 — evaluation error
"""
import argparse
import os
import sys
import time
import json
import numpy as np
from collections import deque
from datetime import datetime
import gymnasium as gym
import gym_donkeycar
from stable_baselines3 import PPO
from stable_baselines3.common.utils import get_schedule_fn
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
# ---- Project paths ----
AGENT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, AGENT_DIR)
from donkeycar_sb3_runner import ThrottleClampWrapper, SimHealthCallback
from reward_wrapper import SpeedRewardWrapper
# ---- Track catalogue ----
# Maps short name → gym env ID
# Wave 4 training tracks.
# generated_road removed: it is visually too similar to generated_track
# and its Phase-2-champion warm-start caused catastrophic forgetting that
# prevented generalisation. generated_track + mountain_track have
# meaningfully different backgrounds, forcing the model to learn
# track-marking features rather than scene-specific shortcuts.
TRAINING_TRACKS = [
('generated_track', 'donkey-generated-track-v0'),
('mountain_track', 'donkey-mountain-track-v0'),
]
# Zero-shot generalization test tracks — never seen during training.
# Warren was removed: its episode-done condition does not fire when the car
# crosses the INSIDE edge (CTE stays small), so the car can drive among
# chairs indefinitely and scores are meaningless.
TEST_TRACKS = [
('mini_monaco', 'donkey-minimonaco-track-v0'),
]
# How many steps to sample before deciding the segment reward (shorter than segment)
SEGMENT_EVAL_STEPS = 500
EXIT_SCENE_WAIT = 4.0 # seconds after exit_scene for sim to reach menu
THROTTLE_MIN = 0.2 # minimum throttle (prevents stationary car)
SPEED_SCALE = 0.1 # SpeedRewardWrapper coefficient
# ---- Logging ----
def log(msg):
ts = datetime.now().strftime('%H:%M:%S')
print(f'[{ts}] {msg}', flush=True)
# ---- Health check callback ----
class HealthCheckCallback(BaseCallback):
"""Stops training early if sim is stuck or frozen."""
def __init__(self, max_stuck_steps=150, min_speed=0.02):
super().__init__(verbose=0)
self.health = SimHealthCallback(max_stuck_steps=max_stuck_steps, min_speed=min_speed)
def _on_step(self):
infos = self.locals.get('infos', [{}])
obs = self.locals.get('new_obs', None)
info = infos[0] if infos else {}
obs_arr = obs[0] if obs is not None and len(obs) > 0 else None
healthy = self.health.on_step(obs_arr, None, None, info)
if not healthy:
log('[W3 Runner][HEALTH] Sim stuck/frozen — stopping segment early.')
return False
return True
class StuckTerminationWrapper(gym.Wrapper):
"""
Terminates the episode when the car hasn't made meaningful positional
progress over `stuck_steps` consecutive steps OR `max_stuck_seconds`
wall-clock seconds, whichever comes first.
The wall-clock timeout is critical for DummyVecEnv: when both cars are
simultaneously stuck against a wall, Unity's physics engine slows to
1-2 FPS (heavy collision computation). At that rate, stuck_steps=40
can take 1+ minutes of wall-clock time. The wall-clock timeout catches
this case regardless of sim speed.
Handles three cases the sim misses:
1. Car pressed slowly against a barrier — Unity's OnCollisionEnter fires
once then resets; Python never sees sustained contact. Speed-based check
terminates after max_low_speed_seconds at speed < low_speed_threshold.
2. Car sliding laterally along a barrier — position displacement > 0.5m
keeps resetting the wall-clock timer; speed stays ≈0. Speed-based check
catches this; position-based check cannot.
3. Car circling off the start/finish line — efficiency→0 gives zero reward
but the episode never ends, wasting training steps with no signal.
When stuck is detected: terminated=True so SpeedRewardWrapper returns -1.0.
"""
def __init__(self, env, stuck_steps: int = 80, min_displacement: float = 0.5,
max_stuck_seconds: float = 12.0, max_episode_seconds: float = 30.0,
low_speed_threshold: float = 0.5, max_low_speed_seconds: float = 3.0,
max_cte: float = 5.0, max_high_cte_seconds: float = 1.0):
super().__init__(env)
self.stuck_steps = stuck_steps
self.min_displacement = min_displacement
self.max_stuck_seconds = max_stuck_seconds
self.max_episode_seconds = max_episode_seconds
self.low_speed_threshold = low_speed_threshold
self.max_low_speed_seconds = max_low_speed_seconds
self.max_cte = max_cte
self.max_high_cte_seconds = max_high_cte_seconds
self._pos_buf: deque = deque(maxlen=stuck_steps)
self._last_progress_pos = None
self._last_progress_t = None
self._episode_start_t = None
self._low_speed_start_t = None
self._high_cte_start_t = None
def reset(self, **kwargs):
self._pos_buf.clear()
self._last_progress_pos = None
self._last_progress_t = None
self._episode_start_t = time.time()
self._low_speed_start_t = None
self._high_cte_start_t = None
return self.env.reset(**kwargs)
def step(self, action):
result = self.env.step(action)
if len(result) == 5:
obs, reward, terminated, truncated, info = result
else:
obs, reward, done, info = result
terminated, truncated = done, False
pos = info.get('pos', None)
now = time.time()
if pos is not None:
try:
pos_arr = np.array(list(pos)[:3], dtype=np.float64)
self._pos_buf.append(pos_arr)
# Wall-clock stuck detection: reset timer whenever car moves > min_displacement
if self._last_progress_pos is None:
self._last_progress_pos = pos_arr
self._last_progress_t = now
else:
moved = float(np.linalg.norm(pos_arr - self._last_progress_pos))
if moved >= self.min_displacement:
# Made meaningful progress — reset wall-clock timer
self._last_progress_pos = pos_arr
self._last_progress_t = now
elif (now - self._last_progress_t) > self.max_stuck_seconds:
# Wall-clock timeout — terminate regardless of step count
if not terminated:
terminated = True
info['stuck_termination'] = True
info['stuck_reason'] = 'wall_clock_timeout'
except (TypeError, ValueError):
pass
# Explicit hit check: gym_donkeycar sets done=True for hit!="none" but timing
# gaps (script execution order, 1-frame delay) can let it slip through.
# This is a zero-latency Python-side backstop — fires on the same step as the hit.
if not terminated:
hit = info.get('hit', 'none')
if hit and hit != 'none':
terminated = True
info['stuck_termination'] = True
info['stuck_reason'] = f'hit_{hit}'
# Speed-based stuck detection: catches car pinned against a barrier.
# A car pressed against a wall has speed≈0 even while sliding laterally
# (accumulating displacement that resets the position-based timer above).
if not terminated:
try:
speed = float(info.get('speed', 999.0) or 999.0)
except (TypeError, ValueError):
speed = 999.0
if speed < self.low_speed_threshold:
if self._low_speed_start_t is None:
self._low_speed_start_t = now
elif (now - self._low_speed_start_t) > self.max_low_speed_seconds:
terminated = True
info['stuck_termination'] = True
info['stuck_reason'] = 'low_speed_timeout'
else:
self._low_speed_start_t = None
# CTE-based termination: catches car stuck at/past the road edge barrier.
# A car pressed sideways against a barrier has high CTE even while the
# speed and position checks are fooled by wheel oscillation.
if not terminated:
try:
cte = float(info.get('cte', 0.0) or 0.0)
except (TypeError, ValueError):
cte = 0.0
if abs(cte) > self.max_cte:
if self._high_cte_start_t is None:
self._high_cte_start_t = now
elif (now - self._high_cte_start_t) > self.max_high_cte_seconds:
terminated = True
info['stuck_termination'] = True
info['stuck_reason'] = 'high_cte_timeout'
else:
self._high_cte_start_t = None
# Hard episode wall-clock limit — fires regardless of car position or sim fps.
# Catches cars sliding slowly along barriers that keep resetting the
# max_stuck_seconds timer by drifting 0.5m at a time.
if not terminated and self._episode_start_t is not None:
if (now - self._episode_start_t) > self.max_episode_seconds:
terminated = True
info['stuck_termination'] = True
info['stuck_reason'] = 'episode_timeout'
# Step-count stuck detection (original logic)
if not terminated and len(self._pos_buf) >= self.stuck_steps:
displacement = float(np.linalg.norm(
self._pos_buf[-1] - self._pos_buf[0]
))
if displacement < self.min_displacement:
terminated = True
info['stuck_termination'] = True
info['stuck_reason'] = 'step_count'
if len(result) == 5:
return obs, reward, terminated, truncated, info
return obs, reward, terminated or truncated, info
def wrap_env(raw_env):
"""Apply standard wrappers: throttle clamp + stuck detection + speed reward."""
env = ThrottleClampWrapper(raw_env, throttle_min=THROTTLE_MIN)
env = StuckTerminationWrapper(env, stuck_steps=40, min_displacement=0.5)
env = SpeedRewardWrapper(env, speed_scale=SPEED_SCALE)
return env
# ---- Track switching ----
def _send_exit_scene(env, verbose=True):
"""
Send exit_scene through the EXISTING connection on env.
Critical: the DonkeyCar sim creates one vehicle per TCP connection.
Sending exit_scene via a NEW raw socket creates a second vehicle and
the sim ignores it for the real training session. We must use the
existing viewer connection that env already holds.
"""
try:
base = env.unwrapped # strips all gym.Wrapper layers
if hasattr(base, 'viewer') and base.viewer is not None:
base.viewer.exit_scene() # sends {'msg_type': 'exit_scene'} on existing TCP
time.sleep(0.5) # let the message flush before closing socket
if verbose:
log('[W3 Runner] exit_scene sent on existing viewer connection.')
return True
else:
if verbose:
log('[W3 Runner] Warning: no viewer found on unwrapped env.')
return False
except Exception as e:
if verbose:
log(f'[W3 Runner] Warning: viewer.exit_scene() raised: {e}')
return False
def close_and_switch(current_env, next_env_id, verbose=True):
"""
Cleanly close current env and connect to next track.
Correct order (IMPORTANT):
1. viewer.exit_scene() on existing connection ← tells sim to go to menu
2. env.close() + sleep(2) ← disconnect (ADR-006)
3. sleep(EXIT_SCENE_WAIT) ← wait for sim menu
4. gym.make(next_env_id) + wrap ← connect to new track
Returns: new wrapped env, or raises on connection failure.
"""
if current_env is not None:
# Step 1: tell the sim to exit the scene BEFORE we close the connection
if verbose:
log('[W3 Runner] Sending exit_scene via existing viewer connection...')
_send_exit_scene(current_env, verbose=verbose)
# Step 2: now close the env (ADR-006)
if verbose:
log('[W3 Runner] Closing current env...')
try:
current_env.close()
except Exception as e:
log(f'[W3 Runner] Warning: env.close() raised: {e}')
time.sleep(2) # ADR-006
# Step 3: wait for sim to reach the scene-selection menu
if verbose:
log(f'[W3 Runner] Waiting {EXIT_SCENE_WAIT}s for sim to reach main menu...')
time.sleep(EXIT_SCENE_WAIT)
# Step 4: connect to the target track
if verbose:
log(f'[W3 Runner] Connecting to {next_env_id}...')
raw_env = gym.make(next_env_id)
env = wrap_env(raw_env)
if verbose:
log(f'[W3 Runner] ✅ Connected to {next_env_id}')
return env
# ---- Model creation / warm-start ----
def create_or_load_model(env, learning_rate, warm_start_path=None, seed=None):
"""
Load model from warm_start_path (PPO.load + set_env) or create fresh PPO.
Falls back to fresh model if warm-start path missing or space mismatch.
"""
if warm_start_path and os.path.exists(warm_start_path):
log(f'[W3 Runner] Loading warm-start model from {warm_start_path}')
try:
model = PPO.load(warm_start_path, env=env, device='auto')
# Three-part LR override required after PPO.load():
# 1. model.learning_rate — Python attribute (used to recreate lr_schedule)
# 2. model.lr_schedule — FloatSchedule used by _update_learning_rate()
# during every train() call. Without this,
# _update_learning_rate() reverts the optimizer
# back to the saved LR on the first gradient step.
# 3. optimizer param_groups — immediate effect before first train()
model.learning_rate = learning_rate
model.lr_schedule = get_schedule_fn(learning_rate)
for pg in model.policy.optimizer.param_groups:
pg['lr'] = learning_rate
log(f'[W3 Runner] ✅ Warm start loaded. LR overridden to {learning_rate:.6f} '
f'(model + lr_schedule + {len(model.policy.optimizer.param_groups)} optimizer param group(s))')
return model
except Exception as e:
log(f'[W3 Runner] ⚠️ Warm start failed ({e}), training from scratch.')
log(f'[W3 Runner] Creating fresh PPO model (lr={learning_rate:.6f})')
model = PPO(
'CnnPolicy',
env,
learning_rate=learning_rate,
verbose=1, # show rollout stats so training progress is visible in log
seed=seed,
)
return model
# ---- Training loop ----
def train_multitrack(model, first_env, total_timesteps, steps_per_switch,
save_dir=None):
"""
Train PPO across training tracks by round-robin switching every steps_per_switch steps.
Saves BOTH the latest checkpoint AND the best model seen during training.
The best model is saved to save_dir/best_model.zip whenever a new high
segment reward is achieved. At the end, the best model weights are
reloaded so the returned model is the best seen, not just the final one.
"""
env = first_env
steps_done = 0
track_idx = 0
segment_rewards = []
health_cb = HealthCheckCallback()
best_segment_reward = float('-inf')
best_model_path = os.path.join(save_dir, 'best_model') if save_dir else None
log(f'[W3 Runner] Starting multi-track training:')
log(f' Total timesteps : {total_timesteps:,}')
log(f' Steps per switch: {steps_per_switch:,}')
log(f' Training tracks : {[t[0] for t in TRAINING_TRACKS]}')
log(f' Best model saved: {best_model_path}.zip')
log(f' Rotations : ~{total_timesteps // (steps_per_switch * len(TRAINING_TRACKS))} full cycles')
while steps_done < total_timesteps:
track_name, track_env_id = TRAINING_TRACKS[track_idx]
segment_steps = min(steps_per_switch, total_timesteps - steps_done)
log(f'\n[W3 Runner] === Segment: {track_name} | '
f'{steps_done:,}/{total_timesteps:,} steps done | '
f'segment={segment_steps:,} steps ===')
# Train segment
model.learn(
total_timesteps=segment_steps,
reset_num_timesteps=False,
callback=health_cb,
)
steps_done += segment_steps
# --- Save latest checkpoint (crash recovery) ---
if save_dir:
try:
os.makedirs(save_dir, exist_ok=True)
# Save numbered checkpoint — NEVER overwrite previous checkpoints.
# Every segment is preserved so we can return to any point in training.
ckpt_name = f'checkpoint_{steps_done:07d}'
model.save(os.path.join(save_dir, ckpt_name))
# Also overwrite 'model' for crash-recovery (latest weights)
model.save(os.path.join(save_dir, 'model'))
log(f'[W3 Runner] Checkpoint saved: {ckpt_name} (step {steps_done:,})')
except Exception as e:
log(f'[W3 Runner] WARNING: checkpoint save failed: {e}')
# Quick segment reward estimate — one deterministic episode,
# capped at MAX_EVAL_STEPS to prevent non-terminating episodes
# (e.g. car driving forever on wide generated_track) inflating the metric.
MAX_EVAL_STEPS = 3000
try:
obs = env.reset()
ep_reward = 0.0
for _ in range(MAX_EVAL_STEPS):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
ep_reward += float(reward[0] if hasattr(reward, '__len__') else reward)
done_flag = done[0] if hasattr(done, '__len__') else done
if done_flag:
break
seg_reward = ep_reward
log(f'[W3 Runner][TRAIN] track={track_name} segment_reward={seg_reward:.2f}')
segment_rewards.append((track_name, float(seg_reward)))
# --- Save BEST model if this segment beat the previous best ---
if seg_reward > best_segment_reward and best_model_path:
best_segment_reward = seg_reward
try:
model.save(best_model_path)
log(f'[W3 Runner] ⭐ NEW BEST model saved! '
f'step={steps_done:,} reward={seg_reward:.2f} '
f'track={track_name}')
except Exception as e:
log(f'[W3 Runner] WARNING: best model save failed: {e}')
except Exception as e:
log(f'[W3 Runner][TRAIN] Segment eval failed: {e}')
segment_rewards.append((track_name, 0.0))
if steps_done >= total_timesteps:
break
# Switch to next training track
next_track_idx = (track_idx + 1) % len(TRAINING_TRACKS)
next_track_name, next_env_id = TRAINING_TRACKS[next_track_idx]
log(f'[W3 Runner] Switching: {track_name}{next_track_name}')
try:
new_env = close_and_switch(env, next_env_id)
model.set_env(new_env)
env = new_env
track_idx = next_track_idx
except Exception as e:
log(f'[W3 Runner] ⚠️ Track switch failed: {e}. Retrying in 5s...')
time.sleep(5)
try:
new_env = close_and_switch(None, next_env_id)
model.set_env(new_env)
env = new_env
track_idx = next_track_idx
except Exception as e2:
log(f'[W3 Runner] ❌ Track switch retry failed: {e2}. Continuing on current track.')
# Stay on current track — don't crash the whole run
log(f'\n[W3 Runner] Training complete: {steps_done:,} total steps across '
f'{len(segment_rewards)} segments.')
log(f'[W3 Runner] Best segment reward during training: {best_segment_reward:.2f}')
# --- Reload the BEST model weights before returning ---
# The final model may have drifted from the best policy found mid-training.
# Always return the best checkpoint, not the last one.
if best_model_path and os.path.exists(best_model_path + '.zip'):
try:
log(f'[W3 Runner] Reloading best model from {best_model_path}.zip')
model = PPO.load(best_model_path, env=env, device='auto')
log(f'[W3 Runner] ✅ Best model reloaded (reward={best_segment_reward:.2f})')
except Exception as e:
log(f'[W3 Runner] WARNING: could not reload best model: {e}. Using final model.')
return env, segment_rewards
# ---- Zero-shot evaluation on test tracks ----
def evaluate_test_tracks(model, current_env, eval_episodes):
"""
Evaluate the trained model on each test track (zero-shot generalization).
Switches to each test track, runs eval_episodes episodes, records
mean_reward and mean_steps. Closes test envs when done.
Returns:
test_results: dict of {track_name: {'mean_reward': float, 'mean_steps': float}}
combined_test_score: sum of mean_rewards across test tracks
"""
log(f'\n[W3 Runner] ===== ZERO-SHOT EVALUATION on TEST tracks =====')
log(f' Test tracks : {[t[0] for t in TEST_TRACKS]}')
log(f' Eval episodes : {eval_episodes}')
test_results = {}
env = current_env
for track_name, track_env_id in TEST_TRACKS:
log(f'\n[W3 Runner] Switching to TEST track: {track_name}')
try:
env = close_and_switch(env, track_env_id)
except Exception as e:
log(f'[W3 Runner] ❌ Cannot connect to test track {track_name}: {e}')
test_results[track_name] = {'mean_reward': 0.0, 'mean_steps': 0.0, 'error': str(e)}
continue
# Run episodes manually to capture step count
all_rewards = []
all_steps = []
for ep in range(eval_episodes):
obs, info = env.reset()
total_reward = 0.0
steps = 0
done = False
pos_samples = [] # sample position every 100 steps to detect shuttling
while not done and steps < 2000:
action, _ = model.predict(obs, deterministic=True)
result = env.step(action)
if len(result) == 5:
obs, reward, terminated, truncated, info = result
done = terminated or truncated
else:
obs, reward, done, info = result
total_reward += reward
steps += 1
# Sample position every 100 steps for shuttle-exploit detection
if steps % 100 == 0:
raw_info = info[0] if isinstance(info, (list, tuple)) else info
pos = raw_info.get('pos', None) if isinstance(raw_info, dict) else None
speed = raw_info.get('speed', 0) if isinstance(raw_info, dict) else 0
if pos is not None:
pos_samples.append(np.array(list(pos)[:3], dtype=np.float64))
# Detect shuttle: check if position oscillates rather than progresses
shuttle_warning = ''
if len(pos_samples) >= 3:
# Compute net progress: total displacement from start to end
net_dist = float(np.linalg.norm(pos_samples[-1] - pos_samples[0]))
# Compute total path between samples
total_sampled = sum(
float(np.linalg.norm(pos_samples[i+1] - pos_samples[i]))
for i in range(len(pos_samples) - 1)
)
macro_eff = net_dist / total_sampled if total_sampled > 0.1 else 1.0
if macro_eff < 0.3 and steps >= 500:
shuttle_warning = f' ⚠️ SHUTTLE EXPLOIT? macro_efficiency={macro_eff:.2f}'
all_rewards.append(total_reward)
all_steps.append(steps)
log(f'[W3 Runner] {track_name} ep{ep+1}: reward={total_reward:.1f} steps={steps}'
f' ({total_reward/max(steps,1):.2f}/step){shuttle_warning}')
time.sleep(0.5)
mean_reward = float(np.mean(all_rewards))
mean_steps = float(np.mean(all_steps))
drove_far = mean_steps > 200
test_results[track_name] = {
'mean_reward': mean_reward,
'mean_steps': mean_steps,
'drove_far': drove_far,
}
verdict = '✅ DRIVES' if drove_far else '❌ CRASHES'
log(f'[W3 Runner][TEST] track={track_name} mean_reward={mean_reward:.2f} '
f'mean_steps={mean_steps:.1f} {verdict}')
# Combined score = sum of mean_rewards on test tracks
combined = sum(r['mean_reward'] for r in test_results.values())
log(f'\n[W3 Runner][TEST] combined_test_score={combined:.4f}')
log(f'[W3 Runner][TEST] mini_monaco_reward='
f'{test_results.get("mini_monaco", {}).get("mean_reward", 0.0):.4f}')
log(f'[W3 Runner][TEST] warren_reward='
f'{test_results.get("warren", {}).get("mean_reward", 0.0):.4f}')
return test_results, combined, env
# ---- Main ----
def main():
parser = argparse.ArgumentParser(description='Wave 3 Multi-Track PPO Trainer.')
parser.add_argument('--total-timesteps', type=int, default=200000,
help='Total training timesteps across all tracks (default: 200000)')
parser.add_argument('--steps-per-switch', type=int, default=10000,
help='Steps on each track before switching (default: 10000)')
parser.add_argument('--learning-rate', type=float, default=0.000225,
help='PPO learning rate (default: 0.000225 = Phase 2 champion)')
parser.add_argument('--warm-start', type=str, default=None,
help='Path to .zip model for warm start (default: models/champion/model.zip)')
parser.add_argument('--save-dir', type=str, default=None,
help='Directory to save trained model')
parser.add_argument('--eval-episodes', type=int, default=3,
help='Episodes per test track for zero-shot evaluation')
parser.add_argument('--seed', type=int, default=None,
help='Random seed')
parser.add_argument('--skip-eval', action='store_true',
help='Skip zero-shot evaluation (training only)')
args = parser.parse_args()
# Wave 4: never auto-detect a warm start. Training always begins from
# random weights so the CNN is not biased toward any single track.
warm_start = args.warm_start # None unless caller explicitly passes one
save_dir = args.save_dir or os.path.join(AGENT_DIR, 'models', 'wave4',
f'trial-{int(time.time())}')
log(f'[W3 Runner] === Wave 4 Multi-Track Training (scratch, no warm-start) ===')
log(f'[W3 Runner] total_timesteps ={args.total_timesteps:,}')
log(f'[W3 Runner] steps_per_switch={args.steps_per_switch:,}')
log(f'[W3 Runner] learning_rate ={args.learning_rate:.6f}')
log(f'[W3 Runner] warm_start ={warm_start}')
log(f'[W3 Runner] save_dir ={save_dir}')
log(f'[W3 Runner] eval_episodes ={args.eval_episodes}')
# ---- 1. Connect to first training track ----
# Assume sim is already at the main menu (user-started, or previous run exited cleanly).
# gym.make() on the first track will load it directly from the menu.
first_track_name, first_env_id = TRAINING_TRACKS[0]
log(f'\n[W3 Runner] Connecting to first training track: {first_track_name} ({first_env_id})')
env = None
try:
raw_env = gym.make(first_env_id)
env = VecTransposeImage(DummyVecEnv([lambda: wrap_env(gym.make(first_env_id))]))
log(f'[W3 Runner] ✅ Connected to {first_env_id}')
except Exception as e:
log(f'[W3 Runner] ❌ Failed to connect to first training track: {e}')
sys.exit(100)
# ---- 2. Create or load model ----
model = None
try:
model = create_or_load_model(env, args.learning_rate, warm_start, args.seed)
except Exception as e:
log(f'[W3 Runner] ❌ Model creation failed: {e}')
try:
env.close()
time.sleep(2)
except Exception:
pass
sys.exit(101)
# ---- 3. Multi-track training ----
try:
env, segment_rewards = train_multitrack(
model, env,
total_timesteps=args.total_timesteps,
steps_per_switch=args.steps_per_switch,
save_dir=save_dir,
)
except Exception as e:
log(f'[W3 Runner] ❌ Training failed: {e}')
try:
env.close()
time.sleep(2)
except Exception:
pass
sys.exit(101)
# ---- 4. Save model ----
# ADR-005: model is always defined before model.save()
try:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'model')
model.save(save_path)
log(f'[W3 Runner] ✅ Model saved to {save_path}.zip')
except Exception as e:
log(f'[W3 Runner] ⚠️ Model save failed: {e}')
# ---- 5. Zero-shot evaluation on test tracks ----
combined_test_score = 0.0
test_results = {}
if not args.skip_eval:
try:
test_results, combined_test_score, env = evaluate_test_tracks(
model, env, args.eval_episodes
)
except Exception as e:
log(f'[W3 Runner] ❌ Test evaluation failed: {e}')
# Ensure combined_test_score = 0 is recorded (trial still valid)
# ---- 6. Print training summary ----
log(f'\n[W3 Runner] ===== TRAINING SUMMARY =====')
if segment_rewards:
by_track = {}
for tname, rew in segment_rewards:
by_track.setdefault(tname, []).append(rew)
for tname, rewards in by_track.items():
log(f'[W3 Runner][TRAIN] {tname}: '
f'mean={np.mean(rewards):.1f} over {len(rewards)} segments')
log(f'\n[W3 Runner] ===== TEST SUMMARY (zero-shot generalization) =====')
for tname, metrics in test_results.items():
verdict = '✅ DRIVES' if metrics.get('drove_far') else '❌ CRASHES'
log(f'[W3 Runner][TEST] {tname}: '
f'reward={metrics.get("mean_reward", 0):.1f} '
f'steps={metrics.get("mean_steps", 0):.0f} {verdict}')
log(f'[W3 Runner][TEST] combined_test_score={combined_test_score:.4f}')
# ---- 7. Teardown ----
log(f'[W3 Runner] Closing final env...')
try:
env.close()
log(f'[W3 Runner] env.close() complete.')
except Exception as e:
log(f'[W3 Runner] Warning: env.close() raised: {e}')
time.sleep(2) # ADR-006
log(f'[W3 Runner] ✅ Multi-track runner complete. Exiting.')
if __name__ == '__main__':
main()