|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import functools |
| 4 | +import math |
4 | 5 | from collections import deque, defaultdict |
5 | 6 | from functools import partial |
6 | 7 | from typing import ( |
|
12 | 13 | import gymnasium as gym |
13 | 14 | import numpy as np |
14 | 15 | from scipy.spatial.transform import Rotation |
| 16 | +from scipy.stats import norm |
15 | 17 | from transformation import Transformation |
16 | 18 |
|
17 | 19 | from ap_gym import ( |
@@ -62,11 +64,30 @@ def __init__( |
62 | 64 | "At least one of frame_position_mode or frame_rotation_mode must not be 'dont_use'" |
63 | 65 | ) |
64 | 66 |
|
| 67 | + max_expected_translation_perturbation_norm = 0 |
| 68 | + if config.perturb_object_pose: |
| 69 | + cumulative_std = config.translation_perturbation_scale * math.sqrt( |
| 70 | + config.step_limit |
| 71 | + ) |
| 72 | + # For all objects starting directly at the edge of the platform, 99.99% of the time they will stay within |
| 73 | + # this bound. |
| 74 | + max_expected_translation_perturbation = cumulative_std * norm.ppf(0.9999) |
| 75 | + max_expected_translation_perturbation_norm = ( |
| 76 | + max_expected_translation_perturbation / (np.min(config.cell_size) / 2) |
| 77 | + ) |
| 78 | + |
| 79 | + prediction_bound = 1.0 + max_expected_translation_perturbation_norm |
| 80 | + |
65 | 81 | super().__init__( |
66 | 82 | config, |
67 | 83 | num_envs, |
68 | | - single_prediction_space=gym.spaces.Box(-1, 1, shape=(target_dims,)), |
69 | | - single_prediction_target_space=gym.spaces.Box(-1, 1, shape=(target_dims,)), |
| 84 | + # We allow for more than [-1, 1] range to account for objects moving beyond the platform during the episode |
| 85 | + single_prediction_space=gym.spaces.Box( |
| 86 | + -prediction_bound, prediction_bound, shape=(target_dims,) |
| 87 | + ), |
| 88 | + single_prediction_target_space=gym.spaces.Box( |
| 89 | + -prediction_bound, prediction_bound, shape=(target_dims,) |
| 90 | + ), |
70 | 91 | loss_fn=MSELossFn(), |
71 | 92 | render_mode=render_mode, |
72 | 93 | ) |
|
0 commit comments