fix: ensure lr_schedule callable set when loading warm-start model (use get_schedule_fn) and update optimizer LR

This commit is contained in:
Paul Huliganga 2026-04-19 20:14:35 -04:00
parent eb92d119f9
commit 38dd5e9b1d
1 changed files with 12 additions and 4 deletions

View File

@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from datetime import datetime
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from stable_baselines3.common.utils import get_schedule_fn
import gymnasium as gym, numpy as np
from donkeycar_sb3_runner import ThrottleClampWrapper
@ -163,11 +164,18 @@ loaded_env = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=No
if os.path.exists(WARM_PATH):
log(f'Loading warm-start model from {WARM_PATH} using base throttle_min=0.2 env')
model = PPO.load(WARM_PATH, env=loaded_env, device='cpu')
# override lr and schedules
# override lr and schedules — ensure lr_schedule callable exists
model.learning_rate = LR
model.lr_schedule = model.get_schedule_fn(LR) if hasattr(model,'get_schedule_fn') else None
for pg in getattr(getattr(model.policy,'optimizer',None) or [], 'param_groups', []):
try:
model.lr_schedule = get_schedule_fn(LR)
except Exception:
model.lr_schedule = None
# update optimizer param groups to new LR
try:
for pg in model.policy.optimizer.param_groups:
pg['lr'] = LR
except Exception:
pass
# Create the training env using base action space but enforce throttle_floor at runtime
first_throttle_floor = phase_defs[0][1]
env0 = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=first_throttle_floor)]))