Source code for compiler_gym.spaces.named_discrete

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, List, Optional, Union

from gym.spaces import Discrete


[docs]class NamedDiscrete(Discrete): """An extension of the :code:`Discrete` space in which each point in the space has a name. Additionally, the space itself may have a name. :ivar name: The name of the space. :code:`None` if the space has no name. :vartype name: Optional[str] :ivar names: A list of names for each element in the space. :vartype names: List[str] Example usage: >>> space = NamedDiscrete(["a", "b", "c"]) >>> space.n 3 >>> space["a"] 0 >>> space.names[0] a >>> space.sample() 1 """
[docs] def __init__(self, items: Iterable[str], name: Optional[str] = None): """Constructor. :param items: A list of names for items in the space. :param name: The name of the space. """ self.name = name self.names = [str(x) for x in items] super().__init__(n=len(self.names))
[docs] def __getitem__(self, name: str) -> int: """Lookup the numeric value of a point in the space. :param name: A name. :return: The numeric value. :raises ValueError: If the name is not in the space. """ return self.names.index(name)
def __repr__(self) -> str: return f"NamedDiscrete([{', '.join(self.names)}])"
[docs] def to_string(self, values: Union[int, Iterable[int]]) -> str: """Convert an action, or sequence of actions, to string. :param values: A numeric value, or list of numeric values. :return: A string representing the values. """ if isinstance(values, int): return self.names[values] else: return " ".join([self.names[v] for v in values])
[docs] def from_string(self, values: Union[str, Iterable[str]]) -> Union[int, List[int]]: """Convert a name, or list of names, to numeric values. :param values: A name, or list of names. :return: A numeric value, or list of numeric values. """ if isinstance(values, str): return self.names.index(values) else: return [self.names.index(v) for v in values]