From 45fd4ab2f512e52bce27fe8715d287ba5973c9af Mon Sep 17 00:00:00 2001 From: Tobia Marcucci Date: Tue, 3 Sep 2024 11:21:08 -0700 Subject: [PATCH] Cleans composite curve. --- pybezier/composite_bezier_curve.py | 109 ++++++++++++++++------------- 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/pybezier/composite_bezier_curve.py b/pybezier/composite_bezier_curve.py index 0bb74da..74032ad 100644 --- a/pybezier/composite_bezier_curve.py +++ b/pybezier/composite_bezier_curve.py @@ -1,69 +1,80 @@ import numpy as np +from typing import List, Self from pybezier.bezier_curve import BezierCurve class CompositeBezierCurve: - def __init__(self, beziers): - for curve1, curve2 in zip(beziers[:-1], beziers[1:]): - assert np.isclose(curve1.b, curve2.a) - assert curve1.dimension == curve2.dimension - self.beziers = beziers - self.N = len(self.beziers) - self.dimension = beziers[0].dimension - self.a = beziers[0].a - self.b = beziers[-1].b - self.duration = self.b - self.a - self.transition_times = [self.a] + [bez.b for bez in beziers] + def __init__(self, curves): + initial_times = [curve.initial_time for curve in curves[1:]] + final_times = [curve.final_time for curve in curves[:-1]] + if not np.allclose(initial_times, final_times): + raise ValueError("Initial and final times don't match.") + dimensions = [curve.dimension for curve in curves] + if len(set(dimensions)) != 1: + raise ValueError("All the curves must have the same dimension.") + self.curves = curves + self.dimension = curves[0].dimension + self.initial_time = curves[0].initial_time + self.final_time = curves[-1].final_time + self.duration = self.final_time - self.initial_time + self.knot_times = [self.initial_time] + [curve.final_time for curve in curves] - def __iter__(self): - return iter(self.beziers) + def __iter__(self) -> List[BezierCurve]: + return self.curves - def __getitem__(self, i): - return self.beziers[i] + def __getitem__(self, i : int) -> BezierCurve: + return self.curves[i] - def __call__(self, t): - i = self.find_segment(t) - return self[i](t) + def __call__(self, time : float) -> np.array: + segment = self.find_segment(time) + return self[segment](time) - def __len__(self): - return(len(self.beziers)) + def __len__(self) -> int: + return(len(self.curves)) - def start_point(self): - return self[0].start_point() + def initial_point(self) -> np.array: + return self[0].initial_point() - def end_point(self): - return self[-1].end_point() + def final_point(self) -> np.array: + return self[-1].final_point() - def knot_points(self): - knots = [bez.points[0] for bez in self] - return np.array(knots + [self[-1].points[-1]]) + def knot_points(self) -> np.array: + knots = [curve.points[0] for curve in self] + knots.append(self.final_point()) + return np.array(knots) - def durations(self): - return np.array([bez.duration for bez in self]) + def durations(self) -> np.array: + return np.array([curve.duration for curve in self]) - def concatenate(self, curve): - shifted_beziers = [BezierCurve(b.points, b.a + self.duration, b.b + self.duration) for b in curve] - return CompositeBezierCurve(self.beziers + shifted_beziers) + def concatenate(self, composite_curve : Self) -> Self: + shifted_curves = [] + for curve in composite_curve: + initial_time = curve.initial_time + self.duration + final_time = curve.final_time + self.duration + shifted_curve = BezierCurve(curve.points, initial_time, final_time) + shifted_curves.append(shifted_curve) + return CompositeBezierCurve(self.curves + shifted_curves) - def derivative(self): - return CompositeBezierCurve([b.derivative() for b in self]) + def derivative(self) -> Self: + return CompositeBezierCurve([curve.derivative() for curve in self]) - def l2_squared(self): - return sum(bez.l2_squared() for bez in self) + def l2_squared(self) -> float: + return sum(curve.l2_squared() for curve in self) - def plot_components(self, samples=51, **kwargs): - for i, bez in enumerate(self): - legend = True if i ==0 else False - bez.plot_components(samples, legend, **kwargs) + def plot_components(self, n : int = 51, legend : bool = True, **kwargs): + for i, curve in enumerate(self): + curve.plot_components(n, legend, **kwargs) + if i == 0: + legend = False - def plot2d(self, **kwargs): - for bez in self: - bez.plot2d(**kwargs) + def plot_trace_2d(self, **kwargs): + for curve in self: + curve.plot_trace_2d(**kwargs) - def scatter2d(self, **kwargs): - for bez in self: - bez.scatter2d(**kwargs) + def scatter_points_2d(self, **kwargs): + for curve in self: + curve.scatter_points_2d(**kwargs) - def plot_2dpolygons(self, **kwargs): - for bez in self: - bez.plot_2dpolygon(**kwargs) + def plot_control_polytopes_2d(self, **kwargs): + for curve in self: + curve.plot_control_polytope_2d(**kwargs)