CLEAN: Robust multi-episode RL runner, no legacy save/model logic; outer loop points to project dir runner.
This commit is contained in:
parent
c98bc7ef38
commit
4a4e61d463
|
|
@ -51,7 +51,7 @@ def run_sweep():
|
||||||
mlog.write(f"[MONITOR {time.ctime()}] Launching inner RL job for config {i+1} repeat {r+1}\n")
|
mlog.write(f"[MONITOR {time.ctime()}] Launching inner RL job for config {i+1} repeat {r+1}\n")
|
||||||
mlog.flush()
|
mlog.flush()
|
||||||
cmd = [
|
cmd = [
|
||||||
'python3', '/home/paulh/.pi/agent/donkeycar_sb3_runner.py',
|
'python3', '/home/paulh/projects/donkeycar-rl-autoresearch/agent/donkeycar_sb3_runner.py',
|
||||||
'--agent', 'dqn',
|
'--agent', 'dqn',
|
||||||
'--env', 'donkey-generated-roads-v0',
|
'--env', 'donkey-generated-roads-v0',
|
||||||
'--timesteps', str(params['timesteps']),
|
'--timesteps', str(params['timesteps']),
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,42 @@
|
||||||
import argparse
|
import argparse
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import gym_donkeycar
|
import gym_donkeycar
|
||||||
from stable_baselines3 import DQN, PPO
|
import argparse
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
import gymnasium as gym
|
||||||
import os
|
import gym_donkeycar
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from discretize_action import DiscretizedActionWrapper
|
from discretize_action import DiscretizedActionWrapper
|
||||||
|
|
||||||
AGENT_MAP = {
|
if __name__ == "__main__":
|
||||||
'dqn': DQN,
|
parser = argparse.ArgumentParser(description="Run multi-episode RL test loop for DonkeyCar Gym. No model training/saving.")
|
||||||
'ppo': PPO, # For later extension
|
parser.add_argument('--agent', type=str, default='dqn', help='RL agent type (only dqn supported in this runner)')
|
||||||
}
|
parser.add_argument('--env', type=str, default='donkey-generated-roads-v0', help='Gym/Gymnasium env ID')
|
||||||
|
parser.add_argument('--timesteps', type=int, default=5000, help='Unused (for outer loop compatibility)')
|
||||||
def run_training(env_id, agent_name, total_timesteps, reward_shaping=False, eval_episodes=10, log_dir=None, seed=None, dqn_discretize=True, n_steer=3, n_throttle=3):
|
parser.add_argument('--eval-episodes', type=int, default=10, help='Episodes for evaluation')
|
||||||
assert agent_name in AGENT_MAP, f"Agent '{agent_name}' not recognized. Available: {list(AGENT_MAP.keys())}"
|
parser.add_argument('--log-dir', type=str, default=None, help='Unused (kept for arg compatibility)')
|
||||||
AgentClass = AGENT_MAP[agent_name]
|
parser.add_argument('--seed', type=int, default=None, help='Optional seed')
|
||||||
|
parser.add_argument('--n-steer', type=int, default=3, help='Number of steer bins (DQN only)')
|
||||||
|
parser.add_argument('--n-throttle', type=int, default=3, help='Number of throttle bins (DQN only)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
print('[SB3 Runner] Starting: Connecting to sim…', flush=True)
|
print('[SB3 Runner] Starting: Connecting to sim…', flush=True)
|
||||||
start = time.time()
|
|
||||||
try:
|
try:
|
||||||
env = gym.make(env_id)
|
env = gym.make(args.env)
|
||||||
print(f'[SB3 Runner][MONITOR] Connected to gym env. {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][MONITOR] Connected to gym env. {time.ctime()}', flush=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'[SB3 Runner][MONITOR ALERT] Failed to connect to sim: {str(e)}', flush=True)
|
print(f'[SB3 Runner][MONITOR ALERT] Failed to connect to sim: {str(e)}', flush=True)
|
||||||
sys.exit(100)
|
sys.exit(100)
|
||||||
if agent_name == 'dqn' and dqn_discretize:
|
if args.agent == 'dqn':
|
||||||
env = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
|
env = DiscretizedActionWrapper(env, n_steer=args.n_steer, n_throttle=args.n_throttle)
|
||||||
print(f'[SB3 Runner][MONITOR] Action discretization: steer={n_steer}, throttle={n_throttle}. {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][MONITOR] Action discretization: steer={args.n_steer}, throttle={args.n_throttle}. {time.ctime()}', flush=True)
|
||||||
EPISODES = 10 # Number of full env.reset runs for this special test
|
EPISODES = args.eval_episodes
|
||||||
try:
|
try:
|
||||||
ep_rewards = []
|
ep_rewards = []
|
||||||
for episode in range(EPISODES):
|
for episode in range(EPISODES):
|
||||||
ep_reward = 0.0
|
ep_reward = 0.0
|
||||||
if seed is not None:
|
if args.seed is not None:
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=args.seed)
|
||||||
else:
|
else:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print(f'[SB3 Runner][TEST] Episode {episode+1}/{EPISODES} - reset at {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][TEST] Episode {episode+1}/{EPISODES} - reset at {time.ctime()}', flush=True)
|
||||||
|
|
@ -43,7 +45,7 @@ def run_training(env_id, agent_name, total_timesteps, reward_shaping=False, eval
|
||||||
while not done:
|
while not done:
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
result = env.step(action)
|
result = env.step(action)
|
||||||
if len(result) in (4, 5): # obs, reward, done, info or obs, reward, done, truncated, info
|
if len(result) in (4, 5):
|
||||||
if len(result) == 4:
|
if len(result) == 4:
|
||||||
obs, reward, done, info = result
|
obs, reward, done, info = result
|
||||||
else:
|
else:
|
||||||
|
|
@ -66,7 +68,6 @@ def run_training(env_id, agent_name, total_timesteps, reward_shaping=False, eval
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'[SB3 Runner][MONITOR ALERT] Exception during episodes: {str(e)} {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][MONITOR ALERT] Exception during episodes: {str(e)} {time.ctime()}', flush=True)
|
||||||
sys.exit(102)
|
sys.exit(102)
|
||||||
# ---- NEW: Ensure teardown and sleep for race avoidance ----
|
|
||||||
print(f'[SB3 Runner][MONITOR] Calling env.close() at {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][MONITOR] Calling env.close() at {time.ctime()}', flush=True)
|
||||||
try:
|
try:
|
||||||
env.close()
|
env.close()
|
||||||
|
|
@ -76,37 +77,3 @@ def run_training(env_id, agent_name, total_timesteps, reward_shaping=False, eval
|
||||||
print(f'[SB3 Runner][MONITOR] Waiting 2s before process exit to avoid race. {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][MONITOR] Waiting 2s before process exit to avoid race. {time.ctime()}', flush=True)
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
print(f'[SB3 Runner][MONITOR] Exiting RL runner at {time.ctime()}', flush=True)
|
print(f'[SB3 Runner][MONITOR] Exiting RL runner at {time.ctime()}', flush=True)
|
||||||
|
|
||||||
# Save if needed
|
|
||||||
if log_dir:
|
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
|
||||||
save_path = os.path.join(log_dir, f'{agent_name}_model')
|
|
||||||
model.save(save_path)
|
|
||||||
print(f"[SB3 Runner] Model saved to {save_path}")
|
|
||||||
|
|
||||||
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=eval_episodes, return_episode_rewards=False)
|
|
||||||
print(f"[SB3 Runner] Eval episodes={eval_episodes}: mean_reward={mean_reward:.3f} std={std_reward:.3f}")
|
|
||||||
return mean_reward, std_reward
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Train/Eval an RL agent on DonkeyCar Gym using SB3.")
|
|
||||||
parser.add_argument('--agent', type=str, default='dqn', choices=AGENT_MAP.keys(), help='RL agent type')
|
|
||||||
parser.add_argument('--env', type=str, default='donkey-generated-roads-v0', help='Gym/Gymnasium env ID')
|
|
||||||
parser.add_argument('--timesteps', type=int, default=5000, help='Total training timesteps')
|
|
||||||
parser.add_argument('--eval-episodes', type=int, default=10, help='Episodes for evaluation after training')
|
|
||||||
parser.add_argument('--log-dir', type=str, default=None, help='Directory to save models')
|
|
||||||
parser.add_argument('--seed', type=int, default=None, help='Random seed')
|
|
||||||
parser.add_argument('--n-steer', type=int, default=3, help='Number of steer bins (DQN only)')
|
|
||||||
parser.add_argument('--n-throttle', type=int, default=3, help='Number of throttle bins (DQN only)')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
run_training(
|
|
||||||
env_id=args.env,
|
|
||||||
agent_name=args.agent,
|
|
||||||
total_timesteps=args.timesteps,
|
|
||||||
eval_episodes=args.eval_episodes,
|
|
||||||
log_dir=args.log_dir,
|
|
||||||
seed=args.seed,
|
|
||||||
n_steer=args.n_steer,
|
|
||||||
n_throttle=args.n_throttle
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue