28 lines
1.1 KiB
Python
28 lines
1.1 KiB
Python
import numpy as np
|
|
import gymnasium as gym
|
|
|
|
class DiscretizedActionWrapper(gym.ActionWrapper):
|
|
def __init__(self, env, n_steer=3, n_throttle=3):
|
|
super().__init__(env)
|
|
self.n_steer = n_steer
|
|
self.n_throttle = n_throttle
|
|
|
|
# Define the bins
|
|
self.steer_bins = np.linspace(-1, 1, n_steer)
|
|
self.throttle_bins = np.linspace(0, 1, n_throttle)
|
|
self.action_list = [(s, t) for s in self.steer_bins for t in self.throttle_bins]
|
|
self.action_space = gym.spaces.Discrete(len(self.action_list))
|
|
|
|
def action(self, act_idx):
|
|
# Map discrete action index to (steer, throttle)
|
|
steer, throttle = self.action_list[act_idx]
|
|
return np.array([steer, throttle], dtype=np.float32)
|
|
|
|
def reverse_action(self, action):
|
|
# Map from continuous to nearest discrete index (for completeness)
|
|
steer, throttle = action
|
|
steer_idx = np.abs(self.steer_bins - steer).argmin()
|
|
throttle_idx = np.abs(self.throttle_bins - throttle).argmin()
|
|
idx = steer_idx * self.n_throttle + throttle_idx
|
|
return int(idx)
|