Source code for rlberry.envs.interface.model
import gymnasium as gym
import numpy as np
import inspect
from rlberry.seeding import Seeder
import rlberry
logger = rlberry.logger
[docs]class Model(gym.Env):
"""
Base class for an environment model.
Attributes
----------
name : string
environment identifier
observation_space : rlberry.spaces.Space
observation space
action_space : rlberry.spaces.Space
action space
reward_range : tuple
tuple (r_min, r_max) containing the minimum and the maximum reward
seeder : rlberry.seeding.Seeder
Seeder, containing random number generator.
"""
name = ""
def __init__(self):
self.observation_space = None
self.action_space = None
self.reward_range: tuple = (-np.inf, np.inf)
# random number generator
self.seeder = Seeder()
[docs] def reseed(self, seed_seq=None):
"""
Get new random number generator for the model.
Parameters
----------
seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
Seed sequence from which to spawn the random number generator.
If None, generate random seed.
If int, use as entropy for SeedSequence.
If seeder, use seeder.seed_seq
"""
# self.seeder
if seed_seq is None:
self.seeder = self.seeder.spawn()
else:
self.seeder = Seeder(seed_seq)
# spaces
self.observation_space.reseed(self.seeder.seed_seq)
self.action_space.reseed(self.seeder.seed_seq)
[docs] def sample(self, state, action):
"""
Execute a step from a state-action pair.
Parameters
----------
state : object
state from which to sample
action : object
action to take in the environment
Returns
-------
observation : object
reward : float
done : bool
info : dict
"""
raise NotImplementedError("sample() method not implemented.")
[docs] def is_online(self):
"""
Returns true if reset() and step() methods are implemented
"""
logger.warning(
"Checking if Model is\
online calls reset() and step() methods."
)
try:
self.reset()
self.step(self.action_space.sample())
return True
except Exception as ex:
if isinstance(ex, NotImplementedError):
return False
else:
raise
[docs] def is_generative(self):
"""
Returns true if sample() method is implemented
"""
logger.warning(
"Checking if Model is \
generative calls sample() method."
)
try:
self.sample(self.observation_space.sample(), self.action_space.sample())
return True
except Exception as ex:
if isinstance(ex, NotImplementedError):
return False
else:
raise
@classmethod
def _get_param_names(cls):
"""
Get parameter names for the Model
"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
init = getattr(cls.__init__, "deprecated_original", cls.__init__)
if init is object.__init__:
# No explicit constructor to introspect
return []
# introspect the constructor arguments to find the model parameters
# to represent
init_signature = inspect.signature(init)
# Consider the constructor parameters excluding 'self'
parameters = [
p
for p in init_signature.parameters.values()
if p.name != "self" and p.kind != p.VAR_KEYWORD
]
# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])
[docs] def get_params(self, deep=True):
"""
Get parameters for this model.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this model and
contained subobjects.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
out = dict()
for key in self._get_param_names():
value = getattr(self, key)
if deep and hasattr(value, "get_params"):
deep_items = value.get_params().items()
out.update((key + "__" + k, val) for k, val in deep_items)
out[key] = value
return out
@property
def unwrapped(self):
return self
@property
def rng(self):
"""Random number generator."""
return self.seeder.rng