import gymnasium as gym
from rlberry.envs.basewrapper import Wrapper
import numpy as np
from numpy import ndarray
[docs]def gym_make(id, wrap_spaces=False, **kwargs):
    """
    Same as gym.make, but wraps the environment
    to ensure unified seeding with rlberry.
    Parameters
    ----------
    id : str
        Environment id.
    wrap_spaces : bool, default = False
        If true, also wraps observation_space and action_space using classes in rlberry.spaces,
        that define a reseed() method.
    **kwargs : keywords arguments
        Additional arguments to pass to the gymnasium environment constructor.
    Examples
    --------
    >>> from rlberry.envs import gym_make
    >>> env_ctor = gym_make
    >>> env_kwargs = {"id": "CartPole-v1"}
    >>> env = env_ctor(**env_kwargs)
    """
    if "module_import" in kwargs:
        __import__(kwargs.pop("module_import"))
    env = gym.make(id, **kwargs)
    return Wrapper(env, wrap_spaces=wrap_spaces) 
def atari_make(id, seed=None, **kwargs):
    """
    Adaptator to manage Atari Env
    Parameters
    ----------
    id : str
        Environment id.    seed : int
        seed for the environment
    **kwargs : keywords arguments
        Optional arguments to configure the environment.
        (render_mode, n_frame_stack, and other arguments for StableBaselines's make_atari_env : https://stable-baselines3.readthedocs.io/en/master/common/env_util.html#stable_baselines3.common.env_util.make_atari_env )
    Returns
    -------
    Atari env with wrapper to be used as Gymnasium env.
    Examples
    --------
    >>> from rlberry.envs.gym_make import atari_make
    >>> env_ctor = atari_make
    >>> env_kwargs = {"id": "ALE/Freeway-v5", "atari_wrappers_dict":dict(terminal_on_life_loss=False),"n_frame_stack":5}}
    >>> env = env_ctor(**env_kwargs)
    """
    from stable_baselines3.common.atari_wrappers import (  # isort:skip
        ClipRewardEnv,
        EpisodicLifeEnv,
        FireResetEnv,
        MaxAndSkipEnv,
        NoopResetEnv,
        NoopResetEnv,
        StickyActionEnv,
    )
    from stable_baselines3.common.monitor import Monitor
    # Default values for Atari_SB3_wrappers
    noop_max = 30
    frame_skip = 4
    screen_size = 84
    terminal_on_life_loss = False  # different from SB3 : some errors with the "terminal_on_life_loss" wrapper : The 'false reset' can lead to make a step on a 'done' environment, then a crash.
    clip_reward = True
    action_repeat_probability = 0.0
    if "atari_SB3_wrappers_dict" in kwargs.keys():
        atari_wrappers_dict = kwargs.pop("atari_SB3_wrappers_dict")
        if "noop_max" in atari_wrappers_dict.keys():
            noop_max = atari_wrappers_dict["noop_max"]
        if "frame_skip" in atari_wrappers_dict.keys():
            frame_skip = atari_wrappers_dict["frame_skip"]
        if "screen_size" in atari_wrappers_dict.keys():
            screen_size = atari_wrappers_dict["screen_size"]
        if "terminal_on_life_loss" in atari_wrappers_dict.keys():
            terminal_on_life_loss = atari_wrappers_dict["terminal_on_life_loss"]
        if "clip_reward" in atari_wrappers_dict.keys():
            clip_reward = atari_wrappers_dict["clip_reward"]
        if "action_repeat_probability" in atari_wrappers_dict.keys():
            action_repeat_probability = atari_wrappers_dict["action_repeat_probability"]
    render_mode = None
    if "render_mode" in kwargs.keys():
        render_mode = kwargs["render_mode"]
        kwargs.pop("render_mode", None)
    if "n_frame_stack" in kwargs.keys():
        n_frame_stack = kwargs.pop("n_frame_stack")
    else:
        n_frame_stack = 4
    env = gym.make(id, render_mode=render_mode)
    env = Wrapper(env)
    env = Monitor(env)
    if action_repeat_probability > 0.0:
        env = StickyActionEnv(env, action_repeat_probability)
    if noop_max > 0:
        env = NoopResetEnv(env, noop_max=noop_max)
    if frame_skip > 1:
        env = MaxAndSkipEnv(env, skip=frame_skip)
    if terminal_on_life_loss:
        env = EpisodicLifeEnv(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    if clip_reward:
        env = ClipRewardEnv(env)
    env = gym.wrappers.ResizeObservation(env, (screen_size, screen_size))
    env = gym.wrappers.GrayScaleObservation(env)
    env = gym.wrappers.FrameStack(env, n_frame_stack)
    if seed:
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
    env = CompatibleWrapper(env)  # Wrapper to make it compatible with rlberry
    env.render_mode = render_mode
    return env
class CompatibleWrapper(Wrapper):
    def __init__(self, env):
        super(CompatibleWrapper, self).__init__(env)
        self.render_mode = None
    def step(self, action):
        if type(action) is ndarray and action.size == 1:
            action = action[0]
        next_observations, rewards, terminated, truncated, infos = self.env.step(action)
        return (
            np.array(next_observations),
            rewards,
            terminated,
            truncated,
            infos,
        )
    def reset(self, seed=None, options=None):
        obs, infos = self.env.reset(seed=seed, options=options)
        return np.array(obs), infos