Skip to content

Commit

Permalink
Implementing the first nmpc cyclic example
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin CO committed Aug 29, 2024
1 parent c9bfbec commit 2da02c2
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 117 deletions.
3 changes: 1 addition & 2 deletions cocofest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from .optimization.fes_ocp import OcpFes
from .optimization.fes_identification_ocp import OcpFesId
from .optimization.fes_ocp_dynamics import OcpFesMsk
# from .optimization.fes_ocp_nmpc import OcpFesNmpc
from .optimization.fes_ocp_mhe import OcpFesMhe
from .optimization.fes_ocp_nmpc_cyclic import OcpFesNmpcCyclic
from .integration.ivp_fes import IvpFes
from .fourier_approx import FourierSeries
from .identification.ding2003_force_parameter_identification import DingModelFrequencyForceParameterIdentification
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cocofest import PlotCyclingResult

# Plot the results
PlotCyclingResult("cycling_fes_driven_min_residual_torque_and_fatigue_results.pkl").plot(starting_location="E")
PlotCyclingResult("cycling_fes_driven_min_residual_torque.pkl").plot(starting_location="E")
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
)

fig, axs = plt.subplots(1, 1, figsize=(3, (1 / 3) * 7))
fig.suptitle("Muscle fatigue", fontsize=20, fontweight="bold", fontname="Times New Roman")
fig.suptitle("Muscle fatigue", fontsize=20, fontweight="bold")

axs.set_xlim(left=0, right=1.5)
plt.setp(
Expand All @@ -192,8 +192,8 @@
a_force_sum_percentage = (np.array(a_force_sum_list) / a_sum_base_line) * 100
a_fatigue_sum_percentage = (np.array(a_fatigue_sum_list) / a_sum_base_line) * 100

axs.plot(data_minimize_force["time"], a_force_sum_percentage, lw=5)
axs.plot(data_minimize_force["time"], a_fatigue_sum_percentage, lw=5)
axs.plot(data_minimize_force["time"], a_force_sum_percentage, lw=5, label="Minimize force production")
axs.plot(data_minimize_force["time"], a_fatigue_sum_percentage, lw=5, label="Maximize muscle capacity")

axs.set_xlim(left=0, right=1.5)

Expand All @@ -204,21 +204,16 @@
)

