import gymnasium as gym
from rlberry.seeding import Seeder, safe_reseed
import numpy as np
from rlberry.envs.interface import Model
from rlberry.rendering import RenderInterface
from rlberry.spaces.from_gym import convert_space_from_gym
from rlberry.rendering.utils import video_write, gif_write
[docs]class Wrapper(Model, RenderInterface):
"""
Wraps a given environment, similar to OpenAI gym's wrapper [1,2] (now updated to gymnasium).
Can also be used to wrap gym environments.
Note:
The input environment is not copied (Wrapper.env points
to the input env).
Parameters
----------
env: gymnasium.Env
Environment to be wrapped.
wrap_spaces: bool, default = False
If True, gymnasium.spaces are converted to rlberry.spaces, which defined a reseed() method.
Attributes
----------
env : gymnasium.Env
The wrapped environment
metadata : dict
InitiallThe meatadata of the wrapped environment
render_mode : str
The render_mode of the wrapped environment
See also:
https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python
[1] https://github.com/openai/gym/blob/master/gym/core.py
[2] https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/core.py
"""
def __init__(self, env, wrap_spaces=False):
# Init base class
Model.__init__(self)
# Save reference to env
self.env = env
self.metadata = self.env.metadata
self.render_mode = self.env.render_mode
self.frames = []
if wrap_spaces:
self.observation_space = convert_space_from_gym(self.env.observation_space)
self.action_space = convert_space_from_gym(self.env.action_space)
else:
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
try:
self.reward_range = self.env.reward_range
except AttributeError:
self.reward_range = (-np.inf, np.inf)
@property
def unwrapped(self):
return self.env.unwrapped
@property
def spec(self):
return self.env.spec
@classmethod
def class_name(cls):
return cls.__name__
def __getattr__(self, attr):
"""
The first condition is to avoid infinite recursion when deep copying.
See https://stackoverflow.com/a/47300262
"""
if attr[:2] == "__":
raise AttributeError(attr)
if attr in self.__dict__:
return getattr(self, attr)
return getattr(self.env, attr)
[docs] def reseed(self, seed_seq=None):
# self.seeder
if seed_seq is None:
self.seeder = self.seeder.spawn()
else:
self.seeder = Seeder(seed_seq)
# get a seed for gym environment; spaces are reseeded below.
if isinstance(self.env, Model):
# seed rlberry Model
self.env.reseed(self.seeder)
elif isinstance(self.env, gym.Env):
# seed gym.Env that is not a rlberry Model
seed_val = self.seeder.rng.integers(2**32).item()
self.env.reset(seed=seed_val)
else:
# other
safe_reseed(self.env, self.seeder, reseed_spaces=False)
safe_reseed(self.observation_space, self.seeder)
safe_reseed(self.action_space, self.seeder)
[docs] def reset(self, seed=None, options=None):
if self.env.render_mode == "human":
self.render()
self.frames = []
return self.env.reset(seed=seed, options=options)
[docs] def step(self, action):
if self.render_mode == "human":
self.render()
elif self.render_mode == "rgb_array":
self.frames.append(self.render())
return self.env.step(action)
[docs] def sample(self, state, action):
return self.env.sample(state, action)
[docs] def render(self, **kwargs):
return self.env.render(**kwargs)
[docs] def close(self):
return self.env.close()
def seed(self, seed=None):
# return self.env.seed(seed)
return self.env.reset(seed=seed)
def compute_reward(self, achieved_goal, desired_goal, info):
return self.env.compute_reward(achieved_goal, desired_goal, info)
[docs] def is_online(self):
try:
self.env.reset()
self.env.step(self.env.action_space.sample())
return True
except Exception:
return False
[docs] def is_generative(self):
try:
self.env.sample(
self.env.observation_space.sample(), self.env.action_space.sample()
)
return True
except Exception:
return False
[docs] def get_video(self, **kwargs):
return self.frames
[docs] def save_video(self, filename, framerate=25, **kwargs):
video_data = self.get_video(**kwargs)
video_write(filename, video_data, framerate=framerate)
def save_gif(self, filename, **kwargs):
video_data = self.get_video(**kwargs)
gif_write(filename, video_data)
def __repr__(self):
return str(self)
def __str__(self):
return "<{}{}>".format(type(self).__name__, self.env)