Source code for rlberry.wrappers.discretize_state

import numpy as np

import rlberry.spaces as spaces
from rlberry.utils.binsearch import binary_search_nd, unravel_index_uniform_bin
from rlberry.envs import Wrapper


[docs]class DiscretizeStateWrapper(Wrapper): """ Discretize an environment with continuous states and discrete actions. """ def __init__(self, _env, n_bins): # initialize base class super().__init__(_env) self.n_bins = n_bins # initialize bins assert n_bins > 0, "DiscretizeStateWrapper requires n_bins > 0" n_states = 1 tol = 1e-8 self.dim = len(self.env.observation_space.low) n_states = n_bins**self.dim self._bins = [] self._open_bins = [] for dd in range(self.dim): range_dd = ( self.env.observation_space.high[dd] - self.env.observation_space.low[dd] ) epsilon = range_dd / n_bins bins_dd = [] for bb in range(n_bins + 1): val = self.env.observation_space.low[dd] + epsilon * bb bins_dd.append(val) self._open_bins.append(tuple(bins_dd[1:])) bins_dd[-1] += tol # "close" the last interval self._bins.append(tuple(bins_dd)) # set observation space self.observation_space = spaces.Discrete(n_states) # List of discretized states self.discretized_states = np.zeros((self.dim, n_states)) for ii in range(n_states): self.discretized_states[:, ii] = self.get_continuous_state(ii, False)
[docs] def reset(self, seed=None, options=None): obs, info = self.env.reset(seed, options) return self.get_discrete_state(obs), info
[docs] def step(self, action): next_observation, reward, terminated, truncated, info = self.env.step(action) next_observation = binary_search_nd(next_observation, self._bins) return next_observation, reward, terminated, truncated, info
[docs] def sample(self, discrete_state, action): # map disctete state to continuous one assert self.observation_space.contains(discrete_state) continuous_state = self.get_continuous_state(discrete_state, randomize=True) # sample in the true environment next_state, reward, terminated, truncated, info = self.env.sample( continuous_state, action ) # discretize next state next_state = binary_search_nd(next_state, self._bins) return next_state, reward, terminated, truncated, info
def get_discrete_state(self, continuous_state): return binary_search_nd(continuous_state, self._bins) def get_continuous_state(self, discrete_state, randomize=False): assert ( discrete_state >= 0 and discrete_state < self.observation_space.n ), "invalid discrete_state" # get multi-index index = unravel_index_uniform_bin(discrete_state, self.dim, self.n_bins) # get state continuous_state = np.zeros(self.dim) for dd in range(self.dim): continuous_state[dd] = self._bins[dd][index[dd]] if randomize: range_dd = ( self.env.observation_space.high[dd] - self.env.observation_space.low[dd] ) epsilon = range_dd / self.n_bins continuous_state[dd] += epsilon * self.rng.uniform() return continuous_state