labels = axs.get_xticklabels() + axs.get_yticklabels()
[label.set_fontname("Times New Roman") for label in labels]
[label.set_fontsize(14) for label in labels]
fig.text(
0.05,
0.5,
"Fatigue percentage (%)",
"Muscle capacity (%)",
ha="center",
va="center",
rotation="vertical",
fontsize=18,
weight="bold",
font="Times New Roman",
)
axs.text(0.75, 96.3, "Time (s)", ha="center", va="center", fontsize=18, weight="bold", font="Times New Roman")
plt.legend(
["Force", "Fatigue"], loc="upper right", ncol=1, prop={"family": "Times New Roman", "size": 14, "weight": "bold"}
)
axs.text(0.75, 96.3, "Time (s)", ha="center", va="center", fontsize=18, weight="bold")
axs.legend(title='Cost function', fontsize="medium", loc="upper right", ncol=1)
plt.show()
38 changes: 0 additions & 38 deletions cocofest/examples/getting_started/mhe_try.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
This example showcases a moving time horizon simulation problem of cyclic muscle force tracking.
The FES model used here is Ding's 2007 pulse duration and frequency model with fatigue.
Only the pulse duration is optimized, frequency is fixed.
The mhe problem is composed of 3 cycles and will move forward 1 cycle at each step.
Only the middle cycle is kept in the optimization problem, the mhe problem stops once the last 6th cycle is reached.
The nmpc cyclic problem is composed of 3 cycles and will move forward 1 cycle at each step.
Only the middle cycle is kept in the optimization problem, the nmpc cyclic problem stops once the last 6th cycle is reached.
"""

import numpy as np
import matplotlib.pyplot as plt

from bioptim import OdeSolver
from cocofest import OcpFesMhe, DingModelPulseDurationFrequencyWithFatigue
from cocofest import OcpFesNmpcCyclic, DingModelPulseDurationFrequencyWithFatigue

# --- Build target force --- #
target_time = np.linspace(0, 1, 100)
target_force = abs(np.sin(target_time*np.pi)) * 200
force_tracking = [target_time, target_force]

# --- Build mhe --- #
# --- Build nmpc cyclic --- #
n_total_cycles = 8
minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0
mhe = OcpFesMhe(model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
nmpc = OcpFesNmpcCyclic(model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=30,
n_shooting=5,
final_time=1,
Expand All @@ -29,51 +30,55 @@
"bimapping": False,
},
objective={"force_tracking": force_tracking},
n_total_cycles=8,
n_total_cycles=n_total_cycles,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="middle",
use_sx=True,
ode_solver=OdeSolver.COLLOCATION())

mhe.prepare_mhe()
mhe.solve()
nmpc.prepare_nmpc()
nmpc.solve()

# --- Show results --- #
time = [j for sub in mhe.result["time"] for j in sub]
fatigue = [j for sub in mhe.result["states"]["A"] for j in sub]
force = [j for sub in mhe.result["states"]["F"] for j in sub]
time = [j for sub in nmpc.result["time"] for j in sub]
fatigue = [j for sub in nmpc.result["states"]["A"] for j in sub]
force = [j for sub in nmpc.result["states"]["F"] for j in sub]

ax1 = plt.subplot(221)
ax1.plot(time, fatigue, label='A', color='green')
ax1.set_title('Fatigue', weight='bold')
ax1.set_xlabel('Time')
ax1.set_ylabel('Force scaling factor')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Force scaling factor (-)')
plt.legend()

ax2 = plt.subplot(222)
ax2.plot(time, force, label='F', color='red', linewidth=2)
ax2.plot(target_time, target_force, label='Target', color='purple')
ax2.plot(time, force, label='F', color='red', linewidth=4)
for i in range(n_total_cycles):
if i == 0:
ax2.plot(target_time, target_force, label='Target', color='purple')
else:
ax2.plot(target_time + i, target_force, color='purple')
ax2.set_title('Force', weight='bold')
ax2.set_xlabel('Time')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Force (N)')
plt.legend()

barWidth = 0.25 # set width of bar
cycles = mhe.result["parameters"]["pulse_duration"] # set height of bar
cycles = nmpc.result["parameters"]["pulse_duration"] # set height of bar
bar = [] # Set position of bar on X axis
for i in range(6):
for i in range(n_total_cycles):
if i == 0:
br = [barWidth * (x + 1) for x in range(len(cycles[i]))]
else:
br = [bar[-1][-1] + barWidth * (x + 1) for x in range(len(cycles[i]))]
bar.append(br)

ax3 = plt.subplot(212)
for i in range(6):
for i in range(n_total_cycles):
ax3.bar(bar[i], cycles[i], width = barWidth,
edgecolor ='grey', label =f'cycle n°{i+1}')
ax3.set_xticks([np.mean(r) for r in bar], ['1', '2', '3', '4', '5', '6'])
ax3.set_xticks([np.mean(r) for r in bar], [str(i+1) for i in range(n_total_cycles)])
ax3.set_xlabel('Cycles')
ax3.set_ylabel('Pulse duration (s)')
plt.legend()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,18 @@
ticks=[1e-12, 1e-10, 1e-8, 1e-6, 1e-4, 1e-2, 1, max_error],
cmap=cmap,
)
cbar1.set_label(label="Force absolute error (N)", size=25, fontname="Times New Roman")
cbar1.set_label(label="Muscle force absolute error (N)", size=25, fontname="Times New Roman")

cbar1.ax.set_yticklabels(
[
"{:.0e}".format(float(1e-12)),
"{:.0e}".format(float(1e-10)),
"{:.0e}".format(float(1e-8)),
"{:.0e}".format(float(1e-6)),
"{:.0e}".format(float(1e-4)),
"{:.0e}".format(float(1e-2)),
"{:.0e}".format(float(1)),
"{:.1e}".format(float(round(max_error))),
"1e⁻¹²",
"1e⁻¹⁰",
"1e⁻⁸",
"1e⁻⁶",
"1e⁻⁴",
"1e⁻²",
"1e⁰",
"5.3e⁺¹",
],
size=25,
fontname="Times New Roman",
Expand All @@ -184,14 +184,14 @@
y_beneath_1e_8 = []
for j in range(len((all_mode_list_error_beneath_1e_8[i]))):
y_beneath_1e_8.append(parameter_list[i][all_mode_list_error_beneath_1e_8[i][j]][1])
axs.plot(x_beneath_1e_8, y_beneath_1e_8, color="darkred", label="Calcium absolute error < 1e-8", linewidth=3)
axs.plot(x_beneath_1e_8, y_beneath_1e_8, color="darkred", label=r"Calcium absolute error < 1e⁻⁸", linewidth=3)

axs.scatter(0, 0, color="white", label="OCP (s) | 100 Integrations (s)", marker="+", s=0, lw=0)
axs.scatter(0, 0, color="white", label="Initialization (s) | 100 Integrations (s)", marker="+", s=0, lw=0)
axs.scatter(
1,
1,
color="blue",
label=" " + str(round(a_ocp_time, 3)) + " " + str(round(a_integration_time, 3)),
label=" " + str(round(a_ocp_time, 3)) + " " + str(round(a_integration_time, 3)),
marker="^",
s=200,
lw=5,
Expand All @@ -200,7 +200,7 @@
100,
39,
color="black",
label=" " + str(round(b_ocp_time, 3)) + " " + str(round(b_integration_time, 3)),
label=" " + str(round(b_ocp_time, 3)) + " " + str(round(b_integration_time, 3)),
marker="+",
s=500,
lw=5,
Expand All @@ -209,15 +209,15 @@
100,
100,
color="green",
label=" " + str(round(c_ocp_time, 3)) + " " + str(round(c_integration_time, 3)),
label=" " + str(round(c_ocp_time, 3)) + " " + str(round(c_integration_time, 3)),
marker=",",
s=200,
lw=5,
)

axs.set_xlabel("Frequency (Hz)", fontsize=25, fontname="Times New Roman")
axs.xaxis.set_major_locator(MaxNLocator(integer=True))
axs.set_ylabel("Past stimulation kept for computation (n)", fontsize=25, fontname="Times New Roman")
axs.set_ylabel("Past stimulations kept for computation (n)", fontsize=25, fontname="Times New Roman")
axs.yaxis.set_major_locator(MaxNLocator(integer=True))

ticks = np.arange(1, 101, 1).tolist()
Expand Down
5 changes: 1 addition & 4 deletions cocofest/optimization/fes_ocp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np

from bioptim import (
BiMapping,
# BiMappingList, parameter mapping not yet implemented
BoundsList,
ConstraintList,
ControlType,
Expand All @@ -29,7 +27,6 @@
from ..models.ding2007 import DingModelPulseDurationFrequency
from ..models.ding2007_with_fatigue import DingModelPulseDurationFrequencyWithFatigue
from ..models.ding2003 import DingModelFrequency
from ..models.ding2003_with_fatigue import DingModelFrequencyWithFatigue
from ..models.hmed2018 import DingModelIntensityFrequency
from ..models.hmed2018_with_fatigue import DingModelIntensityFrequencyWithFatigue

Expand Down Expand Up @@ -158,7 +155,7 @@ def prepare_ocp(
force_fourier_coefficient = (
None if force_tracking is None else OcpFes._build_fourier_coefficient(force_tracking)
)
end_node_tracking = end_node_tracking

models = [model] * n_stim
n_shooting = [n_shooting] * n_stim

Expand Down
Empty file.
Loading

0 comments on commit 2da02c2

Please sign in to comment.