rlberry.agents.stable_baselines.StableBaselinesAgent

class rlberry.agents.stable_baselines.StableBaselinesAgent(env: Env | Tuple[Callable[[...], Env], Mapping[str, Any]], algo_cls: Type[BaseAlgorithm] | None = None, policy: str | Type[BasePolicy] = 'MlpPolicy', verbose=0, tensorboard_log: str | None = None, eval_env: Env | Tuple[Callable[[...], Env], Mapping[str, Any]] | None = None, copy_env: bool = True, seeder: Seeder | None = None, output_dir: str | None = None, _execution_metadata: ExecutionMetadata | None = None, _default_writer_kwargs: dict | None = None, _thread_shared_data: dict | None = None, **kwargs)[source]

Bases: AgentWithSimplePolicy

Wraps an StableBaselines3 Algorithm with a rlberry Agent.

Parameters:
env: gymnasium.Env

Environment

algo_cls: stable_baselines3 Algorithm class

Class of the algorithm to wrap (e.g. A2C)

policy: str or stable_baselines3 Policy class

Policy to use (e.g. MlpPolicy)

verbose: int

Verbosity level: 0 none, 1 training information, 2 tensorflow debug

tensorboard_log: str

Path to the directory where to save the tensorboard logs (if None, no logging)

eval_envgymnasium.Env or tuple (constructor, kwargs)

Environment on which to evaluate the agent. If None, copied from env.

copy_envbool

If true, makes a deep copy of the environment.

seederSeeder, int, or None

Seeder/seed for random number generation.

output_dirstr or Path

Directory that the agent can use to store data.

_execution_metadataExecutionMetadata, optional

Extra information about agent execution (e.g. about which is the process id where the agent is running). Used by ExperimentManager.

_default_writer_kwargsdict, optional

Parameters to initialize DefaultWriter (attribute self.writer). Used by ExperimentManager.

**kwargsKeyword Arguments

Arguments to be passed to the algo_cls constructor. (Class of algorithm to wrap)

Attributes:
output_dir

Directory that the agent can use to store data.

rng

Random number generator.

thread_shared_data

Data shared by agent instances among different threads.

unique_id

Unique identifier for the agent instance.

writer

Writer object to log the output (e.g.

Notes

Other keyword arguments are passed to the algorithm’s constructor.

Examples

>>> from rlberry.envs import gym_make
>>> from stable_baselines3 import A2C
>>> from rlberry.agents import StableBaselinesAgent
>>> env_ctor, env_kwargs = gym_make, dict(id="CartPole-v1")
>>> env = env_ctor(**env_kwargs)
>>> agent = StableBaselinesAgent(env, A2C, "MlpPolicy", verbose=1)

Methods

eval([eval_horizon, n_simulations, gamma])

Monte-Carlo policy evaluation [1] method to estimate the mean discounted reward using the current policy on the evaluation environment.

fit(budget[, tb_log_name, reset_num_timesteps])

Fit the agent.

get_params([deep])

Get parameters for this agent.

load(filename, **kwargs)

Load agent object. Parameters ---------- filename: str Path to the object (pickle) to load. **kwargs: Keyword Arguments Arguments required by the __init__ method of the Agent subclass to load.

policy(observation[, deterministic])

Get the policy for the given observation.

reseed([seed_seq])

Reseed the agent.

sample_parameters(trial)

Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)

save(filename)

Save the agent to a file. Parameters ---------- filename: Path or str File in which to save the Agent. Returns ------- pathlib.Path If save() is successful, a Path object corresponding to the filename is returned. Otherwise, None is returned.

set_logger(logger)

Set the logger to a custom SB3 logger.

set_writer(writer)

set self._writer.

eval(eval_horizon=100000, n_simulations=10, gamma=1.0, **kwargs)

Monte-Carlo policy evaluation [1] method to estimate the mean discounted reward using the current policy on the evaluation environment.

Parameters:
eval_horizonint, optional, default: 10**5

Maximum episode length, representing the horizon for each simulation.

n_simulationsint, optional, default: 10

Number of Monte Carlo simulations to perform for the evaluation.

gammafloat, optional, default: 1.0

Discount factor for future rewards.

Returns:
float

The mean value over ‘n_simulations’ of the sum of rewards obtained in each simulation.

References

[1] (1,2)

Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction. MIT Press.

fit(budget: int, tb_log_name: str | None = None, reset_num_timesteps: bool = False, **kwargs)[source]

Fit the agent.

Parameters:
budget: int

Number of timesteps to train the agent for.

tb_log_name: str

Name of the log to use in tensorboard.

reset_num_timesteps: bool

Whether to reset or not the :code: num_timesteps attribute

**kwargs: Keyword Arguments

Extra arguments required by the ‘learn’ method of the Wrapped Agent.

get_params(deep=True)

Get parameters for this agent.

Parameters:
deepbool, default=True

If True, will return the parameters for this agent and contained subobjects.

Returns:
paramsdict

Parameter names mapped to their values.

classmethod load(filename, **kwargs)[source]

Load agent object. Parameters ———- filename: str

Path to the object (pickle) to load.

**kwargs: Keyword Arguments

Arguments required by the __init__ method of the Agent subclass to load.

property output_dir

Directory that the agent can use to store data.

policy(observation, deterministic=True)[source]

Get the policy for the given observation.

Parameters:
observation:

Observation to get the policy for.

deterministic: bool

Whether to return a deterministic policy or not.

Returns:
The chosen action.
reseed(seed_seq=None)[source]

Reseed the agent.

property rng

Random number generator.

classmethod sample_parameters(trial)

Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)

Note: only the kwargs sent to __init__ are optimized. Make sure to include in the Agent constructor all “optimizable” parameters.

Parameters:
trial: optuna.trial
save(filename)[source]

Save the agent to a file. Parameters ———- filename: Path or str

File in which to save the Agent.

Returns

pathlib.Path

If save() is successful, a Path object corresponding to the filename is returned. Otherwise, None is returned.

set_logger(logger)[source]

Set the logger to a custom SB3 logger.

Parameters:
logger: stable_baselines3.common.logger.Logger

The logger to use.

set_writer(writer)

set self._writer. If is not None, add parameters values to writer.

property thread_shared_data

Data shared by agent instances among different threads.

property unique_id

Unique identifier for the agent instance. Can be used, for example, to create files/directories for the agent to log data safely.

property writer

Writer object to log the output (e.g. tensorboard SummaryWriter)..

Examples using rlberry.agents.stable_baselines.StableBaselinesAgent

Compare PPO and A2C on Acrobot with AdaStop

Compare PPO and A2C on Acrobot with AdaStop