Source code for compiler_gym.views.observation_space_spec

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
from typing import Callable, Optional, Union

import networkx as nx
import numpy as np
from gym.spaces import Box, Space

from compiler_gym.service.proto import Observation, ObservationSpace, ScalarRange
from compiler_gym.spaces.scalar import Scalar
from compiler_gym.spaces.sequence import Sequence
from compiler_gym.util.gym_type_hints import ObservationType


def _json2nx(observation):
    json_data = json.loads(observation.string_value)
    return nx.readwrite.json_graph.node_link_graph(
        json_data, multigraph=True, directed=True
    )


def _scalar_range2tuple(sr: ScalarRange, defaults=(-np.inf, np.inf)):
    """Convert a ScalarRange to a tuple of (min, max) bounds."""
    return (
        sr.min.value if sr.HasField("min") else defaults[0],
        sr.max.value if sr.HasField("max") else defaults[1],
    )


[docs]class ObservationSpaceSpec: """Specification of an observation space. :ivar id: The name of the observation space. :vartype id: str :ivar index: The index into the list of observation spaces that the service supports. :vartype index: int :ivar space: The space. :vartype space: Space :ivar deterministic: Whether the observation space is deterministic. :vartype deterministic: bool :ivar platform_dependent: Whether the observation values depend on the execution environment of the service. :vartype platform_dependent: bool :ivar default_value: A default observation. This value will be returned by :func:`CompilerEnv.step() <compiler_gym.envs.CompilerEnv.step>` if :func:`CompilerEnv.observation_space <compiler_gym.envs.CompilerEnv.observation_space>` is set and the service terminates. """ def __init__( self, id: str, index: int, space: Space, translate: Callable[[Union[ObservationType, Observation]], ObservationType], to_string: Callable[[ObservationType], str], deterministic: bool, platform_dependent: bool, default_value: ObservationType, ): """Constructor. Don't call directly, use make_derived_space().""" self.id: str = id self.index: int = index self.space = space self.deterministic = deterministic self.platform_dependent = platform_dependent self.default_value = default_value self.translate = translate self.to_string = to_string def __hash__(self) -> int: # Quickly hash observation spaces by comparing the index into the list # of spaces returned by the environment. This means that you should not # hash between observation spaces from different environments as this # will cause collisions, e.g. # # # not okay: # >>> obs = set(env.observation.spaces).union( # other_env.observation.spaces # ) # # If you want to hash between environments, consider using the string id # to identify the observation spaces. return self.index def __repr__(self) -> str: return f"ObservationSpaceSpec({self.id})" def __eq__(self, rhs) -> bool: """Equality check.""" if not isinstance(rhs, ObservationSpaceSpec): return False return ( self.id == rhs.id and self.index == rhs.index and self.space == rhs.space and self.platform_dependent == rhs.platform_dependent and self.deterministic == rhs.deterministic ) @classmethod def from_proto(cls, index: int, proto: ObservationSpace): """Construct a space from an ObservationSpace message.""" shape_type = proto.WhichOneof("shape") def make_box(scalar_range_list, dtype, defaults): bounds = [_scalar_range2tuple(r, defaults) for r in scalar_range_list] return Box( low=np.array([b[0] for b in bounds], dtype=dtype), high=np.array([b[1] for b in bounds], dtype=dtype), dtype=dtype, ) def make_scalar(scalar_range, dtype, defaults): scalar_range_tuple = _scalar_range2tuple(scalar_range, defaults) return Scalar( min=dtype(scalar_range_tuple[0]), max=dtype(scalar_range_tuple[1]), dtype=dtype, ) def make_seq(size_range, dtype, defaults, scalar_range=None): return Sequence( size_range=_scalar_range2tuple(size_range, defaults), dtype=dtype, opaque_data_format=proto.opaque_data_format, scalar_range=scalar_range, ) # Translate from protocol buffer specification to python. There are # three variables to derive: # (1) space: the gym.Space instance describing the space. # (2) translate: is a callback that translates from an Observation # message to a python type. # (3) to_string: is a callback that translates from a python type to a # string for printing. if proto.opaque_data_format == "json://networkx/MultiDiGraph": # TODO(cummins): Add a Graph space. space = make_seq(proto.string_size_range, str, (0, None)) def translate(observation): return nx.readwrite.json_graph.node_link_graph( json.loads(observation.string_value), multigraph=True, directed=True ) def to_string(observation): return json.dumps( nx.readwrite.json_graph.node_link_data(observation), indent=2 ) elif proto.opaque_data_format == "json://": space = make_seq(proto.string_size_range, str, (0, None)) def translate(observation): return json.loads(observation.string_value) def to_string(observation): return json.dumps(observation, indent=2) elif shape_type == "int64_range_list": space = make_box( proto.int64_range_list.range, np.int64, (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ) def translate(observation): return np.array(observation.int64_list.value, dtype=np.int64) to_string = str elif shape_type == "double_range_list": space = make_box( proto.double_range_list.range, np.float64, (-np.inf, np.inf) ) def translate(observation): return np.array(observation.double_list.value, dtype=np.float64) to_string = str elif shape_type == "string_size_range": space = make_seq(proto.string_size_range, str, (0, None)) def translate(observation): return observation.string_value to_string = str elif shape_type == "binary_size_range": space = make_seq(proto.binary_size_range, bytes, (0, None)) def translate(observation): return observation.binary_value to_string = str elif shape_type == "scalar_int64_range": space = make_scalar( proto.scalar_int64_range, int, (np.iinfo(np.int64).min, np.iinfo(np.int64).max), ) def translate(observation): return int(observation.scalar_int64) to_string = str elif shape_type == "scalar_double_range": space = make_scalar(proto.scalar_double_range, float, (-np.inf, np.inf)) def translate(observation): return float(observation.scalar_double) to_string = str elif shape_type == "double_sequence": space = make_seq( proto.double_sequence.length_range, np.float64, (-np.inf, np.inf), make_scalar( proto.double_sequence.scalar_range, np.float64, (-np.inf, np.inf) ), ) def translate(observation): return np.array(observation.double_list.value, dtype=np.float64) to_string = str else: raise TypeError( f"Unknown shape '{shape_type}' for ObservationSpace:\n{proto}" ) return cls( id=proto.name, index=index, space=space, translate=translate, to_string=to_string, deterministic=proto.deterministic, platform_dependent=proto.platform_dependent, default_value=translate(proto.default_value), )
[docs] def make_derived_space( self, id: str, translate: Callable[[ObservationType], ObservationType], space: Optional[Space] = None, deterministic: Optional[bool] = None, default_value: Optional[ObservationType] = None, platform_dependent: Optional[bool] = None, to_string: Callable[[ObservationType], str] = None, ) -> "ObservationSpaceSpec": """Create a derived observation space. :param id: The name of the derived observation space. :param translate: A callback function to compute a derived observation from the base observation. :param space: The :code:`gym.Space` describing the observation space. :param deterministic: Whether the observation space is deterministic. If not provided, the value is inherited from the base observation space. :param default_value: The default value for the observation space. If not provided, the value is derived from the default value of the base observation space. :param platform_dependent: Whether the derived observation space is platform-dependent. If not provided, the value is inherited from the base observation space. :param to_string: A callback to convert and observation to a string representation. If not provided, the callback is inherited from the base observation space. :return: A new ObservationSpaceSpec. """ return ObservationSpaceSpec( id=id, index=self.index, space=space or self.space, translate=lambda observation: translate(self.translate(observation)), to_string=to_string or self.to_string, default_value=( translate(self.default_value) if default_value is None else default_value ), deterministic=( self.deterministic if deterministic is None else deterministic ), platform_dependent=( self.platform_dependent if platform_dependent is None else platform_dependent ), )