134 lines
4.7 KiB
Python
134 lines
4.7 KiB
Python
"""
|
|
Integration tests for donkeycar_sb3_runner.py — no live simulator required.
|
|
Uses mocked gym environment.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import tempfile
|
|
import pytest
|
|
import numpy as np
|
|
import gymnasium as gym
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'agent'))
|
|
|
|
|
|
class MockGymEnv(gym.Env):
|
|
"""Minimal mock of a DonkeyCar gym environment as a proper gymnasium.Env."""
|
|
metadata = {'render_modes': []}
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255, shape=(120, 160, 3), dtype=np.uint8
|
|
)
|
|
self.action_space = gym.spaces.Box(
|
|
low=np.array([-1.0, 0.0]),
|
|
high=np.array([1.0, 1.0]),
|
|
dtype=np.float32
|
|
)
|
|
self._step_count = 0
|
|
|
|
def reset(self, seed=None, **kwargs):
|
|
self._step_count = 0
|
|
return np.zeros((120, 160, 3), dtype=np.uint8), {}
|
|
|
|
def step(self, action):
|
|
self._step_count += 1
|
|
obs = np.random.randint(0, 255, (120, 160, 3), dtype=np.uint8)
|
|
reward = float(np.random.uniform(0, 2))
|
|
terminated = self._step_count >= 50
|
|
truncated = False
|
|
info = {'speed': 2.0, 'cte': 0.5}
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
def test_make_env_ppo_no_discretization():
|
|
"""PPO should NOT apply DiscretizedActionWrapper."""
|
|
import gymnasium as gym
|
|
|
|
with patch('gymnasium.make', return_value=MockGymEnv()):
|
|
from donkeycar_sb3_runner import make_env
|
|
env = make_env('donkey-generated-roads-v0', 'ppo', n_steer=7, n_throttle=3, reward_shaping=False)
|
|
# PPO env should have Box action space, not Discrete
|
|
assert hasattr(env.action_space, 'shape'), "PPO env should have continuous Box action space"
|
|
|
|
|
|
def test_make_env_dqn_discretization():
|
|
"""DQN should apply DiscretizedActionWrapper."""
|
|
with patch('gymnasium.make', return_value=MockGymEnv()):
|
|
from donkeycar_sb3_runner import make_env
|
|
env = make_env('donkey-generated-roads-v0', 'dqn', n_steer=5, n_throttle=3, reward_shaping=False)
|
|
# DQN env should have Discrete action space
|
|
assert hasattr(env.action_space, 'n'), "DQN env should have Discrete action space"
|
|
assert env.action_space.n == 5 * 3
|
|
|
|
|
|
def test_save_model_creates_zip():
|
|
"""save_model() should create a .zip file at the specified path."""
|
|
mock_model = MagicMock()
|
|
mock_model.save = MagicMock()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
save_dir = os.path.join(tmpdir, 'trial-0001')
|
|
from donkeycar_sb3_runner import save_model
|
|
saved_path = save_model(mock_model, save_dir)
|
|
|
|
# Verify save was called with correct path
|
|
expected_path = os.path.join(save_dir, 'model')
|
|
mock_model.save.assert_called_once_with(expected_path)
|
|
assert saved_path == expected_path + '.zip'
|
|
assert os.path.isdir(save_dir), "Save directory should be created"
|
|
|
|
|
|
def test_save_model_creates_directory():
|
|
"""save_model() should create save_dir if it doesn't exist."""
|
|
mock_model = MagicMock()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
save_dir = os.path.join(tmpdir, 'nested', 'path', 'trial-042')
|
|
assert not os.path.exists(save_dir)
|
|
|
|
from donkeycar_sb3_runner import save_model
|
|
save_model(mock_model, save_dir)
|
|
assert os.path.isdir(save_dir)
|
|
|
|
|
|
def test_teardown_calls_env_close():
|
|
"""teardown() should call env.close() even if it raises."""
|
|
from donkeycar_sb3_runner import teardown
|
|
mock_env = MagicMock()
|
|
mock_env.close.side_effect = RuntimeError("sim disconnected")
|
|
# Should not raise
|
|
teardown(mock_env)
|
|
mock_env.close.assert_called_once()
|
|
|
|
|
|
def test_runner_script_has_no_syntax_errors():
|
|
"""The runner script should import without syntax errors."""
|
|
import importlib.util
|
|
spec_path = os.path.join(os.path.dirname(__file__), '..', 'agent', 'donkeycar_sb3_runner.py')
|
|
with open(spec_path) as f:
|
|
source = f.read()
|
|
compile(source, spec_path, 'exec') # Raises SyntaxError if broken
|
|
|
|
|
|
def test_no_model_save_before_definition():
|
|
"""Runner source must not call model.save() before model is defined."""
|
|
runner_path = os.path.join(os.path.dirname(__file__), '..', 'agent', 'donkeycar_sb3_runner.py')
|
|
with open(runner_path) as f:
|
|
source = f.read()
|
|
|
|
lines = source.split('\n')
|
|
model_defined_line = None
|
|
for i, line in enumerate(lines):
|
|
if 'model = PPO' in line or 'model = DQN' in line:
|
|
model_defined_line = i
|
|
if 'model.save' in line and model_defined_line is None:
|
|
pytest.fail(f"model.save() called before model is defined at line {i+1}: {line}")
|