Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot constraint violations #218

Merged
merged 13 commits into from
Sep 4, 2024
90 changes: 81 additions & 9 deletions opty/direct_collocation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python

#
from functools import wraps
import logging

Expand Down Expand Up @@ -422,6 +422,8 @@ def plot_trajectories(self, vector, axes=None):

return axes



@_optional_plt_dep
def plot_constraint_violations(self, vector, axes=None):
"""Returns an axis with the state constraint violations plotted versus
Expand All @@ -448,6 +450,35 @@ def plot_constraint_violations(self, vector, axes=None):
- s : number of unknown time intervals

"""
def extract_mantissa_exponent(number):
"""Returns the mantissa and exponent of a number."""
mantissa, exponent = np.frexp(number)
exponent *= np.log10(2)
mantissa *= 10**(exponent % 1)
exponent = int(exponent)
return mantissa, exponent

bars_per_plot = 10
rotation = -90

# ensure that len(axes) is correct, raise ValuError otherwise
if axes is not None:
warner = False
for i in range(len(axes.ravel())):
if axes.ravel()[i]._sharex is not None:
warner = True
if warner == True:
print('Set sharex=False or remove, it makes no sense here')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We typically don't have print statements in library code because their display can't be controlled. Best to avoid warnings and just raise errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't think we need to warn about things like this. The best approach here is to explain in the docstring what the user must do if they want to pass in their own axes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't think we need to warn about things like this. The best approach here is to explain in the docstring what the user must do if they want to pass in their own axes.

If sharex = True, the program will not interrupt, but the plots are meaningless. Do you mean, I just let this happen, and explain in the method, that it should be set to False or removed? Should this be mentioned in the opty Documentation?
If YES, I have to find out how to do this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean, I just let this happen, and explain in the method, that it should be set to False or removed?

Yes, but I don't think you need to explain anything about sharex. plt.subplots() has a very long list of arguments and it is not our job to explain them all here. You simply have to state to the user the minimal information they need to know to pass in a correct axes object.

