Source code for real_robot.utils.registration

from __future__ import annotations

from copy import deepcopy
from functools import partial
from typing import Type

import gymnasium as gym
from gymnasium.envs.registration import EnvSpec as GymEnvSpec

from real_robot.envs.base_env import XArmBaseEnv
from real_robot.utils.logger import get_logger
from real_robot.utils.wrappers.observation import (
    PointCloudObservationWrapper,
    RGBDObservationWrapper,
    RobotSegmentationObservationWrapper,
)


[docs] class EnvSpec: def __init__( self, uid: str, cls: Type[XArmBaseEnv], max_episode_steps=None, default_kwargs: dict = None, ): """A specification for a real_robot environment.""" self.uid = uid self.cls = cls self.max_episode_steps = max_episode_steps self.default_kwargs = {} if default_kwargs is None else default_kwargs
[docs] def make(self, **kwargs): _kwargs = self.default_kwargs.copy() _kwargs.update(kwargs) return self.cls(**_kwargs)
@property def gym_spec(self): """Return a gym EnvSpec for this env""" entry_point = self.cls.__module__ + ":" + self.cls.__name__ return GymEnvSpec( self.uid, entry_point, max_episode_steps=self.max_episode_steps, kwargs=self.default_kwargs, )
REGISTERED_ENVS: dict[str, EnvSpec] = {}
[docs] def register( name: str, cls: Type[XArmBaseEnv], max_episode_steps: int | None = None, default_kwargs: dict | None = None, ): """Register a real_robot environment.""" if name in REGISTERED_ENVS: get_logger("registration").warning(f"Env {name} already registered") if not issubclass(cls, XArmBaseEnv): raise TypeError(f"Env {name} must inherit from XArmBaseEnv") REGISTERED_ENVS[name] = EnvSpec( name, cls, max_episode_steps=max_episode_steps, default_kwargs=default_kwargs )
[docs] def make(env_id, as_gym=True, enable_segmentation=False, **kwargs): """Instantiate a real_robot environment. Args: env_id (str): Environment ID. as_gym (bool, optional): Add TimeLimit wrapper as gym. enable_segmentation (bool, optional): Whether to include Segmentation in observations. **kwargs: Keyword arguments to pass to the environment. """ if env_id not in REGISTERED_ENVS: raise KeyError("Env {} not found in registry".format(env_id)) env_spec = REGISTERED_ENVS[env_id] # Dispatch observation mode obs_mode = kwargs.get("obs_mode") if obs_mode is None: obs_mode = env_spec.cls.SUPPORTED_OBS_MODES[0] if obs_mode not in ["state", "state_dict", "none", "particles"]: kwargs["obs_mode"] = "image" # Add segmentation texture if "robot_seg" in obs_mode: enable_segmentation = True if enable_segmentation: camera_cfgs = kwargs.get("camera_cfgs", {}) camera_cfgs["add_segmentation"] = True kwargs["camera_cfgs"] = camera_cfgs env = env_spec.make(**kwargs) # Dispatch observation wrapper if "rgb" in obs_mode: env = RGBDObservationWrapper(env, obs_mode=obs_mode) elif "pointcloud" in obs_mode: env = PointCloudObservationWrapper(env) # Add robot segmentation wrapper if "robot_seg" in obs_mode: env = RobotSegmentationObservationWrapper(env) # Set observation mode on the wrapper if isinstance(env, gym.Wrapper): env.obs_mode = obs_mode # Compatible with gym.make if as_gym: env.unwrapped.spec = env_spec.gym_spec if env_spec.max_episode_steps is not None: env = gym.wrappers.TimeLimit( env, max_episode_steps=env_spec.max_episode_steps ) return env
[docs] def register_env(uid: str, max_episode_steps=None, override=False, **kwargs): """A decorator to register real_robot environments. :param uid: unique id of the environment. :param max_episode_steps: maximum number of steps in an episode. :param override: whether to override the environment if it is already registered. Notes: - `max_episode_steps` is processed differently from other keyword arguments in gym. `gym.make` wraps the env with `gym.wrappers.TimeLimit` to limit the maximum number of steps. - `gym.EnvSpec` uses kwargs instead of **kwargs! """ def _register_env(cls): if uid in REGISTERED_ENVS: if override: from gymnasium.envs.registration import registry get_logger("registration").warning(f"Override registered env {uid}") REGISTERED_ENVS.pop(uid) registry.pop(uid) else: get_logger("registration").warning( f"Env {uid} is already registered. Skip registration." ) return cls # Register for real_robot register( uid, cls, max_episode_steps=max_episode_steps, default_kwargs=deepcopy(kwargs), ) # Register for gym gym.register( uid, entry_point=partial(make, env_id=uid, as_gym=False), max_episode_steps=max_episode_steps, kwargs=deepcopy(kwargs), ) return cls return _register_env