80 lines
4.1 KiB
Python
80 lines
4.1 KiB
Python
import argparse
|
|
import gymnasium as gym
|
|
import gym_donkeycar
|
|
import argparse
|
|
import gymnasium as gym
|
|
import gym_donkeycar
|
|
import sys
|
|
import time
|
|
from discretize_action import DiscretizedActionWrapper
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run multi-episode RL test loop for DonkeyCar Gym. No model training/saving.")
|
|
parser.add_argument('--agent', type=str, default='dqn', help='RL agent type (only dqn supported in this runner)')
|
|
parser.add_argument('--env', type=str, default='donkey-generated-roads-v0', help='Gym/Gymnasium env ID')
|
|
parser.add_argument('--timesteps', type=int, default=5000, help='Unused (for outer loop compatibility)')
|
|
parser.add_argument('--eval-episodes', type=int, default=10, help='Episodes for evaluation')
|
|
parser.add_argument('--log-dir', type=str, default=None, help='Unused (kept for arg compatibility)')
|
|
parser.add_argument('--seed', type=int, default=None, help='Optional seed')
|
|
parser.add_argument('--n-steer', type=int, default=3, help='Number of steer bins (DQN only)')
|
|
parser.add_argument('--n-throttle', type=int, default=3, help='Number of throttle bins (DQN only)')
|
|
args = parser.parse_args()
|
|
|
|
print('[SB3 Runner] Starting: Connecting to sim…', flush=True)
|
|
try:
|
|
env = gym.make(args.env)
|
|
print(f'[SB3 Runner][MONITOR] Connected to gym env. {time.ctime()}', flush=True)
|
|
except Exception as e:
|
|
print(f'[SB3 Runner][MONITOR ALERT] Failed to connect to sim: {str(e)}', flush=True)
|
|
sys.exit(100)
|
|
if args.agent == 'dqn':
|
|
env = DiscretizedActionWrapper(env, n_steer=args.n_steer, n_throttle=args.n_throttle)
|
|
print(f'[SB3 Runner][MONITOR] Action discretization: steer={args.n_steer}, throttle={args.n_throttle}. {time.ctime()}', flush=True)
|
|
EPISODES = args.eval_episodes
|
|
try:
|
|
ep_rewards = []
|
|
for episode in range(EPISODES):
|
|
ep_reward = 0.0
|
|
if args.seed is not None:
|
|
obs = env.reset(seed=args.seed)
|
|
else:
|
|
obs = env.reset()
|
|
print(f'[SB3 Runner][TEST] Episode {episode+1}/{EPISODES} - reset at {time.ctime()}', flush=True)
|
|
done = False
|
|
t = 0
|
|
while not done:
|
|
action = env.action_space.sample()
|
|
result = env.step(action)
|
|
if len(result) in (4, 5):
|
|
if len(result) == 4:
|
|
obs, reward, done, info = result
|
|
else:
|
|
obs, reward, done, truncated, info = result
|
|
done = done or truncated
|
|
else:
|
|
print('[SB3 Runner][MONITOR] UNEXPECTED step() result shape!', flush=True)
|
|
break
|
|
ep_reward += reward
|
|
t += 1
|
|
if t % 10 == 0 or done:
|
|
print(f'[SB3 Runner][TEST] Step {t} done={done} reward={reward} {time.ctime()}', flush=True)
|
|
if done:
|
|
print(f'[SB3 Runner][TEST] Episode {episode+1} ended after {t} steps, total_reward={ep_reward} at {time.ctime()}', flush=True)
|
|
break
|
|
ep_rewards.append(ep_reward)
|
|
print(f'[SB3 Runner][TEST] All episode rewards: {ep_rewards}', flush=True)
|
|
if len(ep_rewards) > 0:
|
|
print(f'[SB3 Runner][TEST] mean_reward={sum(ep_rewards)/len(ep_rewards):.4f}', flush=True)
|
|
except Exception as e:
|
|
print(f'[SB3 Runner][MONITOR ALERT] Exception during episodes: {str(e)} {time.ctime()}', flush=True)
|
|
sys.exit(102)
|
|
print(f'[SB3 Runner][MONITOR] Calling env.close() at {time.ctime()}', flush=True)
|
|
try:
|
|
env.close()
|
|
print(f'[SB3 Runner][MONITOR] env.close() complete. {time.ctime()}', flush=True)
|
|
except Exception as e:
|
|
print(f'[SB3 Runner][MONITOR ALERT] Exception during env.close(): {str(e)} {time.ctime()}', flush=True)
|
|
print(f'[SB3 Runner][MONITOR] Waiting 2s before process exit to avoid race. {time.ctime()}', flush=True)
|
|
time.sleep(2)
|
|
print(f'[SB3 Runner][MONITOR] Exiting RL runner at {time.ctime()}', flush=True)
|