donkeycar-rl-autoresearch/agent/reward_wrapper.py

72 lines
2.4 KiB
Python

"""
Speed-Aware Reward Wrapper for DonkeyCar RL
============================================
Replaces the default CTE-only reward with:
reward = speed * (1.0 - min(abs(cte) / max_cte, 1.0))
Falls back to original reward if speed/cte not available in info dict.
"""
import gymnasium as gym
import numpy as np
class SpeedRewardWrapper(gym.Wrapper):
"""
Replace DonkeyCar's default reward with a speed-aware version.
Reward = speed * (1 - |cte| / max_cte)
- Maximum when car is fast AND centred on the track
- Zero when car is at max cross-track error
- Negative (crash penalty) preserved from original reward when episode ends with failure
"""
def __init__(self, env, max_cte: float = 8.0, crash_penalty: float = -10.0):
super().__init__(env)
self.max_cte = max_cte
self.crash_penalty = crash_penalty
def step(self, action):
result = self.env.step(action)
# Handle both 4-tuple (old gym) and 5-tuple (gymnasium) APIs
if len(result) == 5:
obs, reward, terminated, truncated, info = result
done = terminated or truncated
elif len(result) == 4:
obs, reward, done, info = result
terminated = done
truncated = False
else:
raise ValueError(f'Unexpected step() result length: {len(result)}')
# Shape the reward using speed and CTE from info
shaped = self._shape_reward(reward, done, info)
if len(result) == 5:
return obs, shaped, terminated, truncated, info
else:
return obs, shaped, done, info
def _shape_reward(self, original_reward: float, done: bool, info: dict) -> float:
"""Compute speed-aware reward, falling back to original if info is unavailable."""
try:
speed = float(info.get('speed', None))
cte = float(info.get('cte', None))
if speed is None or cte is None:
return original_reward
# Positive driving reward: fast + centred
shaped = speed * (1.0 - min(abs(cte) / self.max_cte, 1.0))
# Preserve crash penalty (original reward is -1 on crash in DonkeyCar)
if done and original_reward < 0:
shaped += self.crash_penalty
return shaped
except (TypeError, ValueError):
# info dict doesn't have speed/cte — fall back gracefully
return original_reward