fix(stuck): add CTE-based termination + tighten speed check
StuckTerminationWrapper: new max_cte/max_high_cte_seconds params. If |cte| > threshold for 1s → terminate. Catches car pressed sideways against barrier (lateral drift keeps speed/position checks alive). exp25 params: cte_threshold=3.0m (road half-width), speed_threshold=1.0 (was 0.5 — now catches slow lateral drift), max_low_speed=1.5s (was 2.0s). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
bb889ab4a1
commit
dbc09d12d1
|
|
@ -64,8 +64,10 @@ PROGRESS_PATIENCE = 100
|
||||||
|
|
||||||
MAX_STUCK_SECONDS = 5.0
|
MAX_STUCK_SECONDS = 5.0
|
||||||
MAX_EPISODE_SECONDS = 30.0
|
MAX_EPISODE_SECONDS = 30.0
|
||||||
LOW_SPEED_THRESHOLD = 0.5
|
LOW_SPEED_THRESHOLD = 1.0 # raised from 0.5 — catches slow lateral drift along barrier
|
||||||
MAX_LOW_SPEED_SECONDS = 2.0
|
MAX_LOW_SPEED_SECONDS = 1.5 # reduced from 2.0 — faster termination when slow
|
||||||
|
MAX_CTE_TERMINATION = 3.0 # terminate if |cte| > 3m for 1s (road half-width ~3-4m)
|
||||||
|
MAX_HIGH_CTE_SECONDS = 1.0
|
||||||
|
|
||||||
TRACK_ID = 'donkey-generated-roads-v0'
|
TRACK_ID = 'donkey-generated-roads-v0'
|
||||||
PORT = 9091
|
PORT = 9091
|
||||||
|
|
@ -88,6 +90,8 @@ def make_env(track_id, port):
|
||||||
max_episode_seconds=MAX_EPISODE_SECONDS,
|
max_episode_seconds=MAX_EPISODE_SECONDS,
|
||||||
low_speed_threshold=LOW_SPEED_THRESHOLD,
|
low_speed_threshold=LOW_SPEED_THRESHOLD,
|
||||||
max_low_speed_seconds=MAX_LOW_SPEED_SECONDS,
|
max_low_speed_seconds=MAX_LOW_SPEED_SECONDS,
|
||||||
|
max_cte=MAX_CTE_TERMINATION,
|
||||||
|
max_high_cte_seconds=MAX_HIGH_CTE_SECONDS,
|
||||||
)
|
)
|
||||||
env = SpeedRewardWrapper(
|
env = SpeedRewardWrapper(
|
||||||
env,
|
env,
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,8 @@ class StuckTerminationWrapper(gym.Wrapper):
|
||||||
"""
|
"""
|
||||||
def __init__(self, env, stuck_steps: int = 80, min_displacement: float = 0.5,
|
def __init__(self, env, stuck_steps: int = 80, min_displacement: float = 0.5,
|
||||||
max_stuck_seconds: float = 12.0, max_episode_seconds: float = 30.0,
|
max_stuck_seconds: float = 12.0, max_episode_seconds: float = 30.0,
|
||||||
low_speed_threshold: float = 0.5, max_low_speed_seconds: float = 3.0):
|
low_speed_threshold: float = 0.5, max_low_speed_seconds: float = 3.0,
|
||||||
|
max_cte: float = 5.0, max_high_cte_seconds: float = 1.0):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.stuck_steps = stuck_steps
|
self.stuck_steps = stuck_steps
|
||||||
self.min_displacement = min_displacement
|
self.min_displacement = min_displacement
|
||||||
|
|
@ -156,11 +157,14 @@ class StuckTerminationWrapper(gym.Wrapper):
|
||||||
self.max_episode_seconds = max_episode_seconds
|
self.max_episode_seconds = max_episode_seconds
|
||||||
self.low_speed_threshold = low_speed_threshold
|
self.low_speed_threshold = low_speed_threshold
|
||||||
self.max_low_speed_seconds = max_low_speed_seconds
|
self.max_low_speed_seconds = max_low_speed_seconds
|
||||||
|
self.max_cte = max_cte
|
||||||
|
self.max_high_cte_seconds = max_high_cte_seconds
|
||||||
self._pos_buf: deque = deque(maxlen=stuck_steps)
|
self._pos_buf: deque = deque(maxlen=stuck_steps)
|
||||||
self._last_progress_pos = None
|
self._last_progress_pos = None
|
||||||
self._last_progress_t = None
|
self._last_progress_t = None
|
||||||
self._episode_start_t = None
|
self._episode_start_t = None
|
||||||
self._low_speed_start_t = None
|
self._low_speed_start_t = None
|
||||||
|
self._high_cte_start_t = None
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
self._pos_buf.clear()
|
self._pos_buf.clear()
|
||||||
|
|
@ -168,6 +172,7 @@ class StuckTerminationWrapper(gym.Wrapper):
|
||||||
self._last_progress_t = None
|
self._last_progress_t = None
|
||||||
self._episode_start_t = time.time()
|
self._episode_start_t = time.time()
|
||||||
self._low_speed_start_t = None
|
self._low_speed_start_t = None
|
||||||
|
self._high_cte_start_t = None
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
|
@ -224,6 +229,24 @@ class StuckTerminationWrapper(gym.Wrapper):
|
||||||
else:
|
else:
|
||||||
self._low_speed_start_t = None
|
self._low_speed_start_t = None
|
||||||
|
|
||||||
|
# CTE-based termination: catches car stuck at/past the road edge barrier.
|
||||||
|
# A car pressed sideways against a barrier has high CTE even while the
|
||||||
|
# speed and position checks are fooled by wheel oscillation.
|
||||||
|
if not terminated:
|
||||||
|
try:
|
||||||
|
cte = float(info.get('cte', 0.0) or 0.0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
cte = 0.0
|
||||||
|
if abs(cte) > self.max_cte:
|
||||||
|
if self._high_cte_start_t is None:
|
||||||
|
self._high_cte_start_t = now
|
||||||
|
elif (now - self._high_cte_start_t) > self.max_high_cte_seconds:
|
||||||
|
terminated = True
|
||||||
|
info['stuck_termination'] = True
|
||||||
|
info['stuck_reason'] = 'high_cte_timeout'
|
||||||
|
else:
|
||||||
|
self._high_cte_start_t = None
|
||||||
|
|
||||||
# Hard episode wall-clock limit — fires regardless of car position or sim fps.
|
# Hard episode wall-clock limit — fires regardless of car position or sim fps.
|
||||||
# Catches cars sliding slowly along barriers that keep resetting the
|
# Catches cars sliding slowly along barriers that keep resetting the
|
||||||
# max_stuck_seconds timer by drifting 0.5m at a time.
|
# max_stuck_seconds timer by drifting 0.5m at a time.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue