121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
"""
|
||
Tests for discretize_action.py — no simulator required.
|
||
"""
|
||
|
||
import pytest
|
||
import numpy as np
|
||
import sys
|
||
import os
|
||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'agent'))
|
||
|
||
from discretize_action import DiscretizedActionWrapper
|
||
|
||
|
||
import gymnasium as gym
|
||
|
||
|
||
class MockEnv(gym.Env):
|
||
"""Minimal mock gymnasium.Env with continuous Box action space."""
|
||
metadata = {'render_modes': []}
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.action_space = gym.spaces.Box(
|
||
low=np.array([-1.0, 0.0], dtype=np.float32),
|
||
high=np.array([1.0, 1.0], dtype=np.float32),
|
||
)
|
||
self.observation_space = gym.spaces.Box(
|
||
low=0, high=255, shape=(120, 160, 3), dtype=np.uint8
|
||
)
|
||
|
||
def reset(self, seed=None, **kwargs):
|
||
obs = np.zeros((120, 160, 3), dtype=np.uint8)
|
||
return obs, {}
|
||
|
||
def step(self, action):
|
||
obs = np.zeros((120, 160, 3), dtype=np.uint8)
|
||
return obs, 1.0, False, False, {'cte': 0.1, 'speed': 2.5}
|
||
|
||
def close(self):
|
||
pass
|
||
|
||
|
||
# ---- Tests ----
|
||
|
||
def test_wrapper_creates_discrete_action_space():
|
||
env = MockEnv()
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
|
||
assert hasattr(wrapped.action_space, 'n'), "Wrapped env should have discrete action space"
|
||
assert wrapped.action_space.n == 5 * 3
|
||
|
||
|
||
def test_n_steer_n_throttle_product():
|
||
"""Action space size = n_steer × n_throttle."""
|
||
for n_steer in [3, 5, 7, 9]:
|
||
for n_throttle in [2, 3, 5]:
|
||
env = MockEnv()
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
|
||
assert wrapped.action_space.n == n_steer * n_throttle
|
||
|
||
|
||
def test_action_decode_center_steer():
|
||
"""Middle steer action should decode to steer ≈ 0.0."""
|
||
env = MockEnv()
|
||
n_steer, n_throttle = 5, 3
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
|
||
# Middle steer index = n_steer // 2 = 2, any throttle
|
||
center_steer_action = 2 * n_throttle + 0 # steer_idx=2, throttle_idx=0
|
||
continuous = wrapped.action(center_steer_action)
|
||
steer = continuous[0]
|
||
assert abs(steer) < 0.1, f"Center steer should be ~0.0, got {steer}"
|
||
|
||
|
||
def test_action_decode_full_left_steer():
|
||
"""First steer index should decode to steer = -1.0."""
|
||
env = MockEnv()
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
|
||
continuous = wrapped.action(0) # steer_idx=0, throttle_idx=0
|
||
steer = continuous[0]
|
||
assert steer == pytest.approx(-1.0, abs=0.01), f"Full left steer should be -1.0, got {steer}"
|
||
|
||
|
||
def test_action_decode_full_right_steer():
|
||
"""Last steer index should decode to steer = 1.0."""
|
||
env = MockEnv()
|
||
n_steer, n_throttle = 5, 3
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
|
||
last_steer_action = (n_steer - 1) * n_throttle + 0
|
||
continuous = wrapped.action(last_steer_action)
|
||
steer = continuous[0]
|
||
assert steer == pytest.approx(1.0, abs=0.01), f"Full right steer should be 1.0, got {steer}"
|
||
|
||
|
||
def test_action_decode_all_valid():
|
||
"""Every discrete action index should decode to a valid (steer, throttle) pair."""
|
||
env = MockEnv()
|
||
n_steer, n_throttle = 7, 3
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
|
||
for action in range(n_steer * n_throttle):
|
||
continuous = wrapped.action(action)
|
||
steer, throttle = continuous[0], continuous[1]
|
||
assert -1.0 <= steer <= 1.0, f"Steer {steer} out of range for action {action}"
|
||
assert 0.0 <= throttle <= 1.0, f"Throttle {throttle} out of range for action {action}"
|
||
|
||
|
||
def test_step_passes_through():
|
||
"""Wrapped env.step() should work with integer action."""
|
||
env = MockEnv()
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
|
||
wrapped.reset()
|
||
result = wrapped.step(0)
|
||
assert len(result) in (4, 5), "step() should return 4 or 5 values"
|
||
|
||
|
||
def test_reset_works():
|
||
"""Wrapped env.reset() should work."""
|
||
env = MockEnv()
|
||
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
|
||
obs = wrapped.reset()
|
||
assert obs is not None
|