donkeycar-rl-autoresearch/agent/eval_on_track.py

96 lines
3.5 KiB
Python

"""
eval_on_track.py — Evaluate a saved model on any track, zero-shot.
Usage:
python3 eval_on_track.py --model models/wave4-champion/model.zip \
--track donkey-generated-roads-v0 \
--episodes 3 --max-steps 3000
This is the proper zero-shot evaluation: load a trained model, connect
to a track it has never seen, run N episodes, report reward and steps.
"""
import argparse, os, sys, time
import numpy as np
sys.path.insert(0, os.path.dirname(__file__))
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
import gymnasium as gym
from reward_wrapper import SpeedRewardWrapper
from donkeycar_sb3_runner import ThrottleClampWrapper
from multitrack_runner import StuckTerminationWrapper
THROTTLE_MIN = 0.2
SPEED_SCALE = 0.1
def make_env(env_id):
raw = gym.make(env_id)
env = ThrottleClampWrapper(raw, throttle_min=THROTTLE_MIN)
env = StuckTerminationWrapper(env, stuck_steps=80, min_displacement=0.5)
env = SpeedRewardWrapper(env, speed_scale=SPEED_SCALE)
return env
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help='Path to model.zip')
parser.add_argument('--track', default='donkey-generated-roads-v0')
parser.add_argument('--episodes', type=int, default=3)
parser.add_argument('--max-steps', type=int, default=3000)
args = parser.parse_args()
print(f'\n=== Zero-Shot Eval ===')
print(f'Model : {args.model}')
print(f'Track : {args.track}')
print(f'Episodes: {args.episodes} x max {args.max_steps} steps\n')
raw_env = make_env(args.track)
env = VecTransposeImage(DummyVecEnv([lambda: raw_env]))
model = PPO.load(args.model, env=env, device='auto')
all_rewards, all_steps = [], []
for ep in range(args.episodes):
obs, _ = env.reset()
total_reward, steps, done = 0.0, 0, False
pos_samples = []
while not done and steps < args.max_steps:
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += float(reward[0])
steps += 1
done = bool(terminated[0] or truncated[0])
if steps % 100 == 0:
raw_info = info[0] if isinstance(info, (list,tuple)) else info
pos = raw_info.get('pos') if isinstance(raw_info, dict) else None
if pos is not None:
pos_samples.append(np.array(list(pos)[:3]))
# Shuttle detection
note = ''
if len(pos_samples) >= 3:
net = np.linalg.norm(pos_samples[-1] - pos_samples[0])
tot = sum(np.linalg.norm(pos_samples[i+1]-pos_samples[i])
for i in range(len(pos_samples)-1))
eff = net/tot if tot > 0.1 else 1.0
if eff < 0.3 and steps >= 500:
note = f' ⚠️ SHUTTLE EXPLOIT? macro_eff={eff:.2f}'
status = '✅ RAN FULL EVAL' if steps >= args.max_steps else '❌ CRASHED'
print(f' ep{ep+1}: {total_reward:.1f} reward / {steps} steps '
f'({total_reward/max(steps,1):.2f}/step) {status}{note}')
all_rewards.append(total_reward)
all_steps.append(steps)
time.sleep(0.5)
print(f'\n Mean reward : {np.mean(all_rewards):.1f}')
print(f' Mean steps : {np.mean(all_steps):.0f}')
print(f' {"✅ DRIVES" if np.mean(all_steps) > 500 else "❌ CRASHES"}')
env.close()
time.sleep(2)
if __name__ == '__main__':
main()