Source code for rlberry.wrappers.discretize_state
import numpy as np
import rlberry.spaces as spaces
from rlberry.utils.binsearch import binary_search_nd, unravel_index_uniform_bin
from rlberry.envs import Wrapper
[docs]class DiscretizeStateWrapper(Wrapper):
"""
Discretize an environment with continuous states and discrete actions.
"""
def __init__(self, _env, n_bins):
# initialize base class
super().__init__(_env)
self.n_bins = n_bins
# initialize bins
assert n_bins > 0, "DiscretizeStateWrapper requires n_bins > 0"
n_states = 1
tol = 1e-8
self.dim = len(self.env.observation_space.low)
n_states = n_bins**self.dim
self._bins = []
self._open_bins = []
for dd in range(self.dim):
range_dd = (
self.env.observation_space.high[dd] - self.env.observation_space.low[dd]
)
epsilon = range_dd / n_bins
bins_dd = []
for bb in range(n_bins + 1):
val = self.env.observation_space.low[dd] + epsilon * bb
bins_dd.append(val)
self._open_bins.append(tuple(bins_dd[1:]))
bins_dd[-1] += tol # "close" the last interval
self._bins.append(tuple(bins_dd))
# set observation space
self.observation_space = spaces.Discrete(n_states)
# List of discretized states
self.discretized_states = np.zeros((self.dim, n_states))
for ii in range(n_states):
self.discretized_states[:, ii] = self.get_continuous_state(ii, False)
[docs] def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed, options)
return self.get_discrete_state(obs), info
[docs] def step(self, action):
next_observation, reward, terminated, truncated, info = self.env.step(action)
next_observation = binary_search_nd(next_observation, self._bins)
return next_observation, reward, terminated, truncated, info
[docs] def sample(self, discrete_state, action):
# map disctete state to continuous one
assert self.observation_space.contains(discrete_state)
continuous_state = self.get_continuous_state(discrete_state, randomize=True)
# sample in the true environment
next_state, reward, terminated, truncated, info = self.env.sample(
continuous_state, action
)
# discretize next state
next_state = binary_search_nd(next_state, self._bins)
return next_state, reward, terminated, truncated, info
def get_discrete_state(self, continuous_state):
return binary_search_nd(continuous_state, self._bins)
def get_continuous_state(self, discrete_state, randomize=False):
assert (
discrete_state >= 0 and discrete_state < self.observation_space.n
), "invalid discrete_state"
# get multi-index
index = unravel_index_uniform_bin(discrete_state, self.dim, self.n_bins)
# get state
continuous_state = np.zeros(self.dim)
for dd in range(self.dim):
continuous_state[dd] = self._bins[dd][index[dd]]
if randomize:
range_dd = (
self.env.observation_space.high[dd]
- self.env.observation_space.low[dd]
)
epsilon = range_dd / self.n_bins
continuous_state[dd] += epsilon * self.rng.uniform()
return continuous_state