Source code for compiler_gym.wrappers.commandline

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Iterable as IterableType
from typing import Dict, Iterable, List, Optional, Union

from compiler_gym.envs import CompilerEnv
from compiler_gym.spaces import Commandline, CommandlineFlag
from compiler_gym.util.gym_type_hints import StepType
from compiler_gym.wrappers.core import ActionWrapper, CompilerEnvWrapper


[docs]class CommandlineWithTerminalAction(CompilerEnvWrapper): """Creates a new action space with a special "end of episode" terminal action at the start. If step() is called with it, the "done" flag is set. """
[docs] def __init__( self, env: CompilerEnv, terminal=CommandlineFlag( name="end-of-episode", flag="# end-of-episode", description="End the episode", ), ): """Constructor. :param env: The environment to wrap. :param terminal: The flag to use as the terminal action. Optional. """ super().__init__(env) if not isinstance(env.action_space, Commandline): raise TypeError( f"Unsupported action space: {type(env.action_space).__name__}" ) # Redefine the action space, inserting the terminal action at the start. self.action_space = Commandline( items=[terminal] + [ CommandlineFlag( name=name, flag=flag, description=description, ) for name, flag, description in zip( env.action_space.names, env.action_space.flags, env.action_space.descriptions, ) ], name=f"{type(self).__name__}<{env.action_space.name}>", )
def step(self, action: int) -> StepType: if isinstance(action, int): end_of_episode = action == 0 action = [] if end_of_episode else action - 1 else: try: index = action.index(0) end_of_episode = True except ValueError: index = len(action) end_of_episode = False action = [a - 1 for a in action[:index]] observation, reward, done, info = self.env.step(action) if end_of_episode and not done: done = True info["terminal_action"] = True return observation, reward, done, info
[docs]class ConstrainedCommandline(ActionWrapper): """Constrains a Commandline action space to a subset of the original space's flags. """
[docs] def __init__( self, env: CompilerEnv, flags: Iterable[str], name: Optional[str] = None ): """Constructor. :param env: The environment to wrap. :param flags: A list of entries from :code:`env.action_space.flags` denoting flags that are available in this wrapped environment. :param name: The name of the new action space. """ super().__init__(env) self._flags = flags if not flags: raise TypeError("No flags provided") if not isinstance(env.action_space, Commandline): raise TypeError( "Can only wrap Commandline action space. " f"Received: {type(env.action_space).__name__}" ) self._forward_translation: List[int] = [self.action_space[f] for f in flags] self._reverse_translation: Dict[int, int] = { v: i for i, v in enumerate(self._forward_translation) } # Redefine the action space using this smaller set of flags. self.action_space = Commandline( items=[ CommandlineFlag( name=env.action_space.names[a], flag=env.action_space.flags[a], description=env.action_space.descriptions[a], ) for a in (env.action_space.flags.index(f) for f in flags) ], name=f"{type(self).__name__}<{name or env.action_space.name}, {len(flags)}>", )
def action(self, action: Union[int, List[int]]): if isinstance(action, IterableType): return [self._forward_translation[a] for a in action] return self._forward_translation[action] def reverse_action(self, action: Union[int, List[int]]): if isinstance(action, IterableType): return [self._reverse_translation[a] for a in action] return self._reverse_translation[action] @property def actions(self) -> List[int]: """Reverse-translate actions back into the constrained space.""" return self.reverse_action(self.env.actions) def fork(self) -> "ConstrainedCommandline": return ConstrainedCommandline( env=self.env.fork(), flags=self._flags, name=self.action_space.name )