fix: ensure lr_schedule callable set when loading warm-start model (use get_schedule_fn) and update optimizer LR
This commit is contained in:
parent
eb92d119f9
commit
38dd5e9b1d
|
|
@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|||
from datetime import datetime
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
|
||||
from stable_baselines3.common.utils import get_schedule_fn
|
||||
import gymnasium as gym, numpy as np
|
||||
from donkeycar_sb3_runner import ThrottleClampWrapper
|
||||
|
||||
|
|
@ -163,11 +164,18 @@ loaded_env = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=No
|
|||
if os.path.exists(WARM_PATH):
|
||||
log(f'Loading warm-start model from {WARM_PATH} using base throttle_min=0.2 env')
|
||||
model = PPO.load(WARM_PATH, env=loaded_env, device='cpu')
|
||||
# override lr and schedules
|
||||
# override lr and schedules — ensure lr_schedule callable exists
|
||||
model.learning_rate = LR
|
||||
model.lr_schedule = model.get_schedule_fn(LR) if hasattr(model,'get_schedule_fn') else None
|
||||
for pg in getattr(getattr(model.policy,'optimizer',None) or [], 'param_groups', []):
|
||||
try:
|
||||
model.lr_schedule = get_schedule_fn(LR)
|
||||
except Exception:
|
||||
model.lr_schedule = None
|
||||
# update optimizer param groups to new LR
|
||||
try:
|
||||
for pg in model.policy.optimizer.param_groups:
|
||||
pg['lr'] = LR
|
||||
except Exception:
|
||||
pass
|
||||
# Create the training env using base action space but enforce throttle_floor at runtime
|
||||
first_throttle_floor = phase_defs[0][1]
|
||||
env0 = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=first_throttle_floor)]))
|
||||
|
|
|
|||
Loading…
Reference in New Issue