-
Notifications
You must be signed in to change notification settings - Fork 0
/
rocket-euler-sim.py
113 lines (94 loc) · 1.89 KB
/
rocket-euler-sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Rocket Simulation with Explicit Euler's method
# ----------------------------------------------
#
# @from: https://github.com/casadi/casadi/blob/master/docs/examples/python/rocket.py
#
# System Dynamics
#
# sdot = v,
# vdot = (u - 0.05 * v^2)/m
# mdot = -0.1*u^2
#
# Parameters
#
# m - mass (kg)
# s - position (m)
# v - velocity (m/s)
# u - fuel/thrust
#
from casadi import (
MX,
Function,
nlpsol,
vertcat,
mtimes,
linspace,
)
from pylab import (
plot,
grid,
show,
legend,
xlabel,
)
# Constants
T = 0.2 # time horizon (s)
N = 20 # euler control interval
dt = T/N # Time step
# Control
u = MX.sym("u")
# State
x = MX.sym("x",3)
s = x[0] # position
v = x[1] # speed
m = x[2] # mass
# ODE right hand side
sdot = v
vdot = (u - 0.05 * v**2)/m
mdot = -0.1*u**2
xdot = vertcat(sdot, vdot, mdot)
# ODE right hand side function
f = Function('f', [x,u],[xdot])
# Integrate with Explicit Euler over 0.2 seconds
xj = x
for j in range(N):
fj = f(xj,u)
xj += dt*fj
# Discrete time dynamics function
F = Function('F', [x,u],[xj])
# Number of control segments
nu = 50
# Control for all segments
U = MX.sym("U", nu)
# Initial conditions
X0 = MX([0, 0, 1])
# Integrate over all intervals
X=X0
for k in range(nu):
X = F(X, U[k])
# Objective function and constraints
J = mtimes(U.T,U) # u'*u in Matlab
G = X[0:2] # x(1:2) in Matlab
# NLP
nlp = {'x':U, 'f':J, 'g':G}
# Allocate an NLP solver
opts = {"ipopt.tol":1e-10, "expand":True}
solver = nlpsol("solver", "ipopt", nlp, opts)
arg = {}
# Bounds on u and initial condition
arg["lbx"] = -0.5
arg["ubx"] = 0.5
arg["x0"] = 0.4
# Bounds on g
arg["lbg"] = [10,0]
arg["ubg"] = [10,0]
# Solve the problem
res = solver(**arg)
# Get the solution
tgrid = linspace(0, T, nu).full()
plot(tgrid, res["x"].full())
plot(tgrid, res["lam_x"].full())
legend(['u', 'lambda_u'])
xlabel('t (s)')
grid()
show()