Source code for real_robot.utils.wrappers.observation

from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from real_robot.utils.common import (
    flatten_dict_keys,
    flatten_dict_space_keys,
    merge_dicts,
)


[docs] class RGBDObservationWrapper(gym.ObservationWrapper): """Map RealSense camera capture to rgb and depth.""" def __init__(self, env, obs_mode="rgbd"): """ :param obs_mode: if obs_mode == 'rgb', use only Color """ super().__init__(env) self.obs_mode = obs_mode self.observation_space = deepcopy(env.observation_space) # Remove Position from camera obs space if self.obs_mode == "rgb": [ cam_space.spaces.pop("depth", None) for cam_space in self.observation_space["image"].spaces.values() ]
[docs] def observation(self, observation: dict): """Applied on obs before returning from self.reset() and self.step()""" image_obs = observation["image"] if self.obs_mode == "rgb": [cam_obs.pop("depth", None) for cam_obs in image_obs.values()] return observation
[docs] def merge_dict_spaces(dict_spaces: Sequence[spaces.Dict]): reverse_spaces = merge_dicts([x.spaces for x in dict_spaces]) for key in reverse_spaces: low, high = [], [] for x in reverse_spaces[key]: assert isinstance(x, spaces.Box), type(x) low.append(x.low) high.append(x.high) low = np.concatenate(low) high = np.concatenate(high) new_space = spaces.Box(low=low, high=high, dtype=low.dtype) reverse_spaces[key] = new_space return spaces.Dict(OrderedDict(reverse_spaces))
[docs] class PointCloudObservationWrapper(gym.ObservationWrapper): """Convert Position textures to world-space point cloud.""" def __init__(self, env): raise NotImplementedError("Check impl for XArmBaseEnv") super().__init__(env) self.observation_space = deepcopy(env.observation_space) self.update_observation_space(self.observation_space) self._buffer = {}
[docs] @staticmethod def update_observation_space(space: spaces.Dict): # Replace image observation spaces with point cloud ones image_space: spaces.Dict = space.spaces.pop("image") space.spaces.pop("camera_param") pcd_space = OrderedDict() for cam_uid in image_space: cam_image_space = image_space[cam_uid] cam_pcd_space = OrderedDict() h, w = cam_image_space["Position"].shape[:2] cam_pcd_space["xyzw"] = spaces.Box( low=-np.inf, high=np.inf, shape=(h * w, 4), dtype=np.float32 ) # Extra keys if "Color" in cam_image_space.spaces: cam_pcd_space["rgb"] = spaces.Box( low=0, high=255, shape=(h * w, 3), dtype=np.uint8 ) if "Segmentation" in cam_image_space.spaces: cam_pcd_space["Segmentation"] = spaces.Box( low=0, high=(2**32 - 1), shape=(h * w, 4), dtype=np.uint32 ) pcd_space[cam_uid] = spaces.Dict(cam_pcd_space) pcd_space = merge_dict_spaces(pcd_space.values()) space.spaces["pointcloud"] = pcd_space
[docs] def observation(self, observation: dict): image_obs = observation.pop("image") camera_params = observation.pop("camera_param") pointcloud_obs = OrderedDict() for cam_uid, images in image_obs.items(): cam_pcd = {} # Each pixel is (x, y, z, z_buffer_depth) in OpenGL camera space position = images["Position"] # position[..., 3] = position[..., 3] < 1 position[..., 3] = position[..., 2] < 0 # Convert to world space cam2world = camera_params[cam_uid]["cam2world_gl"] xyzw = position.reshape(-1, 4) @ cam2world.T cam_pcd["xyzw"] = xyzw # Extra keys if "Color" in images: rgb = images["Color"][..., :3] rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8) cam_pcd["rgb"] = rgb.reshape(-1, 3) if "Segmentation" in images: cam_pcd["Segmentation"] = images["Segmentation"].reshape(-1, 4) pointcloud_obs[cam_uid] = cam_pcd pointcloud_obs = merge_dicts(pointcloud_obs.values()) for key, value in pointcloud_obs.items(): buffer = self._buffer.get(key, None) pointcloud_obs[key] = np.concatenate(value, out=buffer) self._buffer[key] = pointcloud_obs[key] observation["pointcloud"] = pointcloud_obs return observation
[docs] class RobotSegmentationObservationWrapper(gym.ObservationWrapper): """Add a binary mask for robot links.""" def __init__(self, env, replace=True): raise NotImplementedError("Check impl for XArmBaseEnv") super().__init__(env) self.observation_space = deepcopy(env.observation_space) self.init_observation_space(self.observation_space, replace=replace) self.replace = replace # Cache robot link ids self.robot_link_ids = self.env.robot_link_ids
[docs] @staticmethod def init_observation_space(space: spaces.Dict, replace: bool): # Update image observation spaces if "image" in space.spaces: image_space = space["image"] for cam_uid in image_space: cam_space = image_space[cam_uid] if "Segmentation" not in cam_space.spaces: continue height, width = cam_space["Segmentation"].shape[:2] new_space = spaces.Box( low=0, high=1, shape=(height, width, 1), dtype="bool" ) if replace: cam_space.spaces.pop("Segmentation") cam_space.spaces["robot_seg"] = new_space # Update pointcloud observation spaces if "pointcloud" in space.spaces: pcd_space = space["pointcloud"] if "Segmentation" in pcd_space.spaces: n = pcd_space["Segmentation"].shape[0] new_space = spaces.Box(low=0, high=1, shape=(n, 1), dtype="bool") if replace: pcd_space.spaces.pop("Segmentation") pcd_space.spaces["robot_seg"] = new_space
[docs] def reset(self, **kwargs): observation = self.env.reset(**kwargs) self.robot_link_ids = self.env.robot_link_ids return self.observation(observation)
[docs] def observation_image(self, observation: dict): image_obs = observation["image"] for cam_images in image_obs.values(): if "Segmentation" not in cam_images: continue seg = cam_images["Segmentation"] robot_seg = np.isin(seg[..., 1:2], self.robot_link_ids) if self.replace: cam_images.pop("Segmentation") cam_images["robot_seg"] = robot_seg return observation
[docs] def observation_pointcloud(self, observation: dict): pointcloud_obs = observation["pointcloud"] if "Segmentation" not in pointcloud_obs: return observation seg = pointcloud_obs["Segmentation"] robot_seg = np.isin(seg[..., 1:2], self.robot_link_ids) if self.replace: pointcloud_obs.pop("Segmentation") pointcloud_obs["robot_seg"] = robot_seg return observation
[docs] def observation(self, observation: dict): if "image" in observation: observation = self.observation_image(observation) if "pointcloud" in observation: observation = self.observation_pointcloud(observation) return observation
[docs] class FlattenObservationWrapper(gym.ObservationWrapper): def __init__(self, env) -> None: raise NotImplementedError("Check impl for XArmBaseEnv") super().__init__(env) self.observation_space = flatten_dict_space_keys(self.observation_space)
[docs] def observation(self, observation): return flatten_dict_keys(observation)