fix: ensure lr_schedule callable set when loading warm-start model (use get_schedule_fn) and update optimizer LR
This commit is contained in:
parent
eb92d119f9
commit
38dd5e9b1d
|
|
@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
|
||||||
|
from stable_baselines3.common.utils import get_schedule_fn
|
||||||
import gymnasium as gym, numpy as np
|
import gymnasium as gym, numpy as np
|
||||||
from donkeycar_sb3_runner import ThrottleClampWrapper
|
from donkeycar_sb3_runner import ThrottleClampWrapper
|
||||||
|
|
||||||
|
|
@ -163,11 +164,18 @@ loaded_env = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=No
|
||||||
if os.path.exists(WARM_PATH):
|
if os.path.exists(WARM_PATH):
|
||||||
log(f'Loading warm-start model from {WARM_PATH} using base throttle_min=0.2 env')
|
log(f'Loading warm-start model from {WARM_PATH} using base throttle_min=0.2 env')
|
||||||
model = PPO.load(WARM_PATH, env=loaded_env, device='cpu')
|
model = PPO.load(WARM_PATH, env=loaded_env, device='cpu')
|
||||||
# override lr and schedules
|
# override lr and schedules — ensure lr_schedule callable exists
|
||||||
model.learning_rate = LR
|
model.learning_rate = LR
|
||||||
model.lr_schedule = model.get_schedule_fn(LR) if hasattr(model,'get_schedule_fn') else None
|
try:
|
||||||
for pg in getattr(getattr(model.policy,'optimizer',None) or [], 'param_groups', []):
|
model.lr_schedule = get_schedule_fn(LR)
|
||||||
pg['lr'] = LR
|
except Exception:
|
||||||
|
model.lr_schedule = None
|
||||||
|
# update optimizer param groups to new LR
|
||||||
|
try:
|
||||||
|
for pg in model.policy.optimizer.param_groups:
|
||||||
|
pg['lr'] = LR
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
# Create the training env using base action space but enforce throttle_floor at runtime
|
# Create the training env using base action space but enforce throttle_floor at runtime
|
||||||
first_throttle_floor = phase_defs[0][1]
|
first_throttle_floor = phase_defs[0][1]
|
||||||
env0 = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=first_throttle_floor)]))
|
env0 = VecTransposeImage(DummyVecEnv([make_env_base(0.2, throttle_floor=first_throttle_floor)]))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue