Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin CO committed Aug 29, 2024
1 parent 2da02c2 commit 6df6b3c
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,5 @@
weight="bold",
)
axs.text(0.75, 96.3, "Time (s)", ha="center", va="center", fontsize=18, weight="bold")
axs.legend(title='Cost function', fontsize="medium", loc="upper right", ncol=1)
axs.legend(title="Cost function", fontsize="medium", loc="upper right", ncol=1)
plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,30 @@

# --- Build target force --- #
target_time = np.linspace(0, 1, 100)
target_force = abs(np.sin(target_time*np.pi)) * 200
target_force = abs(np.sin(target_time * np.pi)) * 200
force_tracking = [target_time, target_force]

# --- Build nmpc cyclic --- #
n_total_cycles = 8
minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0
nmpc = OcpFesNmpcCyclic(model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=30,
n_shooting=5,
final_time=1,
pulse_duration={
"min": minimum_pulse_duration,
"max": 0.0006,
"bimapping": False,
},
objective={"force_tracking": force_tracking},
n_total_cycles=n_total_cycles,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="middle",
use_sx=True,
ode_solver=OdeSolver.COLLOCATION())
nmpc = OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=30,
n_shooting=5,
final_time=1,
pulse_duration={
"min": minimum_pulse_duration,
"max": 0.0006,
"bimapping": False,
},
objective={"force_tracking": force_tracking},
n_total_cycles=n_total_cycles,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="middle",
use_sx=True,
ode_solver=OdeSolver.COLLOCATION(),
)

nmpc.prepare_nmpc()
nmpc.solve()
Expand All @@ -46,22 +48,22 @@
force = [j for sub in nmpc.result["states"]["F"] for j in sub]

ax1 = plt.subplot(221)
ax1.plot(time, fatigue, label='A', color='green')
ax1.set_title('Fatigue', weight='bold')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Force scaling factor (-)')
ax1.plot(time, fatigue, label="A", color="green")
ax1.set_title("Fatigue", weight="bold")
ax1.set_xlabel("Time (s)")
ax1.set_ylabel("Force scaling factor (-)")
plt.legend()

ax2 = plt.subplot(222)
ax2.plot(time, force, label='F', color='red', linewidth=4)
ax2.plot(time, force, label="F", color="red", linewidth=4)
for i in range(n_total_cycles):
if i == 0:
ax2.plot(target_time, target_force, label='Target', color='purple')
ax2.plot(target_time, target_force, label="Target", color="purple")
else:
ax2.plot(target_time + i, target_force, color='purple')
ax2.set_title('Force', weight='bold')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Force (N)')
ax2.plot(target_time + i, target_force, color="purple")
ax2.set_title("Force", weight="bold")
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Force (N)")
plt.legend()

barWidth = 0.25 # set width of bar
Expand All @@ -76,11 +78,10 @@

ax3 = plt.subplot(212)
for i in range(n_total_cycles):
ax3.bar(bar[i], cycles[i], width = barWidth,
edgecolor ='grey', label =f'cycle n°{i+1}')
ax3.set_xticks([np.mean(r) for r in bar], [str(i+1) for i in range(n_total_cycles)])
ax3.set_xlabel('Cycles')
ax3.set_ylabel('Pulse duration (s)')
ax3.bar(bar[i], cycles[i], width=barWidth, edgecolor="grey", label=f"cycle n°{i+1}")
ax3.set_xticks([np.mean(r) for r in bar], [str(i + 1) for i in range(n_total_cycles)])
ax3.set_xlabel("Cycles")
ax3.set_ylabel("Pulse duration (s)")
plt.legend()
ax3.set_title('Pulse duration', weight='bold')
ax3.set_title("Pulse duration", weight="bold")
plt.show()
2 changes: 1 addition & 1 deletion cocofest/models/ding2003.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
model_name: str = "ding2003",
muscle_name: str = None,
sum_stim_truncation: int = None,
stim_prev: list[float] = None
stim_prev: list[float] = None,
):
super().__init__()
self._model_name = model_name
Expand Down
8 changes: 7 additions & 1 deletion cocofest/models/ding2007.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class DingModelPulseDurationFrequency(DingModelFrequency):
Muscle & Nerve: Official Journal of the American Association of Electrodiagnostic Medicine, 36(2), 214-222.
"""

def __init__(self, model_name: str = "ding_2007", muscle_name: str = None, sum_stim_truncation: int = None, stim_prev: list[float] = None):
def __init__(
self,
model_name: str = "ding_2007",
muscle_name: str = None,
sum_stim_truncation: int = None,
stim_prev: list[float] = None,
):
super(DingModelPulseDurationFrequency, self).__init__(
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation, stim_prev=stim_prev
)
Expand Down
6 changes: 5 additions & 1 deletion cocofest/models/ding2007_with_fatigue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class DingModelPulseDurationFrequencyWithFatigue(DingModelPulseDurationFrequency
"""

