rlberry.manager.read_writer_data

rlberry.manager.read_writer_data(data_source, tag=None, preprocess_func=None, id_agent=None)[source]

Given a list of ExperimentManager or a folder, read data (corresponding to info) obtained in each episode. The dictionary returned by agents’ .fit() method must contain a key equal to info.

Parameters:
data_sourceExperimentManager, or list of ExperimentManager or str or list of str
  • If ExperimentManager or list of ExperimentManager, load data from it (the agents must be fitted).

  • If str, the string must be the string path of a directory, each

subdirectory of this directory must contain pickle files. Load the data from the directory of the latest experiment in date. This str should be equal to the value of the output_dir parameter in ExperimentManager.

  • If list of str, each string must be a directory containing pickle files.

Load the data from these pickle files.

Note: the agent’s save function must save its writer at the key _writer. This is the default for rlberry agents.

tagstr or list of str or None

Tag of data that we want to preprocess.

preprocess_func: Callable or None or list of Callable or None

Function to apply to ‘tag’ column before returning data. For instance, if tag=episode_rewards,setting preprocess_func=np.cumsum will return cumulative rewards If None, do not preprocess. If tag is a list, preprocess_func must be None or a list of Callable or None that matches the length of tag.

id_agent: int or None, default=None

If not None, returns the data only for agent ‘id_agent’.

Returns:
Pandas DataFrame with data from writers.

Examples

>>> from rlberry.agents.torch import A2CAgent, DQNAgent
>>> from rlberry.manager import ExperimentManager, read_writer_data
>>> from rlberry.envs import gym_make
>>>
>>> if __name__=="__main__":
>>>     managers = [ ExperimentManager(
>>>         agent_class,
>>>         (gym_make, dict(id="CartPole-v1")),
>>>         fit_budget=1e4,
>>>         eval_kwargs=dict(eval_horizon=500),
>>>         n_fit=1,
>>>         parallelization="process",
>>>         mp_context="spawn",
>>>         seed=42,
>>>          ) for agent_class in [A2CAgent, DQNAgent]]
>>>     for manager in managers:
>>>         manager.fit()
>>>     data = read_writer_data(managers)