Source code for compiler_gym.wrappers.core

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, Optional, Union

import gym

from compiler_gym.envs import CompilerEnv
from compiler_gym.spaces.reward import Reward
from compiler_gym.util.gym_type_hints import ObservationType, StepType
from compiler_gym.views import ObservationSpaceSpec


[docs]class CompilerEnvWrapper(gym.Wrapper): """Wraps a :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>` environment to allow a modular transformation. This class is the base class for all wrappers. This class must be used rather than :code:`gym.Wrapper` to support the CompilerGym API extensions such as the :code:`fork()` method. """
[docs] def __init__(self, env: CompilerEnv): """Constructor. :param env: The environment to wrap. :raises TypeError: If :code:`env` is not a :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`. """ # No call to gym.Wrapper superclass constructor here because we need to # avoid setting the observation_space member variable, which in the # CompilerEnv class is a property with a custom setter. Instead we set # the observation_space_spec directly. self.env = env self.action_space = self.env.action_space self.reward_range = self.env.reward_range self.metadata = self.env.metadata
def step(self, action, observations=None, rewards=None): return self.env.step(action, observations=observations, rewards=rewards) def reset(self, *args, **kwargs) -> ObservationType: return self.env.reset(*args, **kwargs) def fork(self) -> CompilerEnv: return type(self)(env=self.env.fork()) @property def observation_space(self): if self.env.observation_space_spec: return self.env.observation_space_spec.space @observation_space.setter def observation_space( self, observation_space: Optional[Union[str, ObservationSpaceSpec]] ) -> None: self.env.observation_space = observation_space @property def observation_space_spec(self): return self.env.observation_space_spec @observation_space_spec.setter def observation_space_spec( self, observation_space_spec: Optional[ObservationSpaceSpec] ) -> None: self.env.observation_space_spec = observation_space_spec @property def reward_space(self) -> Optional[Reward]: return self.env.reward_space @reward_space.setter def reward_space(self, reward_space: Optional[Union[str, Reward]]) -> None: self.env.reward_space = reward_space
[docs]class ActionWrapper(CompilerEnvWrapper): """Wraps a :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>` environment to allow an action space transformation. """ def step( self, action: Union[int, Iterable[int]], observations=None, rewards=None ) -> StepType: return self.env.step( self.action(action), observations=observations, rewards=rewards )
[docs] def action(self, action): """Translate the action to the new space.""" raise NotImplementedError
[docs] def reverse_action(self, action): """Translate an action from the new space to the wrapped space.""" raise NotImplementedError
[docs]class ObservationWrapper(CompilerEnvWrapper): """Wraps a :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>` environment to allow an observation space transformation. """ def reset(self, *args, **kwargs): observation = self.env.reset(*args, **kwargs) return self.observation(observation) def step(self, *args, **kwargs): observation, reward, done, info = self.env.step(*args, **kwargs) return self.observation(observation), reward, done, info
[docs] def observation(self, observation): """Translate an observation to the new space.""" raise NotImplementedError
[docs]class RewardWrapper(CompilerEnvWrapper): """Wraps a :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>` environment to allow an reward space transformation. """ def reset(self, *args, **kwargs): return self.env.reset(*args, **kwargs) def step(self, *args, **kwargs): observation, reward, done, info = self.env.step(*args, **kwargs) # Undo the episode_reward update and reapply it once we have transformed # the reward. # # TODO(cummins): Refactor step() so that we don't have to do this # recalculation of episode_reward, as this is prone to errors if, say, # the base reward returns NaN or an invalid type. if reward is not None and self.episode_reward is not None: self.unwrapped.episode_reward -= reward reward = self.reward(reward) self.unwrapped.episode_reward += reward return observation, reward, done, info
[docs] def reward(self, reward): """Translate a reward to the new space.""" raise NotImplementedError