rlberry.manager
.read_writer_data¶
- rlberry.manager.read_writer_data(data_source, tag=None, preprocess_func=None, id_agent=None)[source]¶
Given a list of ExperimentManager or a folder, read data (corresponding to info) obtained in each episode. The dictionary returned by agents’ .fit() method must contain a key equal to info.
- Parameters:
- data_source
ExperimentManager
, or list ofExperimentManager
or str or list of str If ExperimentManager or list of ExperimentManager, load data from it (the agents must be fitted).
If str, the string must be the string path of a directory, each
subdirectory of this directory must contain pickle files. Load the data from the directory of the latest experiment in date. This str should be equal to the value of the output_dir parameter in
ExperimentManager
.If list of str, each string must be a directory containing pickle files.
Load the data from these pickle files.
Note: the agent’s save function must save its writer at the key _writer. This is the default for rlberry agents.
- tagstr or list of str or None
Tag of data that we want to preprocess.
- preprocess_func: Callable or None or list of Callable or None
Function to apply to ‘tag’ column before returning data. For instance, if tag=episode_rewards,setting preprocess_func=np.cumsum will return cumulative rewards If None, do not preprocess. If tag is a list, preprocess_func must be None or a list of Callable or None that matches the length of tag.
- id_agent: int or None, default=None
If not None, returns the data only for agent ‘id_agent’.
- data_source
- Returns:
- Pandas DataFrame with data from writers.
Examples
>>> from rlberry.agents.torch import A2CAgent, DQNAgent >>> from rlberry.manager import ExperimentManager, read_writer_data >>> from rlberry.envs import gym_make >>> >>> if __name__=="__main__": >>> managers = [ ExperimentManager( >>> agent_class, >>> (gym_make, dict(id="CartPole-v1")), >>> fit_budget=1e4, >>> eval_kwargs=dict(eval_horizon=500), >>> n_fit=1, >>> parallelization="process", >>> mp_context="spawn", >>> seed=42, >>> ) for agent_class in [A2CAgent, DQNAgent]] >>> for manager in managers: >>> manager.fit() >>> data = read_writer_data(managers)