len_axes = len(axes.ravel())
len_constr = len(self.collocator.instance_constraints)
if (len_constr <= bars_per_plot) and (len_axes < 2):
raise ValueError('len(axes) must be equal to 2')
elif ((len_constr > bars_per_plot) and
(len_axes < len_constr // bars_per_plot + 2)):
raise ValueError(f'len(axes) must be equal to {len_constr//bars_per_plot+2}')
else:
pass

N = self.collocator.num_collocation_nodes
con_violations = self.con(vector)
state_violations = con_violations[
Expand All @@ -458,23 +489,64 @@ def plot_constraint_violations(self, vector, axes=None):
con_nodes = range(1, self.collocator.num_collocation_nodes)

plot_inst_viols = self.collocator.instance_constraints is not None
num_inst_viols = len(instance_violations)
if num_inst_viols == bars_per_plot:
num_plots = 1
else:
num_plots = num_inst_viols // bars_per_plot + 1

if axes is None:
fig, axes = plt.subplots(1 + plot_inst_viols, squeeze=False,
layout='compressed')
fig, axes = plt.subplots(1 + num_plots, squeeze=False,
layout='compressed', figsize=(8, 2.0*(num_plots+1)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to leave the fig size for the user to create manually. Setting figsize inside a function like this can result in giant figures.


axes = axes.ravel()

axes[0].plot(con_nodes, state_violations.T)
axes[0].set_title('Constraint violations')
axes[0].set_xlabel('Node Number')
axes[0].set_ylabel('EoM violation')

# reduce the instance constrtaints to 2 significant digits and print
# in exponential form
instance_constr_plot = []
for exp1 in self.collocator.instance_constraints:
for a in sm.preorder_traversal(exp1):
if isinstance(a, sm.Float):
value = float(a)
mantissa, exponent = extract_mantissa_exponent(value)
mantissa = round(mantissa, 2)

if exponent != 0:
sympy_value = sm.Symbol(f'{mantissa} \cdot 10^{exponent}')
else:
sympy_value = sm.Symbol(f'{mantissa}')
exp1 = exp1.subs(a, sympy_value)
instance_constr_plot.append(exp1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python has printing controls that work like this:

In [7]: '{:1.3f}'.format(1.2394359593939)
Out[7]: '1.239'

is there a reason to not use something like that instead of literally rounding the numerical values?

Copy link
Contributor Author

@Peter230655 Peter230655 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python has printing controls that work like this:

In [7]: '{:1.3f}'.format(1.2394359593939)
Out[7]: '1.239'

is there a reason to not use something like that instead of literally rounding the numerical values?

Timo suggested this to me, so I used it, since it worked. I had no idea how to do such things.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see now there may not be a straight forward way to use Python printing methods on SymPy Floats in an expression. I thought there should be a setting on the StrPrinter that would control display of decimals, but there doesn't seem to be one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Timo already showed this (I struggle to find his comment), but this is a way to control printing of floats:

In [47]: from mpmath.libmp import prec_to_dps, to_str as mlib_to_str

In [48]: class MyPrinter(sm.StrPrinter):
    ...:     _default_settings = {
    ...:         "order": None,
    ...:         "full_prec": "auto",
    ...:         "sympy_integers": False,
    ...:         "abbrev": False,
    ...:         "perm_cyclic": True,
    ...:         "min": None,
    ...:         "max": None,
    ...:         "dps": None,
    ...:     }
    ...: 
    ...:     def _print_Float(self, expr):
    ...:         prec = expr._prec
    ...:         if prec < 5:
    ...:             dps = 0
    ...:         else:
    ...:             dps = prec_to_dps(expr._prec)
    ...:         if self._settings['dps']:
    ...:             dps = self._settings['dps']
    ...:         if self._settings["full_prec"] is True:
    ...:             strip = False
    ...:         elif self._settings["full_prec"] is False:
    ...:             strip = True
    ...:         elif self._settings["full_prec"] == "auto":
    ...:             strip = self._print_level > 1
    ...:         low = self._settings["min"] if "min" in self._settings else None
    ...:         high = self._settings["max"] if "max" in self._settings else None
    ...:         rv = mlib_to_str(expr._mpf_, dps, strip_zeros=strip, min_fixed=low, max_fixed=high)
    ...:         if rv.startswith('-.0'):
    ...:             rv = '-0.' + rv[3:]
    ...:         elif rv.startswith('.0'):
    ...:             rv = '0.' + rv[2:]
    ...:         if rv.startswith('+'):
    ...:             # e.g., +inf -> inf
    ...:             rv = rv[1:]
    ...:         return rv
    ...: 

In [49]: MyPrinter(settings={'dps': 3}).doprint(f(1.329294))
Out[49]: 'f(1.33)'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, I found Timo's solution which was the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Timo already showed this (I struggle to find his comment), but this is a way to control printing of floats:

In [47]: from mpmath.libmp import prec_to_dps, to_str as mlib_to_str

In [48]: class MyPrinter(sm.StrPrinter):
    ...:     _default_settings = {
    ...:         "order": None,
    ...:         "full_prec": "auto",
    ...:         "sympy_integers": False,
    ...:         "abbrev": False,
    ...:         "perm_cyclic": True,
    ...:         "min": None,
    ...:         "max": None,
    ...:         "dps": None,
    ...:     }
    ...: 
    ...:     def _print_Float(self, expr):
    ...:         prec = expr._prec
    ...:         if prec < 5:
    ...:             dps = 0
    ...:         else:
    ...:             dps = prec_to_dps(expr._prec)
    ...:         if self._settings['dps']:
    ...:             dps = self._settings['dps']
    ...:         if self._settings["full_prec"] is True:
    ...:             strip = False
    ...:         elif self._settings["full_prec"] is False:
    ...:             strip = True
    ...:         elif self._settings["full_prec"] == "auto":
    ...:             strip = self._print_level > 1
    ...:         low = self._settings["min"] if "min" in self._settings else None
    ...:         high = self._settings["max"] if "max" in self._settings else None
    ...:         rv = mlib_to_str(expr._mpf_, dps, strip_zeros=strip, min_fixed=low, max_fixed=high)
    ...:         if rv.startswith('-.0'):
    ...:             rv = '-0.' + rv[3:]
    ...:         elif rv.startswith('.0'):
    ...:             rv = '0.' + rv[2:]
    ...:         if rv.startswith('+'):
    ...:             # e.g., +inf -> inf
    ...:             rv = rv[1:]
    ...:         return rv
    ...: 

In [49]: MyPrinter(settings={'dps': 3}).doprint(f(1.329294))
Out[49]: 'f(1.33)'

He did show this or something similar, but I liked the simple round(a, 3) much better! :-)


if plot_inst_viols:
axes[-1].bar(
range(len(instance_violations)), instance_violations,
tick_label=[sm.latex(s, mode='inline')
for s in self.collocator.instance_constraints])
axes[-1].set_ylabel('Instance')
axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=-10)
for i in range(num_plots):
num_ticks = bars_per_plot
if i == num_plots - 1:
beginn = i * bars_per_plot
endd = num_inst_viols
num_ticks = num_inst_viols % bars_per_plot
if num_inst_viols == bars_per_plot:
num_ticks = bars_per_plot
else:
endd = (i + 1) * bars_per_plot
beginn = i * bars_per_plot

inst_viol = instance_violations[beginn: endd]
inst_constr = instance_constr_plot[beginn: endd]

width = [0.06*num_ticks for _ in range(num_ticks)]
axes[i+1].bar(
range(num_ticks), inst_viol,
tick_label=[sm.latex(s, mode='inline')
for s in inst_constr], width=width)
axes[i+1].set_ylabel('Instance')
axes[i+1].set_xticklabels(axes[i+1].get_xticklabels(),
rotation=rotation)

return axes

Expand Down
Loading