donkeycar-rl-autoresearch/tests/test_discretize_action.py

121 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tests for discretize_action.py — no simulator required.
"""
import pytest
import numpy as np
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'agent'))
from discretize_action import DiscretizedActionWrapper
import gymnasium as gym
class MockEnv(gym.Env):
"""Minimal mock gymnasium.Env with continuous Box action space."""
metadata = {'render_modes': []}
def __init__(self):
super().__init__()
self.action_space = gym.spaces.Box(
low=np.array([-1.0, 0.0], dtype=np.float32),
high=np.array([1.0, 1.0], dtype=np.float32),
)
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(120, 160, 3), dtype=np.uint8
)
def reset(self, seed=None, **kwargs):
obs = np.zeros((120, 160, 3), dtype=np.uint8)
return obs, {}
def step(self, action):
obs = np.zeros((120, 160, 3), dtype=np.uint8)
return obs, 1.0, False, False, {'cte': 0.1, 'speed': 2.5}
def close(self):
pass
# ---- Tests ----
def test_wrapper_creates_discrete_action_space():
env = MockEnv()
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
assert hasattr(wrapped.action_space, 'n'), "Wrapped env should have discrete action space"
assert wrapped.action_space.n == 5 * 3
def test_n_steer_n_throttle_product():
"""Action space size = n_steer × n_throttle."""
for n_steer in [3, 5, 7, 9]:
for n_throttle in [2, 3, 5]:
env = MockEnv()
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
assert wrapped.action_space.n == n_steer * n_throttle
def test_action_decode_center_steer():
"""Middle steer action should decode to steer ≈ 0.0."""
env = MockEnv()
n_steer, n_throttle = 5, 3
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
# Middle steer index = n_steer // 2 = 2, any throttle
center_steer_action = 2 * n_throttle + 0 # steer_idx=2, throttle_idx=0
continuous = wrapped.action(center_steer_action)
steer = continuous[0]
assert abs(steer) < 0.1, f"Center steer should be ~0.0, got {steer}"
def test_action_decode_full_left_steer():
"""First steer index should decode to steer = -1.0."""
env = MockEnv()
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
continuous = wrapped.action(0) # steer_idx=0, throttle_idx=0
steer = continuous[0]
assert steer == pytest.approx(-1.0, abs=0.01), f"Full left steer should be -1.0, got {steer}"
def test_action_decode_full_right_steer():
"""Last steer index should decode to steer = 1.0."""
env = MockEnv()
n_steer, n_throttle = 5, 3
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
last_steer_action = (n_steer - 1) * n_throttle + 0
continuous = wrapped.action(last_steer_action)
steer = continuous[0]
assert steer == pytest.approx(1.0, abs=0.01), f"Full right steer should be 1.0, got {steer}"
def test_action_decode_all_valid():
"""Every discrete action index should decode to a valid (steer, throttle) pair."""
env = MockEnv()
n_steer, n_throttle = 7, 3
wrapped = DiscretizedActionWrapper(env, n_steer=n_steer, n_throttle=n_throttle)
for action in range(n_steer * n_throttle):
continuous = wrapped.action(action)
steer, throttle = continuous[0], continuous[1]
assert -1.0 <= steer <= 1.0, f"Steer {steer} out of range for action {action}"
assert 0.0 <= throttle <= 1.0, f"Throttle {throttle} out of range for action {action}"
def test_step_passes_through():
"""Wrapped env.step() should work with integer action."""
env = MockEnv()
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
wrapped.reset()
result = wrapped.step(0)
assert len(result) in (4, 5), "step() should return 4 or 5 values"
def test_reset_works():
"""Wrapped env.reset() should work."""
env = MockEnv()
wrapped = DiscretizedActionWrapper(env, n_steer=5, n_throttle=3)
obs = wrapped.reset()
assert obs is not None