Skip to content

Commit 517d67e

Browse files
author
Tim Schneider
committed
Setting prediction bounds properly for TactilePoseEstimationEnv, accounting for object movement
1 parent e64f639 commit 517d67e

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

tactile_mnist/tactile_perception_vector_env.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class TactilePerceptionConfig:
8484
cell_size: tuple[float, float] = tuple(CELL_SIZE)
8585
cell_padding: tuple[float, float] = tuple(CELL_PADDING)
8686
smallest_dimension_up: bool = False
87+
translation_perturbation_scale: float = 1e-3
88+
rotation_perturbation_scale: float = 5e-2
8789

8890

8991
class GenericMeshDataPoint(Protocol):
@@ -660,8 +662,12 @@ def execute_step(
660662
self.__current_sensor_target_poses_platform_frame
661663
)
662664
if self.__config.perturb_object_pose:
663-
translation_perturbation = self.np_random.normal(scale=1e-3, size=2)
664-
rotation_perturbation = self.np_random.normal(scale=5e-2)
665+
translation_perturbation = self.np_random.normal(
666+
scale=self.config.translation_perturbation_scale, size=2
667+
)
668+
rotation_perturbation = self.np_random.normal(
669+
scale=self.config.rotation_perturbation_scale
670+
)
665671
perturbation = Transformation.from_pos_euler(
666672
np.concatenate([translation_perturbation, [0]]),
667673
[0, 0, rotation_perturbation],

tactile_mnist/tactile_pose_estimation_env.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
import math
45
from collections import deque, defaultdict
56
from functools import partial
67
from typing import (
@@ -12,6 +13,7 @@
1213
import gymnasium as gym
1314
import numpy as np
1415
from scipy.spatial.transform import Rotation
16+
from scipy.stats import norm
1517
from transformation import Transformation
1618

1719
from ap_gym import (
@@ -62,11 +64,30 @@ def __init__(
6264
"At least one of frame_position_mode or frame_rotation_mode must not be 'dont_use'"
6365
)
6466

67+
max_expected_translation_perturbation_norm = 0
68+
if config.perturb_object_pose:
69+
cumulative_std = config.translation_perturbation_scale * math.sqrt(
70+
config.step_limit
71+
)
72+
# For all objects starting directly at the edge of the platform, 99.99% of the time they will stay within
73+
# this bound.
74+
max_expected_translation_perturbation = cumulative_std * norm.ppf(0.9999)
75+
max_expected_translation_perturbation_norm = (
76+
max_expected_translation_perturbation / (np.min(config.cell_size) / 2)
77+
)
78+
79+
prediction_bound = 1.0 + max_expected_translation_perturbation_norm
80+
6581
super().__init__(
6682
config,
6783
num_envs,
68-
single_prediction_space=gym.spaces.Box(-1, 1, shape=(target_dims,)),
69-
single_prediction_target_space=gym.spaces.Box(-1, 1, shape=(target_dims,)),
84+
# We allow for more than [-1, 1] range to account for objects moving beyond the platform during the episode
85+
single_prediction_space=gym.spaces.Box(
86+
-prediction_bound, prediction_bound, shape=(target_dims,)
87+
),
88+
single_prediction_target_space=gym.spaces.Box(
89+
-prediction_bound, prediction_bound, shape=(target_dims,)
90+
),
7091
loss_fn=MSELossFn(),
7192
render_mode=render_mode,
7293
)

0 commit comments

Comments
 (0)