rlberry.agents.stable_baselines
.StableBaselinesAgent¶
- class rlberry.agents.stable_baselines.StableBaselinesAgent(env: Env | Tuple[Callable[[...], Env], Mapping[str, Any]], algo_cls: Type[BaseAlgorithm] | None = None, policy: str | Type[BasePolicy] = 'MlpPolicy', verbose=0, tensorboard_log: str | None = None, eval_env: Env | Tuple[Callable[[...], Env], Mapping[str, Any]] | None = None, copy_env: bool = True, seeder: Seeder | None = None, output_dir: str | None = None, _execution_metadata: ExecutionMetadata | None = None, _default_writer_kwargs: dict | None = None, _thread_shared_data: dict | None = None, **kwargs)[source]¶
Bases:
AgentWithSimplePolicy
Wraps an StableBaselines3 Algorithm with a rlberry Agent.
- Parameters:
- env: gymnasium.Env
Environment
- algo_cls: stable_baselines3 Algorithm class
Class of the algorithm to wrap (e.g. A2C)
- policy: str or stable_baselines3 Policy class
Policy to use (e.g. MlpPolicy)
- verbose: int
Verbosity level: 0 none, 1 training information, 2 tensorflow debug
- tensorboard_log: str
Path to the directory where to save the tensorboard logs (if None, no logging)
- eval_envgymnasium.Env or tuple (constructor, kwargs)
Environment on which to evaluate the agent. If None, copied from env.
- copy_envbool
If true, makes a deep copy of the environment.
- seeder
Seeder
, int, or None Seeder/seed for random number generation.
- output_dirstr or Path
Directory that the agent can use to store data.
- writer_extra (through class Agent)str in {“reward”, “action”, “action_and_reward”},
Scalar that will be recorded in the writer.
- _execution_metadataExecutionMetadata, optional
Extra information about agent execution (e.g. about which is the process id where the agent is running). Used by
ExperimentManager
.- _default_writer_kwargsdict, optional
Parameters to initialize
DefaultWriter
(attribute self.writer). Used byExperimentManager
.- **kwargsKeyword Arguments
Arguments to be passed to the algo_cls constructor. (Class of algorithm to wrap)
- Attributes:
output_dir
Directory that the agent can use to store data.
rng
Random number generator.
thread_shared_data
Data shared by agent instances among different threads.
unique_id
Unique identifier for the agent instance.
writer
Writer object to log the output (e.g.
Notes
Other keyword arguments are passed to the algorithm’s constructor.
Examples
>>> from rlberry.envs import gym_make >>> from stable_baselines3 import A2C >>> from rlberry.agents import StableBaselinesAgent >>> env_ctor, env_kwargs = gym_make, dict(id="CartPole-v1") >>> env = env_ctor(**env_kwargs) >>> agent = StableBaselinesAgent(env, A2C, "MlpPolicy", verbose=1)
Methods
eval
([eval_horizon, n_simulations, gamma])Monte-Carlo policy evaluation [1] method to estimate the mean discounted reward using the current policy on the evaluation environment.
fit
(budget[, tb_log_name, reset_num_timesteps])Fit the agent.
get_params
([deep])Get parameters for this agent.
load
(filename, **kwargs)Load agent object.
policy
(observation[, deterministic])Get the policy for the given observation.
reseed
([seed_seq])Reseed the agent.
sample_parameters
(trial)Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)
save
(filename)Save the agent to a file.
set_logger
(logger)Set the logger to a custom SB3 logger.
set_writer
(writer)set self._writer.
- eval(eval_horizon=100000, n_simulations=10, gamma=1.0)¶
Monte-Carlo policy evaluation [1] method to estimate the mean discounted reward using the current policy on the evaluation environment.
- Parameters:
- eval_horizonint, optional, default: 10**5
Maximum episode length, representing the horizon for each simulation.
- n_simulationsint, optional, default: 10
Number of Monte Carlo simulations to perform for the evaluation.
- gammafloat, optional, default: 1.0
Discount factor for future rewards.
- Returns:
- float
The mean value over ‘n_simulations’ of the sum of rewards obtained in each simulation.
References
- fit(budget: int, tb_log_name: str | None = None, reset_num_timesteps: bool = False, **kwargs)[source]¶
Fit the agent.
- Parameters:
- budget: int
Number of timesteps to train the agent for.
- tb_log_name: str
Name of the log to use in tensorboard.
- reset_num_timesteps: bool
Whether to reset or not the :code: num_timesteps attribute
- **kwargs: Keyword Arguments
Extra arguments required by the ‘learn’ method of the Wrapped Agent.
- get_params(deep=True)¶
Get parameters for this agent.
- Parameters:
- deepbool, default=True
If True, will return the parameters for this agent and contained subobjects.
- Returns:
- paramsdict
Parameter names mapped to their values.
- classmethod load(filename, **kwargs)[source]¶
Load agent object.
- Parameters:
- filename: str
Path to the object (pickle) to load.
- **kwargs: Keyword Arguments
Arguments required by the __init__ method of the Agent subclass to load.
- property output_dir¶
Directory that the agent can use to store data.
- policy(observation, deterministic=True)[source]¶
Get the policy for the given observation.
- Parameters:
- observation:
Observation to get the policy for.
- deterministic: bool
Whether to return a deterministic policy or not.
- Returns:
- The chosen action.
- property rng¶
Random number generator.
- classmethod sample_parameters(trial)¶
Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)
Note: only the kwargs sent to __init__ are optimized. Make sure to include in the Agent constructor all “optimizable” parameters.
- Parameters:
- trial: optuna.trial
- save(filename)[source]¶
Save the agent to a file.
- Parameters:
- filename: Path or str
File in which to save the Agent.
- Returns:
- pathlib.Path
If save() is successful, a Path object corresponding to the filename is returned. Otherwise, None is returned.
- set_logger(logger)[source]¶
Set the logger to a custom SB3 logger.
- Parameters:
- logger: stable_baselines3.common.logger.Logger
The logger to use.
- set_writer(writer)¶
set self._writer. If is not None, add parameters values to writer.
Data shared by agent instances among different threads.
- property unique_id¶
Unique identifier for the agent instance. Can be used, for example, to create files/directories for the agent to log data safely.
- property writer¶
Writer object to log the output (e.g. tensorboard SummaryWriter)..
Examples using rlberry.agents.stable_baselines.StableBaselinesAgent
¶
Compare PPO and A2C on Acrobot with AdaStop