def __init__(
self, model_name: str = "ding_2007_with_fatigue", muscle_name: str = None, sum_stim_truncation: int = None, stim_prev: list[float] = None
self,
model_name: str = "ding_2007_with_fatigue",
muscle_name: str = None,
sum_stim_truncation: int = None,
stim_prev: list[float] = None,
):
super(DingModelPulseDurationFrequencyWithFatigue, self).__init__(
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation, stim_prev=stim_prev
Expand Down
101 changes: 71 additions & 30 deletions cocofest/optimization/fes_ocp_nmpc_cyclic.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
import math

import numpy as np
from bioptim import SolutionMerge, ObjectiveList, ObjectiveFcn, OdeSolver, Node, OptimalControlProgram, ControlType, TimeAlignment
from bioptim import (
SolutionMerge,
ObjectiveList,
ObjectiveFcn,
OdeSolver,
Node,
OptimalControlProgram,
ControlType,
TimeAlignment,
)

from .fes_ocp import OcpFes
from ..models.fes_model import FesModel
from ..custom_objectives import CustomObjective


class OcpFesNmpcCyclic:
def __init__(self,
model: FesModel = None,
n_stim: int = None,
n_shooting: int = None,
final_time: int | float = None,
pulse_event: dict = None,
pulse_duration: dict = None,
pulse_intensity: dict = None,
n_total_cycles: int = None,
n_simultaneous_cycles: int = None,
n_cycle_to_advance: int = None,
cycle_to_keep: str = None,
objective: dict = None,
use_sx: bool = True,
ode_solver: OdeSolver = OdeSolver.RK4(n_integration_steps=1),
n_threads: int = 1,
def __init__(
self,
model: FesModel = None,
n_stim: int = None,
n_shooting: int = None,
final_time: int | float = None,
pulse_event: dict = None,
pulse_duration: dict = None,
pulse_intensity: dict = None,
n_total_cycles: int = None,
n_simultaneous_cycles: int = None,
n_cycle_to_advance: int = None,
cycle_to_keep: str = None,
objective: dict = None,
use_sx: bool = True,
ode_solver: OdeSolver = OdeSolver.RK4(n_integration_steps=1),
n_threads: int = 1,
):
super(OcpFesNmpcCyclic, self).__init__()
self.model = model
Expand Down Expand Up @@ -104,7 +114,9 @@ def prepare_nmpc(self):
n_threads=self.n_threads,
)

OcpFes._sanity_check_frequency(n_stim=self.n_stim, final_time=self.final_time, frequency=frequency, round_down=round_down)
OcpFes._sanity_check_frequency(

Check warning on line 117 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L117

Added line #L117 was not covered by tests
n_stim=self.n_stim, final_time=self.final_time, frequency=frequency, round_down=round_down
)

force_fourier_coefficient = (

Check warning on line 121 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L121

Added line #L121 was not covered by tests
None if force_tracking is None else OcpFes._build_fourier_coefficient(force_tracking)
Expand Down Expand Up @@ -146,7 +158,14 @@ def prepare_nmpc(self):
x_bounds, x_init = OcpFes._set_bounds(self.model, self.n_stim * self.n_simultaneous_cycles)
one_cycle_shooting = [self.n_shooting] * self.n_stim
objective_functions = self._set_objective(

Check warning on line 160 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L157-L160

Added lines #L157 - L160 were not covered by tests
self.n_stim, one_cycle_shooting, force_fourier_coefficient, end_node_tracking, custom_objective, time_min, time_max, self.n_simultaneous_cycles
self.n_stim,
one_cycle_shooting,
force_fourier_coefficient,
end_node_tracking,
custom_objective,
time_min,
time_max,
self.n_simultaneous_cycles,
)
all_cycle_n_shooting = [self.n_shooting] * self.n_stim * self.n_simultaneous_cycles
self.ocp = OptimalControlProgram(

Check warning on line 171 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L170-L171

Added lines #L170 - L171 were not covered by tests
Expand Down Expand Up @@ -183,7 +202,7 @@ def update_stim(self, sol):
if "pulse_apparition_time" in sol.decision_parameters():
stimulation_time = sol.decision_parameters()["pulse_apparition_time"]

Check warning on line 203 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L202-L203

Added lines #L202 - L203 were not covered by tests
else:
stimulation_time = [0] + list(np.cumsum(sol.ocp.phase_time[:self.n_stim-1]))
stimulation_time = [0] + list(np.cumsum(sol.ocp.phase_time[: self.n_stim - 1]))

Check warning on line 205 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L205

Added line #L205 was not covered by tests

stim_prev = list(np.array(stimulation_time) - self.final_time)
if self.previous_stim:
Expand All @@ -206,14 +225,22 @@ def store_results(self, sol_time, sol_states, sol_parameters, index):

# Initialize the dict if it's the first iteration
if index == 0:
self.result["time"] = [None]*self.n_total_cycles
[self.result["states"].update({state_key: [None]*self.n_total_cycles}) for state_key in list(sol_states[0].keys())]
[self.result["parameters"].update({key_parameter: [None]*self.n_total_cycles}) for key_parameter in list(sol_parameters.keys())]
self.result["time"] = [None] * self.n_total_cycles
[

Check warning on line 229 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L227-L229

Added lines #L227 - L229 were not covered by tests
self.result["states"].update({state_key: [None] * self.n_total_cycles})
for state_key in list(sol_states[0].keys())
]
[

Check warning on line 233 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L233

Added line #L233 was not covered by tests
self.result["parameters"].update({key_parameter: [None] * self.n_total_cycles})
for key_parameter in list(sol_parameters.keys())
]

# Store the results
phase_size = np.array(sol_time).shape[0]
node_size = np.array(sol_time).shape[1]
sol_time = list(np.array(sol_time).reshape(phase_size*node_size))[self.first_node_in_phase*node_size:self.last_node_in_phase*node_size]
sol_time = list(np.array(sol_time).reshape(phase_size * node_size))[

Check warning on line 241 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L239-L241

Added lines #L239 - L241 were not covered by tests
self.first_node_in_phase * node_size : self.last_node_in_phase * node_size
]
sol_time = list(dict.fromkeys(sol_time)) # Remove duplicate time
if index == 0:
updated_sol_time = [t - sol_time[0] for t in sol_time]

Check warning on line 246 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L244-L246

Added lines #L244 - L246 were not covered by tests
Expand All @@ -223,13 +250,17 @@ def store_results(self, sol_time, sol_states, sol_parameters, index):
self.result["time"][index] = updated_sol_time[:-1]

Check warning on line 250 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L248-L250

Added lines #L248 - L250 were not covered by tests

for state_key in list(sol_states[0].keys()):
middle_states_values = sol_states[self.first_node_in_phase:self.last_node_in_phase]
middle_states_values = [list(middle_states_values[i][state_key][0])[:-1] for i in range(len(middle_states_values))] # Remove the last node duplicate
middle_states_values = sol_states[self.first_node_in_phase : self.last_node_in_phase]
middle_states_values = [

Check warning on line 254 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L252-L254

Added lines #L252 - L254 were not covered by tests
list(middle_states_values[i][state_key][0])[:-1] for i in range(len(middle_states_values))
] # Remove the last node duplicate
middle_states_values = [j for sub in middle_states_values for j in sub]
self.result["states"][state_key][index] = middle_states_values

Check warning on line 258 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L257-L258

Added lines #L257 - L258 were not covered by tests

for key_parameter in list(sol_parameters.keys()):
self.result["parameters"][key_parameter][index] = sol_parameters[key_parameter][self.first_node_in_phase:self.last_node_in_phase]
self.result["parameters"][key_parameter][index] = sol_parameters[key_parameter][

Check warning on line 261 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L260-L261

Added lines #L260 - L261 were not covered by tests
self.first_node_in_phase : self.last_node_in_phase
]
return

Check warning on line 264 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L264

Added line #L264 was not covered by tests

def solve(self):
Expand All @@ -244,13 +275,23 @@ def solve(self):
# Todo uncomment when the model is updated to take into account the past stimulation

@staticmethod
def _set_objective(n_stim, n_shooting, force_fourier_coefficient, end_node_tracking, custom_objective, time_min, time_max,
n_simultaneous_cycles):
def _set_objective(
n_stim,
n_shooting,
force_fourier_coefficient,
end_node_tracking,
custom_objective,
time_min,
time_max,
n_simultaneous_cycles,
):
# Creates the objective for our problem
objective_functions = ObjectiveList()
if custom_objective:
if len(custom_objective) != n_stim:
raise ValueError("The number of custom objective must be equal to the stimulation number of a single cycle")
raise ValueError(

Check warning on line 292 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L289-L292

Added lines #L289 - L292 were not covered by tests
"The number of custom objective must be equal to the stimulation number of a single cycle"
)
for i in range(len(custom_objective)):
for j in range(n_simultaneous_cycles):
objective_functions.add(custom_objective[i + j * n_stim][0])

Check warning on line 297 in cocofest/optimization/fes_ocp_nmpc_cyclic.py

View check run for this annotation

Codecov / codecov/patch

cocofest/optimization/fes_ocp_nmpc_cyclic.py#L295-L297

Added lines #L295 - L297 were not covered by tests
Expand Down

0 comments on commit 6df6b3c

Please sign in to comment.