-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plot constraint violations #218
Changes from 6 commits
5c2cbef
87ad67e
ce2e61f
933368a
1a509e7
c3f1eae
53b5661
71487b1
49e73fe
a0f2ba7
458909e
7e7921c
1af48c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
#!/usr/bin/env python | ||
|
||
# | ||
from functools import wraps | ||
import logging | ||
|
||
|
@@ -422,6 +422,8 @@ def plot_trajectories(self, vector, axes=None): | |
|
||
return axes | ||
|
||
|
||
|
||
@_optional_plt_dep | ||
def plot_constraint_violations(self, vector, axes=None): | ||
"""Returns an axis with the state constraint violations plotted versus | ||
|
@@ -448,6 +450,35 @@ def plot_constraint_violations(self, vector, axes=None): | |
- s : number of unknown time intervals | ||
|
||
""" | ||
def extract_mantissa_exponent(number): | ||
"""Returns the mantissa and exponent of a number.""" | ||
mantissa, exponent = np.frexp(number) | ||
exponent *= np.log10(2) | ||
mantissa *= 10**(exponent % 1) | ||
exponent = int(exponent) | ||
return mantissa, exponent | ||
|
||
bars_per_plot = 10 | ||
rotation = -90 | ||
|
||
# ensure that len(axes) is correct, raise ValuError otherwise | ||
if axes is not None: | ||
warner = False | ||
for i in range(len(axes.ravel())): | ||
if axes.ravel()[i]._sharex is not None: | ||
warner = True | ||
if warner == True: | ||
print('Set sharex=False or remove, it makes no sense here') | ||
len_axes = len(axes.ravel()) | ||
len_constr = len(self.collocator.instance_constraints) | ||
if (len_constr <= bars_per_plot) and (len_axes < 2): | ||
raise ValueError('len(axes) must be equal to 2') | ||
elif ((len_constr > bars_per_plot) and | ||
(len_axes < len_constr // bars_per_plot + 2)): | ||
raise ValueError(f'len(axes) must be equal to {len_constr//bars_per_plot+2}') | ||
else: | ||
pass | ||
|
||
N = self.collocator.num_collocation_nodes | ||
con_violations = self.con(vector) | ||
state_violations = con_violations[ | ||
|
@@ -458,23 +489,64 @@ def plot_constraint_violations(self, vector, axes=None): | |
con_nodes = range(1, self.collocator.num_collocation_nodes) | ||
|
||
plot_inst_viols = self.collocator.instance_constraints is not None | ||
num_inst_viols = len(instance_violations) | ||
if num_inst_viols == bars_per_plot: | ||
num_plots = 1 | ||
else: | ||
num_plots = num_inst_viols // bars_per_plot + 1 | ||
|
||
if axes is None: | ||
fig, axes = plt.subplots(1 + plot_inst_viols, squeeze=False, | ||
layout='compressed') | ||
fig, axes = plt.subplots(1 + num_plots, squeeze=False, | ||
layout='compressed', figsize=(8, 2.0*(num_plots+1))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to leave the fig size for the user to create manually. Setting figsize inside a function like this can result in giant figures. |
||
|
||
axes = axes.ravel() | ||
|
||
axes[0].plot(con_nodes, state_violations.T) | ||
axes[0].set_title('Constraint violations') | ||
axes[0].set_xlabel('Node Number') | ||
axes[0].set_ylabel('EoM violation') | ||
|
||
# reduce the instance constrtaints to 2 significant digits and print | ||
# in exponential form | ||
instance_constr_plot = [] | ||
for exp1 in self.collocator.instance_constraints: | ||
for a in sm.preorder_traversal(exp1): | ||
if isinstance(a, sm.Float): | ||
value = float(a) | ||
mantissa, exponent = extract_mantissa_exponent(value) | ||
mantissa = round(mantissa, 2) | ||
|
||
if exponent != 0: | ||
sympy_value = sm.Symbol(f'{mantissa} \cdot 10^{exponent}') | ||
else: | ||
sympy_value = sm.Symbol(f'{mantissa}') | ||
exp1 = exp1.subs(a, sympy_value) | ||
instance_constr_plot.append(exp1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python has printing controls that work like this:
is there a reason to not use something like that instead of literally rounding the numerical values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Timo suggested this to me, so I used it, since it worked. I had no idea how to do such things. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I see now there may not be a straight forward way to use Python printing methods on SymPy Floats in an expression. I thought there should be a setting on the StrPrinter that would control display of decimals, but there doesn't seem to be one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe Timo already showed this (I struggle to find his comment), but this is a way to control printing of floats:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hah, I found Timo's solution which was the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
He did show this or something similar, but I liked the simple round(a, 3) much better! :-) |
||
|
||
if plot_inst_viols: | ||
axes[-1].bar( | ||
range(len(instance_violations)), instance_violations, | ||
tick_label=[sm.latex(s, mode='inline') | ||
for s in self.collocator.instance_constraints]) | ||
axes[-1].set_ylabel('Instance') | ||
axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=-10) | ||
for i in range(num_plots): | ||
num_ticks = bars_per_plot | ||
if i == num_plots - 1: | ||
beginn = i * bars_per_plot | ||
endd = num_inst_viols | ||
num_ticks = num_inst_viols % bars_per_plot | ||
if num_inst_viols == bars_per_plot: | ||
num_ticks = bars_per_plot | ||
else: | ||
endd = (i + 1) * bars_per_plot | ||
beginn = i * bars_per_plot | ||
|
||
inst_viol = instance_violations[beginn: endd] | ||
inst_constr = instance_constr_plot[beginn: endd] | ||
|
||
width = [0.06*num_ticks for _ in range(num_ticks)] | ||
axes[i+1].bar( | ||
range(num_ticks), inst_viol, | ||
tick_label=[sm.latex(s, mode='inline') | ||
for s in inst_constr], width=width) | ||
axes[i+1].set_ylabel('Instance') | ||
axes[i+1].set_xticklabels(axes[i+1].get_xticklabels(), | ||
rotation=rotation) | ||
|
||
return axes | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We typically don't have print statements in library code because their display can't be controlled. Best to avoid warnings and just raise errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't think we need to warn about things like this. The best approach here is to explain in the docstring what the user must do if they want to pass in their own axes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If sharex = True, the program will not interrupt, but the plots are meaningless. Do you mean, I just let this happen, and explain in the method, that it should be set to False or removed? Should this be mentioned in the opty Documentation?
If YES, I have to find out how to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but I don't think you need to explain anything about sharex.
plt.subplots()
has a very long list of arguments and it is not our job to explain them all here. You simply have to state to the user the minimal information they need to know to pass in a correct axes object.