Skip to content

Commit

Permalink
Cleans composite curve.
Browse files Browse the repository at this point in the history
  • Loading branch information
TobiaMarcucci committed Sep 3, 2024
1 parent 65400b4 commit 45fd4ab
Showing 1 changed file with 60 additions and 49 deletions.
109 changes: 60 additions & 49 deletions pybezier/composite_bezier_curve.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,80 @@
import numpy as np
from typing import List, Self
from pybezier.bezier_curve import BezierCurve

class CompositeBezierCurve:

def __init__(self, beziers):
for curve1, curve2 in zip(beziers[:-1], beziers[1:]):
assert np.isclose(curve1.b, curve2.a)
assert curve1.dimension == curve2.dimension
self.beziers = beziers
self.N = len(self.beziers)
self.dimension = beziers[0].dimension
self.a = beziers[0].a
self.b = beziers[-1].b
self.duration = self.b - self.a
self.transition_times = [self.a] + [bez.b for bez in beziers]
def __init__(self, curves):
initial_times = [curve.initial_time for curve in curves[1:]]
final_times = [curve.final_time for curve in curves[:-1]]
if not np.allclose(initial_times, final_times):
raise ValueError("Initial and final times don't match.")
dimensions = [curve.dimension for curve in curves]
if len(set(dimensions)) != 1:
raise ValueError("All the curves must have the same dimension.")
self.curves = curves
self.dimension = curves[0].dimension
self.initial_time = curves[0].initial_time
self.final_time = curves[-1].final_time
self.duration = self.final_time - self.initial_time
self.knot_times = [self.initial_time] + [curve.final_time for curve in curves]

def __iter__(self):
return iter(self.beziers)
def __iter__(self) -> List[BezierCurve]:
return self.curves

def __getitem__(self, i):
return self.beziers[i]
def __getitem__(self, i : int) -> BezierCurve:
return self.curves[i]

def __call__(self, t):
i = self.find_segment(t)
return self[i](t)
def __call__(self, time : float) -> np.array:
segment = self.find_segment(time)
return self[segment](time)

def __len__(self):
return(len(self.beziers))
def __len__(self) -> int:
return(len(self.curves))

def start_point(self):
return self[0].start_point()
def initial_point(self) -> np.array:
return self[0].initial_point()

def end_point(self):
return self[-1].end_point()
def final_point(self) -> np.array:
return self[-1].final_point()

def knot_points(self):
knots = [bez.points[0] for bez in self]
return np.array(knots + [self[-1].points[-1]])
def knot_points(self) -> np.array:
knots = [curve.points[0] for curve in self]
knots.append(self.final_point())
return np.array(knots)

def durations(self):
return np.array([bez.duration for bez in self])
def durations(self) -> np.array:
return np.array([curve.duration for curve in self])

def concatenate(self, curve):
shifted_beziers = [BezierCurve(b.points, b.a + self.duration, b.b + self.duration) for b in curve]
return CompositeBezierCurve(self.beziers + shifted_beziers)
def concatenate(self, composite_curve : Self) -> Self:
shifted_curves = []
for curve in composite_curve:
initial_time = curve.initial_time + self.duration
final_time = curve.final_time + self.duration
shifted_curve = BezierCurve(curve.points, initial_time, final_time)
shifted_curves.append(shifted_curve)
return CompositeBezierCurve(self.curves + shifted_curves)

def derivative(self):
return CompositeBezierCurve([b.derivative() for b in self])
def derivative(self) -> Self:
return CompositeBezierCurve([curve.derivative() for curve in self])

def l2_squared(self):
return sum(bez.l2_squared() for bez in self)
def l2_squared(self) -> float:
return sum(curve.l2_squared() for curve in self)

def plot_components(self, samples=51, **kwargs):
for i, bez in enumerate(self):
legend = True if i ==0 else False
bez.plot_components(samples, legend, **kwargs)
def plot_components(self, n : int = 51, legend : bool = True, **kwargs):
for i, curve in enumerate(self):
curve.plot_components(n, legend, **kwargs)
if i == 0:
legend = False

def plot2d(self, **kwargs):
for bez in self:
bez.plot2d(**kwargs)
def plot_trace_2d(self, **kwargs):
for curve in self:
curve.plot_trace_2d(**kwargs)

def scatter2d(self, **kwargs):
for bez in self:
bez.scatter2d(**kwargs)
def scatter_points_2d(self, **kwargs):
for curve in self:
curve.scatter_points_2d(**kwargs)

def plot_2dpolygons(self, **kwargs):
for bez in self:
bez.plot_2dpolygon(**kwargs)
def plot_control_polytopes_2d(self, **kwargs):
for curve in self:
curve.plot_control_polytope_2d(**kwargs)

0 comments on commit 45fd4ab

Please sign in to comment.