rlberry.agents.Agent

class rlberry.agents.Agent(env: Env | Tuple[Callable[[...], Env], Mapping[str, Any]] | None = None, eval_env: Env | Tuple[Callable[[...], Env], Mapping[str, Any]] | None = None, copy_env: bool = True, compress_pickle: bool = True, seeder: Seeder | int | None = None, output_dir: str | None = None, writer_extra: str | None = None, _execution_metadata: ExecutionMetadata | None = None, _default_writer_kwargs: dict | None = None, _thread_shared_data: dict | None = None)[source]

Bases: ABC

Basic interface for agents.

If the ‘inherited class’ from Agent use the torch lib, it is highly recommended to inherit AgentTorch instead.

Note

1 - Abstract Class : cannot be instantiated. The abstract methods have to be overridden by the ‘inherited class’ agent.
2 - Classes that implements this interface can send **kwargs to initiate Agent.__init__(), but the keys must match the parameters.
Parameters:
envgymnasium.Env or tuple (constructor, kwargs)

Environment on which to train the agent.

eval_envgymnasium.Env or tuple (constructor, kwargs)

Environment on which to evaluate the agent. If None, copied from env.

copy_envbool

If true, makes a deep copy of the environment.

compress_picklebool

If true, compress the save files using bz2.

seederSeeder, int, or None

Seeder/seed for random number generation.

output_dirstr or Path

Directory that the agent can use to store data.

_execution_metadataExecutionMetadata, optional

Extra information about agent execution (e.g. about which is the process id where the agent is running). Used by ExperimentManager.

_default_writer_kwargsdict, optional

Parameters to initialize DefaultWriter (attribute self.writer). Used by ExperimentManager.

_thread_shared_datadict, optional

Used by ExperimentManager to share data across Agent instances created in different threads.

Attributes:
namestring

Agent identifier (not necessarily unique).

envgymnasium.Env or tuple (constructor, kwargs)

Environment on which to train the agent.

eval_envgymnasium.Env or tuple (constructor, kwargs)

Environment on which to evaluate the agent. If None, copied from env.

writerobject, default: None

Writer object to log the output (e.g.

seederSeeder, int, or None

Seeder/seed for random number generation.

rngnumpy.random._generator.Generator

Random number generator.

output_dirstr or Path

Directory that the agent can use to store data.

unique_idstr

Unique identifier for the agent instance.

writer_extrastr in {“reward”, “action”, “action_and_reward”},

Scalar that will be recorded in the writer.

thread_shared_datadict

Data shared by agent instances among different threads.

Methods

eval(**kwargs)

Abstract method.

fit(budget, **kwargs)

Abstract method to be overridden by the 'inherited agent'.

get_params([deep])

Get parameters for this agent.

load(filename, **kwargs)

Load agent object from filepath.

reseed([seed_seq])

Get new random number generator for the agent.

sample_parameters(trial)

Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)

save(filename)

Save agent object.

set_writer(writer)

set self._writer.

abstract eval(**kwargs)[source]

Abstract method.

Returns a float measuring the quality of the agent (e.g. MC policy evaluation).

Parameters:
eval_env: object

Environment for evaluation.

**kwargs: Keyword Arguments

Extra parameters specific to the implemented evaluation.

abstract fit(budget: int, **kwargs)[source]

Abstract method to be overridden by the ‘inherited agent’.

Train the agent with a fixed budget, using the provided environment.

Parameters:
budget: int

Computational (or sample complexity) budget. It can be, for instance:

  • The number of timesteps taken by the environment (env.step) or the number of episodes;

  • The number of iterations for algorithms such as value/policy iteration;

  • The number of searches in MCTS (Monte-Carlo Tree Search) algorithms;

among others.

Ideally, calling

fit(budget1)
fit(budget2)

should be equivalent to one call

fit(budget1 + budget2)

This property is required to reduce the time required for hyperparameter optimization (by allowing early stopping), but it is not strictly required elsewhere in the library.

If the agent does not require a budget, set it to -1.

**kwargs: Keyword Arguments

Extra parameters specific to the implemented fit.

get_params(deep=True)[source]

Get parameters for this agent.

Parameters:
deepbool, default=True

If True, will return the parameters for this agent and contained subobjects.

Returns:
paramsdict

Parameter names mapped to their values.

classmethod load(filename, **kwargs)[source]

Load agent object from filepath.

If overridden, save() method must also be overridden.

Parameters:
filename: str

Path to the object (pickle) to load.

**kwargs: Keyword Arguments

Arguments required by the __init__ method of the Agent subclass to load.

property output_dir

Directory that the agent can use to store data.

reseed(seed_seq=None)[source]

Get new random number generator for the agent.

Parameters:
seed_seqnumpy.random.SeedSequence, rlberry.seeding.seeder.Seeder or int, defaultNone

Seed sequence from which to spawn the random number generator. If None, generate random seed. If int, use as entropy for SeedSequence. If seeder, use seeder.seed_seq

property rng

Random number generator.

classmethod sample_parameters(trial)[source]

Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)

Note: only the kwargs sent to __init__ are optimized. Make sure to include in the Agent constructor all “optimizable” parameters.

Parameters:
trial: optuna.trial
save(filename)[source]

Save agent object. By default, the agent is pickled.

If overridden, the load() method must also be overridden.

Before saving, consider setting writer to None if it can’t be pickled (tensorboard writers keep references to files and cannot be pickled).

Note: dill[R5d46c33e8424-1]_ is used when pickle fails (see https://stackoverflow.com/a/25353243, for instance). Pickle is tried first, since it is faster.

Parameters:
filename: Path or str

File in which to save the Agent.

Returns:
pathlib.Path

If save() is successful, a Path object corresponding to the filename is returned. Otherwise, None is returned.

Warning

The returned filename might differ from the input filename: For instance, ..

the method can append the correct suffix to the name before saving.

References

set_writer(writer)[source]

set self._writer. If is not None, add parameters values to writer.

property thread_shared_data

Data shared by agent instances among different threads.

property unique_id

Unique identifier for the agent instance. Can be used, for example, to create files/directories for the agent to log data safely.

property writer

Writer object to log the output (e.g. tensorboard SummaryWriter)..

Examples using rlberry.agents.Agent

Record reward during training and then plot it

Record reward during training and then plot it

Illustration of plotting tools on Bandits

Illustration of plotting tools on Bandits

A demo of Experiment Manager

A demo of Experiment Manager

Checkpointing

Checkpointing

UCB Bandit cumulative regret

UCB Bandit cumulative regret

EXP3 Bandit cumulative regret

EXP3 Bandit cumulative regret

Comparison of Thompson sampling and UCB on Bernoulli and Gaussian bandits

Comparison of Thompson sampling and UCB on Bernoulli and Gaussian bandits

Comparison subplots of various index based bandits algorithms

Comparison subplots of various index based bandits algorithms

A demo of Bandit BAI on a real dataset to select mirrors

A demo of Bandit BAI on a real dataset to select mirrors