72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
"""
|
|
Speed-Aware Reward Wrapper for DonkeyCar RL
|
|
============================================
|
|
Replaces the default CTE-only reward with:
|
|
reward = speed * (1.0 - min(abs(cte) / max_cte, 1.0))
|
|
|
|
Falls back to original reward if speed/cte not available in info dict.
|
|
"""
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
|
|
|
|
class SpeedRewardWrapper(gym.Wrapper):
|
|
"""
|
|
Replace DonkeyCar's default reward with a speed-aware version.
|
|
|
|
Reward = speed * (1 - |cte| / max_cte)
|
|
- Maximum when car is fast AND centred on the track
|
|
- Zero when car is at max cross-track error
|
|
- Negative (crash penalty) preserved from original reward when episode ends with failure
|
|
"""
|
|
|
|
def __init__(self, env, max_cte: float = 8.0, crash_penalty: float = -10.0):
|
|
super().__init__(env)
|
|
self.max_cte = max_cte
|
|
self.crash_penalty = crash_penalty
|
|
|
|
def step(self, action):
|
|
result = self.env.step(action)
|
|
|
|
# Handle both 4-tuple (old gym) and 5-tuple (gymnasium) APIs
|
|
if len(result) == 5:
|
|
obs, reward, terminated, truncated, info = result
|
|
done = terminated or truncated
|
|
elif len(result) == 4:
|
|
obs, reward, done, info = result
|
|
terminated = done
|
|
truncated = False
|
|
else:
|
|
raise ValueError(f'Unexpected step() result length: {len(result)}')
|
|
|
|
# Shape the reward using speed and CTE from info
|
|
shaped = self._shape_reward(reward, done, info)
|
|
|
|
if len(result) == 5:
|
|
return obs, shaped, terminated, truncated, info
|
|
else:
|
|
return obs, shaped, done, info
|
|
|
|
def _shape_reward(self, original_reward: float, done: bool, info: dict) -> float:
|
|
"""Compute speed-aware reward, falling back to original if info is unavailable."""
|
|
try:
|
|
speed = float(info.get('speed', None))
|
|
cte = float(info.get('cte', None))
|
|
|
|
if speed is None or cte is None:
|
|
return original_reward
|
|
|
|
# Positive driving reward: fast + centred
|
|
shaped = speed * (1.0 - min(abs(cte) / self.max_cte, 1.0))
|
|
|
|
# Preserve crash penalty (original reward is -1 on crash in DonkeyCar)
|
|
if done and original_reward < 0:
|
|
shaped += self.crash_penalty
|
|
|
|
return shaped
|
|
|
|
except (TypeError, ValueError):
|
|
# info dict doesn't have speed/cte — fall back gracefully
|
|
return original_reward
|