Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nmpc cyclic exemple #66

Merged
merged 9 commits into from
Aug 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed time continuity and past stim
Kevin CO committed Aug 27, 2024
commit 39239f192c2aad55299bc23e2c734b1dfd907617
18 changes: 9 additions & 9 deletions cocofest/examples/getting_started/mhe_try.py
Original file line number Diff line number Diff line change
@@ -4,9 +4,9 @@
import matplotlib.pyplot as plt


time = np.linspace(0, 1, 100)
force = abs(np.sin(time * 5) + np.random.normal(scale=0.1, size=len(time))) * 100
force_tracking = [time, force]
time1 = np.linspace(0, 6, 600)
force1 = abs(np.sin(time1 * 5) + np.random.normal(scale=0.1, size=len(time1))) * 100
force_tracking = [time1, force1]

minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0
mhe = OcpFesMhe(model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
@@ -29,10 +29,10 @@
mhe.prepare_mhe()
mhe.solve()
# print(mhe)
plt.plot(np.linspace(0, 6, 260), mhe.result_states["F"])
time = [j for sub in mhe.result["time"] for j in sub]
force = [j for sub in mhe.result["states"]["F"] for j in sub]
# fatigue = [j for sub in mhe.result["states"]["A"] for j in sub]
# plt.plot(time, fatigue)
plt.plot(time, force)
plt.plot(time1, force1)
plt.show()


# sol = ocp.solve()
# sol.graphs()

146 changes: 42 additions & 104 deletions cocofest/optimization/fes_ocp_mhe.py
Original file line number Diff line number Diff line change
@@ -174,99 +174,46 @@ def prepare_mhe(self):

return self.ocp

# def update_mhe(self, previous_sol):
# sol_states = sol.decision_states(to_merge=[SolutionMerge.PHASES, SolutionMerge.NODES])
# cn_results.append(list(sol_states["Cn"][0]))
# f_results.append(list(sol_states["F"][0]))
# a_results.append(list(sol_states["A"][0]))
# tau1_results.append(list(sol_states["Tau1"][0]))
# km_results.append(list(sol_states["Km"][0]))
# sol_time = sol.decision_time(to_merge=[SolutionMerge.PHASES, SolutionMerge.NODES])
# sol_time = list(sol_time.reshape(sol_time.shape[0]))
#
# sol_time_stim_parameters = sol.decision_parameters()["pulse_apparition_time"]
#
# if previous_stim:
# # stim_prev = list(sol_time_stim_parameters - sol_time[-1])
# update_previous_stim = list(np.array(previous_stim) - sol_time[-1])
# previous_stim = update_previous_stim + stim_prev
# else:
# stim_prev = list(sol_time_stim_parameters - sol_time[-1])
# previous_stim = stim_prev
#
# if i != 0:
# sol_time = [x + time[-1][-1] for x in sol_time]
#
# time.append(sol_time)
# keys = list(sol_states.keys())
#
# for key in keys:
# ocp.nlp[0].x_bounds[key].max[0][0] = sol_states[key][-1][-1]
# ocp.nlp[0].x_bounds[key].min[0][0] = sol_states[key][-1][-1]
# for j in range(len(ocp.nlp)):
# ocp.nlp[j].model = DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10,
# stim_prev=previous_stim)
# for key in keys:
# ocp.nlp[j].x_init[key].init[0][0] = sol_states[key][-1][-1]

def update_time(self, sol, index):
if index == 0:
sol_time = sol.decision_time(to_merge=[SolutionMerge.PHASES, SolutionMerge.NODES])
sol_time = list(sol_time.reshape(sol_time.shape[0]))
sol_time = [t - sol_time[0] for t in sol_time]



# self.time.append(sol_time)
else:
sol_time = sol.decision_time(to_merge=[SolutionMerge.PHASES, SolutionMerge.NODES])
sol_time = list(sol_time.reshape(sol_time.shape[0]))


sol_time = sol.decision_time(to_merge=[SolutionMerge.PHASES, SolutionMerge.NODES])
sol_time = list(sol_time.reshape(sol_time.shape[0]))
final_time = sol_time[-1] + self.time[-1][-1]

if index != 0:
sol_time = [x + self.time[-1][-1] for x in sol_time]
self.time.append(sol_time)

def update_states_bounds(self, sol_states, state_keys):
def update_states_bounds(self, sol_states):
state_keys = list(self.ocp.nlp[0].states.keys())
index_to_keep = 1 * self.n_stim - 1 # todo: update this when more simultaneous cycles than 3
for key in state_keys:
self.ocp.nlp[0].x_bounds[key].max[0][0] = sol_states[self.n_stim//self.n_simultaneous_cycles - 1][key][0][-1]
self.ocp.nlp[0].x_bounds[key].min[0][0] = sol_states[self.n_stim//self.n_simultaneous_cycles - 1][key][0][-1]
for j in range(self.n_stim//self.n_simultaneous_cycles - 1, len(self.ocp.nlp)):
self.ocp.nlp[0].x_bounds[key].max[0][0] = sol_states[index_to_keep][key][0][-1]
self.ocp.nlp[0].x_bounds[key].min[0][0] = sol_states[index_to_keep][key][0][-1]
for j in range(index_to_keep, len(self.ocp.nlp)):
self.ocp.nlp[j].x_init[key].init[0][0] = sol_states[j][key][0][0]

def update_parameters(self, sol, parameters_keys):
sol_parameters = sol.decision_parameters()
sol_parameters_dict = {key: list(sol_parameters[key][0]) for key in sol_parameters.keys()}
# def update_parameters(self, sol, parameters_keys):
# sol_parameters = sol.decision_parameters()
# sol_parameters_dict = {key: list(sol_parameters[key][0]) for key in sol_parameters.keys()}
# return

def update_objective(self, sol):
return

def update_stim(self, sol):
previous_time = self.time[0]
if "pulse_apparition_time" in sol.decision_parameters():
sol_time_stim_parameters = sol.decision_parameters()["pulse_apparition_time"]

if self.previous_stim:
stim_prev = list(sol_time_stim_parameters - self.time[-1])
update_previous_stim = list(np.array(self.previous_stim) - self.time[-1])
self.previous_stim = update_previous_stim + stim_prev
# TODO: check if correct did some modification since last time
# TODO: wrong, will take the last previous_stim from n_simultaneous_cycles at each iteration
# TODO: wrong, time is pushed to many times
else:
stim_prev = list(sol_time_stim_parameters - self.time[-1])
self.previous_stim = stim_prev
stimulation_time = sol.decision_parameters()["pulse_apparition_time"]
else:
stimulation_time = [0] + list(np.cumsum(sol.ocp.phase_time[:self.n_stim-1]))

list(self.previous_stim)
stim_prev = list(np.array(stimulation_time) - self.final_time)
if self.previous_stim:
update_previous_stim = list(np.array(self.previous_stim) - self.final_time)
self.previous_stim = update_previous_stim + stim_prev

else:
self.previous_stim = stim_prev

for j in range(len(sol.ocp.nlp)):
sol.ocp.nlp[j].model.set_pass_pulse_apparition_time(self.previous_stim) #TODO: Does not seem to impact the model estimation

def store_results(self, sol_time, sol_states, sol_parameters, index, merge=False):
if self.cycle_to_keep == "middle":
# Get the middle phase index to keep
phase_to_keep = int(math.ceil(self.n_simultaneous_cycles / 2))
first_node_in_phase = self.n_stim * (phase_to_keep - 1)
last_node_in_phase = self.n_stim * phase_to_keep
self.first_node_in_phase = self.n_stim * (phase_to_keep - 1)
self.last_node_in_phase = self.n_stim * phase_to_keep

# Initialize the dict if it's the first iteration
if index == 0:
@@ -277,45 +224,34 @@ def store_results(self, sol_time, sol_states, sol_parameters, index, merge=False
# Store the results
phase_size = np.array(sol_time).shape[0]
node_size = np.array(sol_time).shape[1]
sol_time = list(np.array(sol_time).reshape(phase_size*node_size))[first_node_in_phase*node_size:last_node_in_phase*node_size]
sol_time = list(np.array(sol_time).reshape(phase_size*node_size))[self.first_node_in_phase*node_size:self.last_node_in_phase*node_size]
sol_time = list(dict.fromkeys(sol_time)) # Remove duplicate time
if index == 0:
sol_time = [t - sol_time[0] for t in sol_time]
updated_sol_time = [t - sol_time[0] for t in sol_time]
else:
sol_time = [t - sol_time[0] + self.result["time"][index-1][-1] for t in sol_time]
# time_diff = sol_time[first_node_in_phase]
# sol_time[first_node_in_phase:last_node_in_phase] = sol_time[first_node_in_phase:last_node_in_phase] - self.time[-1][-1]
self.result["time"][index] = sol_time # Todo remove last node
updated_sol_time = [t - sol_time[0] + self.temp_last_node_time for t in sol_time]
self.temp_last_node_time = updated_sol_time[-1]
self.result["time"][index] = updated_sol_time[:-1]

for state_key in list(sol_states[0].keys()):
middle_states_values = sol_states[first_node_in_phase:last_node_in_phase]
middle_states_values = [list(middle_states_values[i][state_key][0]) for i in range(len(middle_states_values))]
middle_states_values = sol_states[self.first_node_in_phase:self.last_node_in_phase]
middle_states_values = [list(middle_states_values[i][state_key][0])[:-1] for i in range(len(middle_states_values))] # Remove the last node duplicate
middle_states_values = [j for sub in middle_states_values for j in sub]
self.result["states"][state_key][index] = middle_states_values# Todo might be wrong

for key_parameter in list(sol_parameters.keys()):
self.result["parameters"][key_parameter][index] = sol_parameters[key_parameter][first_node_in_phase:last_node_in_phase]
self.result["parameters"][key_parameter][index] = sol_parameters[key_parameter][self.first_node_in_phase:self.last_node_in_phase]
return

def solve(self):
state_keys = list(self.ocp.nlp[0].states.keys())
parameters_keys = list(self.ocp.nlp[0].parameters.keys())
for i in range(self.n_total_cycles // self.n_cycle_to_advance):
sol = self.ocp.solve()

sol_states = sol.decision_states(to_merge=[SolutionMerge.NODES])
self.update_states_bounds(sol_states, state_keys)

self.update_states_bounds(sol_states)
sol_time = sol.decision_time(to_merge=SolutionMerge.NODES, time_alignment=TimeAlignment.STATES)
# sol_time = self.update_time(sol, index=i)

sol_parameters = sol.decision_parameters()
# self.update_parameters(sol, parameters_keys)
# self.update_stim(sol)

# sol.graphs()

self.store_results(sol_time, sol_states, sol_parameters, i)
# return self.ocp.solve()
self.update_stim(sol)

@staticmethod
def _set_objective(n_stim, n_shooting, force_fourier_coefficient, end_node_tracking, custom_objective, time_min, time_max,
@@ -389,6 +325,8 @@ def _mhe_sanity_check(self):

if self.cycle_to_keep not in ["first", "middle", "last"]:
raise ValueError("cycle_to_keep must be either 'first', 'middle' or 'last'")
if self.cycle_to_keep != "middle":
raise NotImplementedError("Only 'middle' cycle_to_keep is implemented")

# if self.cycle_to_keep == "middle" and self.n_simultaneous_cycles % 2 == 0:
# raise ValueError("The number of n_total_cycles must be an odd number if cycle_to_keep is 'middle'")
if self.n_simultaneous_cycles != 3:
raise NotImplementedError("Only 3 simultaneous cycles are implemented yet work in progress") # todo add more simultaneous cycles