Source code for rlberry.envs.gym_make

import gymnasium as gym

from rlberry.envs.basewrapper import Wrapper
import numpy as np
from numpy import ndarray


[docs]def gym_make(id, wrap_spaces=False, **kwargs): """ Same as gym.make, but wraps the environment to ensure unified seeding with rlberry. Parameters ---------- id : str Environment id. wrap_spaces : bool, default = False If true, also wraps observation_space and action_space using classes in rlberry.spaces, that define a reseed() method. **kwargs : keywords arguments Additional arguments to pass to the gymnasium environment constructor. Examples -------- >>> from rlberry.envs import gym_make >>> env_ctor = gym_make >>> env_kwargs = {"id": "CartPole-v1"} >>> env = env_ctor(**env_kwargs) """ if "module_import" in kwargs: __import__(kwargs.pop("module_import")) env = gym.make(id, **kwargs) return Wrapper(env, wrap_spaces=wrap_spaces)
def atari_make(id, seed=None, **kwargs): """ Adaptator to manage Atari Env Parameters ---------- id : str Environment id. seed : int seed for the environment **kwargs : keywords arguments Optional arguments to configure the environment. (render_mode, n_frame_stack, and other arguments for StableBaselines's make_atari_env : https://stable-baselines3.readthedocs.io/en/master/common/env_util.html#stable_baselines3.common.env_util.make_atari_env ) Returns ------- Atari env with wrapper to be used as Gymnasium env. Examples -------- >>> from rlberry.envs.gym_make import atari_make >>> env_ctor = atari_make >>> env_kwargs = {"id": "ALE/Freeway-v5", "atari_wrappers_dict":dict(terminal_on_life_loss=False),"n_frame_stack":5}} >>> env = env_ctor(**env_kwargs) """ from stable_baselines3.common.atari_wrappers import ( # isort:skip ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv, NoopResetEnv, StickyActionEnv, ) from stable_baselines3.common.monitor import Monitor # Default values for Atari_SB3_wrappers noop_max = 30 frame_skip = 4 screen_size = 84 terminal_on_life_loss = False # different from SB3 : some errors with the "terminal_on_life_loss" wrapper : The 'false reset' can lead to make a step on a 'done' environment, then a crash. clip_reward = True action_repeat_probability = 0.0 if "atari_SB3_wrappers_dict" in kwargs.keys(): atari_wrappers_dict = kwargs.pop("atari_SB3_wrappers_dict") if "noop_max" in atari_wrappers_dict.keys(): noop_max = atari_wrappers_dict["noop_max"] if "frame_skip" in atari_wrappers_dict.keys(): frame_skip = atari_wrappers_dict["frame_skip"] if "screen_size" in atari_wrappers_dict.keys(): screen_size = atari_wrappers_dict["screen_size"] if "terminal_on_life_loss" in atari_wrappers_dict.keys(): terminal_on_life_loss = atari_wrappers_dict["terminal_on_life_loss"] if "clip_reward" in atari_wrappers_dict.keys(): clip_reward = atari_wrappers_dict["clip_reward"] if "action_repeat_probability" in atari_wrappers_dict.keys(): action_repeat_probability = atari_wrappers_dict["action_repeat_probability"] render_mode = None if "render_mode" in kwargs.keys(): render_mode = kwargs["render_mode"] kwargs.pop("render_mode", None) if "n_frame_stack" in kwargs.keys(): n_frame_stack = kwargs.pop("n_frame_stack") else: n_frame_stack = 4 env = gym.make(id, render_mode=render_mode) env = Wrapper(env) env = Monitor(env) if action_repeat_probability > 0.0: env = StickyActionEnv(env, action_repeat_probability) if noop_max > 0: env = NoopResetEnv(env, noop_max=noop_max) if frame_skip > 1: env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) if clip_reward: env = ClipRewardEnv(env) env = gym.wrappers.ResizeObservation(env, (screen_size, screen_size)) env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, n_frame_stack) if seed: env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) env = CompatibleWrapper(env) # Wrapper to make it compatible with rlberry env.render_mode = render_mode return env class CompatibleWrapper(Wrapper): def __init__(self, env): super(CompatibleWrapper, self).__init__(env) self.render_mode = None def step(self, action): if type(action) is ndarray and action.size == 1: action = action[0] next_observations, rewards, terminated, truncated, infos = self.env.step(action) return ( np.array(next_observations), rewards, terminated, truncated, infos, ) def reset(self, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) return np.array(obs), infos