-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_linearizer_mpc_pendulum.py
More file actions
116 lines (90 loc) · 2.46 KB
/
Copy pathtest_linearizer_mpc_pendulum.py
File metadata and controls
116 lines (90 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import time
from typing import Any
import env
from mpc import EnvironmentPredictiveControlWithoutOptimizer
from system import Pendulum, angle_normalize
# Import make_vec_env to allow parallelization
from stable_baselines3.common.env_util import make_vec_env
import torch
from copy import deepcopy
def cost(predicted_state, target_state, action=None, cost_dict=None):
batch_size, prediction_horizon, _ = predicted_state.shape
device = predicted_state.device
if cost_dict is None:
cost_dict = dict(
state_weight=torch.ones_like(predicted_state, device=device) * 0.5,
action_weight=torch.ones(batch_size, prediction_horizon, 1, device=device)
* 0.01,
)
cost = (
(
torch.nn.functional.mse_loss(
predicted_state,
target_state,
reduction="none",
)
* cost_dict["state_weight"]
)
.mean(1)
.sum()
)
cost += (
(
torch.norm(
action,
p=2,
dim=-1,
keepdim=True,
)
* cost_dict["action_weight"]
)
.mean(1)
.sum()
)
return cost
def obs_to_state_target(obs) -> tuple[Any, Any]:
state = obs
target = torch.zeros_like(state)
target[..., 0] = 1.0
return state, target
# Create environment
env = make_vec_env(
"Pendulum-v1",
n_envs=1,
seed=42,
env_kwargs=dict(
g=10.0,
),
)
env_render = deepcopy(env)
# Create Model Predictive Control model
mpc = EnvironmentPredictiveControlWithoutOptimizer(
env,
cost,
action_size=1,
prediction_horizon=1,
num_optimization_step=50,
lr=0.1,
std=0.2,
device="cpu",
)
env.seed(42)
env_render.seed(42)
observation = env.reset()
env_render.reset()
observation = torch.Tensor(observation.copy())
state, target = obs_to_state_target(observation)
while True:
action, cost_value = mpc(state, target)
# print(action)
action_ = action.clone().detach().numpy()
action_selected = action_[:, 0]
# print(action_selected.shape)
# print(f"Action: {action_selected}")
# print(f"Cost: {cost_value}")
observation, reward, _, information = env.step(action_selected)
env_render.step(action_selected)
# print(reward)
observation = torch.Tensor(observation.copy())
state, target = obs_to_state_target(observation)
env_render.render("human")