101 lines
3.7 KiB
Python
101 lines
3.7 KiB
Python
"""
|
|
eval_on_track.py — Evaluate a saved model on any track, zero-shot.
|
|
|
|
Usage:
|
|
python3 eval_on_track.py --model models/wave4-champion/model.zip \
|
|
--track donkey-generated-roads-v0 \
|
|
--episodes 3 --max-steps 3000
|
|
|
|
This is the proper zero-shot evaluation: load a trained model, connect
|
|
to a track it has never seen, run N episodes, report reward and steps.
|
|
"""
|
|
import argparse, os, sys, time
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
|
|
import gymnasium as gym
|
|
|
|
from reward_wrapper import SpeedRewardWrapper
|
|
from donkeycar_sb3_runner import ThrottleClampWrapper
|
|
from multitrack_runner import StuckTerminationWrapper
|
|
|
|
THROTTLE_MIN = 0.2
|
|
SPEED_SCALE = 0.1
|
|
|
|
def make_env(env_id):
|
|
raw = gym.make(env_id)
|
|
env = ThrottleClampWrapper(raw, throttle_min=THROTTLE_MIN)
|
|
env = StuckTerminationWrapper(env, stuck_steps=80, min_displacement=0.5)
|
|
env = SpeedRewardWrapper(env, speed_scale=SPEED_SCALE)
|
|
return env
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model', required=True, help='Path to model.zip')
|
|
parser.add_argument('--track', default='donkey-generated-roads-v0')
|
|
parser.add_argument('--episodes', type=int, default=3)
|
|
parser.add_argument('--max-steps', type=int, default=3000)
|
|
args = parser.parse_args()
|
|
|
|
print(f'\n=== Zero-Shot Eval ===')
|
|
print(f'Model : {args.model}')
|
|
print(f'Track : {args.track}')
|
|
print(f'Episodes: {args.episodes} x max {args.max_steps} steps\n')
|
|
|
|
raw_env = make_env(args.track)
|
|
env = VecTransposeImage(DummyVecEnv([lambda: raw_env]))
|
|
|
|
model = PPO.load(args.model, env=env, device='auto')
|
|
|
|
all_rewards, all_steps = [], []
|
|
for ep in range(args.episodes):
|
|
obs = env.reset()
|
|
total_reward, steps, done = 0.0, 0, False
|
|
pos_samples = []
|
|
|
|
while not done and steps < args.max_steps:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
result = env.step(action)
|
|
if len(result) == 5:
|
|
obs, reward, terminated, truncated, info = result
|
|
done = bool(terminated[0] or truncated[0])
|
|
else:
|
|
obs, reward, done_arr, info = result
|
|
done = bool(done_arr[0])
|
|
total_reward += float(reward[0])
|
|
steps += 1
|
|
if steps % 100 == 0:
|
|
raw_info = info[0] if isinstance(info, (list,tuple)) else info
|
|
pos = raw_info.get('pos') if isinstance(raw_info, dict) else None
|
|
if pos is not None:
|
|
pos_samples.append(np.array(list(pos)[:3]))
|
|
|
|
# Shuttle detection
|
|
note = ''
|
|
if len(pos_samples) >= 3:
|
|
net = np.linalg.norm(pos_samples[-1] - pos_samples[0])
|
|
tot = sum(np.linalg.norm(pos_samples[i+1]-pos_samples[i])
|
|
for i in range(len(pos_samples)-1))
|
|
eff = net/tot if tot > 0.1 else 1.0
|
|
if eff < 0.3 and steps >= 500:
|
|
note = f' ⚠️ SHUTTLE EXPLOIT? macro_eff={eff:.2f}'
|
|
|
|
status = '✅ RAN FULL EVAL' if steps >= args.max_steps else '❌ CRASHED'
|
|
print(f' ep{ep+1}: {total_reward:.1f} reward / {steps} steps '
|
|
f'({total_reward/max(steps,1):.2f}/step) {status}{note}')
|
|
all_rewards.append(total_reward)
|
|
all_steps.append(steps)
|
|
time.sleep(0.5)
|
|
|
|
print(f'\n Mean reward : {np.mean(all_rewards):.1f}')
|
|
print(f' Mean steps : {np.mean(all_steps):.0f}')
|
|
print(f' {"✅ DRIVES" if np.mean(all_steps) > 500 else "❌ CRASHES"}')
|
|
|
|
env.close()
|
|
time.sleep(2)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|