fix: keep action-space matching by loading model with base throttle 0.2 and applying runtime throttle_floor wrapper for phase1
This commit is contained in:
parent
41d12dede2
commit
eb92d119f9
|
|
@ -91,10 +91,33 @@ class V5RewardWrapper(gym.Wrapper):
|
||||||
return obs, reward, terminated or force_terminate, info
|
return obs, reward, terminated or force_terminate, info
|
||||||
|
|
||||||
# env factory
|
# env factory
|
||||||
def make_env(throttle_min):
|
def make_env_base(base_throttle=0.2, throttle_floor=None):
|
||||||
|
"""Create env with underlying action space based on base_throttle (must match saved model).
|
||||||
|
If throttle_floor is provided, wrap the env to enforce a minimum throttle at action runtime
|
||||||
|
without changing the action_space (so model loading is compatible).
|
||||||
|
"""
|
||||||
def _init():
|
def _init():
|
||||||
raw = gym.make('donkey-mountain-track-v0', conf={'host': HOST, 'port': PORT})
|
raw = gym.make('donkey-mountain-track-v0', conf={'host': HOST, 'port': PORT})
|
||||||
env = ThrottleClampWrapper(raw, throttle_min=throttle_min)
|
env = ThrottleClampWrapper(raw, throttle_min=base_throttle)
|
||||||
|
# If a runtime throttle floor is requested, apply wrapper that enforces it
|
||||||
|
if throttle_floor is not None:
|
||||||
|
class ThrottleFloorWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, floor):
|
||||||
|
super().__init__(env)
|
||||||
|
self.floor = floor
|
||||||
|
def step(self, action):
|
||||||
|
# action is [steer, throttle]
|
||||||
|
act = np.array(action)
|
||||||
|
# Ensure throttle element >= floor (maps in [-1,1]? assume throttle in [0,1])
|
||||||
|
try:
|
||||||
|
# clamp second element
|
||||||
|
act[1] = max(act[1], self.floor)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return self.env.step(act)
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
env = ThrottleFloorWrapper(env, throttle_floor)
|
||||||
env = V5RewardWrapper(env)
|
env = V5RewardWrapper(env)
|
||||||
return env
|
return env
|
||||||
return _init
|
return _init
|
||||||
|
|
@ -134,27 +157,24 @@ def log(s):
|
||||||
phase_defs = [ (PH1_STEPS, 0.4), (PH2_STEPS, 0.2) ]
|
phase_defs = [ (PH1_STEPS, 0.4), (PH2_STEPS, 0.2) ]
|
||||||
|
|
||||||
# create initial env and model (warm start)
|
# create initial env and model (warm start)
|
||||||
# Important: load the warm-start model using the SAME action space it was trained with
|
# Load model with base action space (throttle_min=0.2). We'll enforce a runtime
|
||||||
# (throttle_min=0.2) so we can then switch envs for phase 1 if needed.
|
# throttle FLOOR during phase 1 via a wrapper, but keep the action space unchanged.
|
||||||
loaded_env = VecTransposeImage(DummyVecEnv([make_env(0.2)]))
|
loaded_env = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=None)]))
|
||||||
if os.path.exists(WARM_PATH):
|
if os.path.exists(WARM_PATH):
|
||||||
log(f'Loading warm-start model from {WARM_PATH} using throttle_min=0.2 env')
|
log(f'Loading warm-start model from {WARM_PATH} using base throttle_min=0.2 env')
|
||||||
model = PPO.load(WARM_PATH, env=loaded_env, device='cpu')
|
model = PPO.load(WARM_PATH, env=loaded_env, device='cpu')
|
||||||
# override lr and schedules
|
# override lr and schedules
|
||||||
model.learning_rate = LR
|
model.learning_rate = LR
|
||||||
model.lr_schedule = model.get_schedule_fn(LR) if hasattr(model,'get_schedule_fn') else None
|
model.lr_schedule = model.get_schedule_fn(LR) if hasattr(model,'get_schedule_fn') else None
|
||||||
for pg in getattr(getattr(model.policy,'optimizer',None) or [], 'param_groups', []):
|
for pg in getattr(getattr(model.policy,'optimizer',None) or [], 'param_groups', []):
|
||||||
pg['lr'] = LR
|
pg['lr'] = LR
|
||||||
# Now create the actual training env with the first throttle setting
|
# Create the training env using base action space but enforce throttle_floor at runtime
|
||||||
first_throttle = phase_defs[0][1]
|
first_throttle_floor = phase_defs[0][1]
|
||||||
env0 = VecTransposeImage(DummyVecEnv([make_env(first_throttle)]))
|
env0 = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=first_throttle_floor)]))
|
||||||
if first_throttle != 0.2:
|
|
||||||
log(f'Switching model to env with throttle_min={first_throttle}')
|
|
||||||
model.set_env(env0)
|
model.set_env(env0)
|
||||||
else:
|
else:
|
||||||
log('No warm-start found — creating fresh model with first throttle')
|
log('No warm-start found — creating fresh model with base throttle_min=0.2')
|
||||||
first_throttle = phase_defs[0][1]
|
env0 = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=phase_defs[0][1])]))
|
||||||
env0 = VecTransposeImage(DummyVecEnv([make_env(first_throttle)]))
|
|
||||||
model = PPO('CnnPolicy', env0, learning_rate=LR, verbose=1, device='cpu')
|
model = PPO('CnnPolicy', env0, learning_rate=LR, verbose=1, device='cpu')
|
||||||
loaded_env.close()
|
loaded_env.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue