Skip to content

Commit

Permalink
multi envelope
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed May 30, 2024
1 parent b015027 commit ae6f310
Show file tree
Hide file tree
Showing 22 changed files with 159 additions and 97 deletions.
2 changes: 1 addition & 1 deletion chalk/backend/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def visit_path(
style.fill_opacity_ = 0
for loc_trail in path.loc_trails:
p = loc_trail.location
ctx.move_to(p[0, 0, 0], p[0, 1,0])
#ctx.move_to(p[0, 0, 0], p[0, 1,0])
segments = loc_trail.located_segments()
self.render_segment(segments, ctx)
if loc_trail.trail.closed:
Expand Down
2 changes: 1 addition & 1 deletion chalk/backend/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def visit_apply_name(
self, diagram: ApplyName, style: Style = EMPTY_STYLE
) -> BaseElement:
g = self.dwg.g()
g.add(diagram.diagram.accept(self, style))
g.add(diagram.diagram.accept(self, style).data)
return Maybe(g)


Expand Down
6 changes: 6 additions & 0 deletions chalk/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
def with_envelope(self: Diagram, other: Diagram) -> Diagram:
return self.compose(other.get_envelope())

def close_envelope(self: Diagram) -> Diagram:
from chalk.core import Primitive

env = self.get_envelope()
return self.compose(Envelope.from_bounding_box(env.to_bounding_box()))


# with_trace, phantom,

Expand Down
8 changes: 7 additions & 1 deletion chalk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def _style(self, style: Style) -> Diagram:
def compose(
self, envelope: Envelope, other: Optional[Diagram] = None
) -> Diagram:
other = other if other is not None else Empty()
if other is None and isinstance(self, Compose):
return Compose(envelope, self.diagrams)
if other is None and isinstance(self, Compose):
return Compose(envelope, [self])

other = other if other is not None else Empty()
if isinstance(self, Empty):
return other
if isinstance(other, Empty):
Expand All @@ -90,6 +95,7 @@ def named(self, name: Name) -> Diagram:

# Combinators
with_envelope = chalk.combinators.with_envelope
close_envelope = chalk.combinators.close_envelope
juxtapose = chalk.combinators.juxtapose
juxtapose_snug = chalk.combinators.juxtapose_snug
beside_snug = chalk.combinators.beside_snug
Expand Down
54 changes: 36 additions & 18 deletions chalk/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
Affine,
BoundingBox,
Transformable,
Scalar
Scalars
)
import jax.numpy as np
import chalk.transform as tx
from chalk.visitor import DiagramVisitor

Expand All @@ -22,52 +21,66 @@
from chalk.types import Diagram


SignedDistance = tx.Scalar
# quantize = tx.np.linspace(-100, 100, 1000)
# mult = tx.np.array([1000, 1, 0])[None]


class Envelope(Transformable, Monoid):
total_env = 0
def __init__(
self, f: Callable[[V2_t], SignedDistance], is_empty: bool = False
self, f: Callable[[V2_t], Scalars], is_empty: bool = False
):
self.f = f
self.is_empty = is_empty
self.cache = {}

def __call__(self, direction: V2_t) -> SignedDistance:
def __call__(self, direction: V2_t) -> Scalars:
Envelope.total_env += 1
assert not self.is_empty
return self.f(direction)
# v = (mult @ tx.np.digitize(direction, quantize)).reshape(1)[0]
# if v not in self.cache:
# self.cache[v] = self.f(direction)
# return self.cache[v]


# Monoid
@staticmethod
def empty() -> Envelope:
return Envelope(lambda v: np.array(0.0), is_empty=True)
return Envelope(lambda v: tx.np.array(0.0), is_empty=True)

def __add__(self, other: Envelope) -> Envelope:
if self.is_empty:
return other
if other.is_empty:
return self
return Envelope(
lambda direction: np.maximum(self(direction), other(direction))
lambda direction: tx.np.maximum(self(direction), other(direction))
)

all_dir = tx.np.concatenate([tx.unit_x, -tx.unit_x, tx.unit_y, -tx.unit_y], axis=0)
@property
def center(self) -> P2_t:
if self.is_empty:
return tx.origin
# Get all the directions
d = self(Envelope.all_dir)
return P2(
(-self(-tx.unit_x) + self(tx.unit_x)) / 2,
(-self(-tx.unit_y) + self(tx.unit_y)) / 2,
(-d[1] + d[0]) / 2,
(-d[3] + d[2]) / 2,
)

@property
def width(self) -> Scalar:
assert not self.is_empty
return (self(tx.unit_x) + self(-tx.unit_x)).reshape(())
d = self(Envelope.all_dir[:2])
return d.sum()

@property
def height(self) -> Scalar:
assert not self.is_empty
return (self(tx.unit_y) + self(-tx.unit_y)).reshape(())
d = self(Envelope.all_dir[2:])
return d.sum()

def apply_transform(self, t: Affine) -> Envelope:
if self.is_empty:
Expand All @@ -76,18 +89,17 @@ def apply_transform(self, t: Affine) -> Envelope:
inv_t = tx.inv(rt)
trans_t = tx.transpose_translation(rt)
u: V2_t = -tx.get_translation(t)
def wrapped(v: V2_t) -> SignedDistance:
def wrapped(v: V2_t) -> Scalars:
# Linear
vi = inv_t @ v
v_prim = tx.norm(trans_t @ v)

inner = self(v_prim)
d = tx.dot(v_prim, vi)
after_linear = inner / d

# Translation
diff = tx.dot((u / tx.dot(v, v)), v)
return (after_linear - diff).reshape(())
diff = tx.dot((u / tx.dot(v, v)[..., None, None]), v)
return after_linear - diff

return Envelope(wrapped)

Expand All @@ -100,12 +112,18 @@ def envelope_v(self, v: V2_t) -> V2_t:

@staticmethod
def from_bounding_box(box: BoundingBox) -> Envelope:
def wrapped(d: tx.V2_t) -> SignedDistance:
v = box.rotate_rad(tx.rad(d)).br[0]
return v / tx.length(d)
def wrapped(d: tx.V2_t) -> Scalars:
v = box.rotate_rad(tx.rad(d)).br[:, 0, 0]
r = v / tx.length(d)
return r

return Envelope(wrapped)

def to_bounding_box(self: Envelope) -> BoundingBox:
d = self(Envelope.all_dir)
return tx.BoundingBox(V2(-d[1], -d[3]), V2(d[0], d[2]))


# @staticmethod
# def from_circle(radius: tx.Floating) -> Envelope:
# def wrapped(d: V2_t) -> SignedDistance:
Expand Down
24 changes: 12 additions & 12 deletions chalk/shapes/arc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def center(self) -> P2_t:
return self.t @ tx.P2(0, 0)

def seg(offset: V2_t) -> Trail:
return arc_seg(offset, 1e-3)
return arc_seg(offset, 1e-4)

def is_in_mod_360(x: Degrees, a: Degrees, b: Degrees) -> tx.Mask:
"""Checks if x ∈ [a, b] mod 360. See the following link for an
Expand Down Expand Up @@ -111,27 +111,27 @@ def arc_between(
ret = tx.translation(p) @ tx.rotation(-tx.rad(diff)) @ tx.translation(tx.V2(d / 2, dy)) @ tx.rotation(φ) @ tx.scale(tx.V2(r, r))
return Segment(ret, angles)

def arc_envelope(angle_offset: Float[Array, "#B 2"]) -> Callable[[V2_t], tx.Scalars]:
def arc_envelope(angle_offset: Float[Array, "#B 2"]):
"Trace is done as simple arc and transformed"
angle0_deg = angle_offset[..., 0]
angle1_deg = angle0_deg + angle_offset[..., 1]

is_circle = abs(angle0_deg - angle1_deg) >= 360
low = tx.np.minimum(angle0_deg, angle1_deg)
high = tx.np.maximum(angle0_deg, angle1_deg)
check = (low - high) % 360

v1 = tx.polar(angle0_deg)
v2 = tx.polar(angle1_deg)
def wrapped(d: V2_t) -> tx.Scalars:
is_circle = abs(angle0_deg - angle1_deg) >= 360
q = tx.np.where(
is_circle | is_in_mod_360(
tx.angle(d),
tx.np.minimum(angle0_deg, angle1_deg),
tx.np.maximum(angle0_deg, angle1_deg),
),

def wrapped(d):
return tx.np.where(
(is_circle | (((tx.angle_(d) - high) % 360) > check)),
# Case 1: P2 at arc
1 / tx.length(d),
1 / tx.length_(d),
# Case 2: P2 outside of arc
tx.np.maximum(tx.dot(d, v1), tx.dot(d, v2))
)
return q
return wrapped

def arc_seg(q: V2_t, height: tx.Floating) -> Trail:
Expand Down
27 changes: 15 additions & 12 deletions chalk/trail.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
Transformable,
Floating
)
import jax
import jax.numpy as np
from chalk.types import Diagram, Enveloped, Traceable, TrailLike


Expand Down Expand Up @@ -46,19 +44,22 @@ def get_envelope(self) -> Envelope:
inv_t = tx.inv(rt)
trans_t = tx.transpose_translation(rt)
u: V2_t = -tx.get_translation(t)
def wrapped(v: V2_t):
def wrapped(v: V2_t) -> Scalars:
# Linear
v = v[:, None, :, :]

vi = inv_t @ v
v_prim = tx.norm(trans_t @ v)
inp = (trans_t @ v)
v_prim = tx.norm_(inp)
inner = env(v_prim)
d = tx.dot(v_prim, vi)
after_linear = inner / d


# Translation
diff = tx.dot((u / tx.dot(v, v)), v)
diff = tx.dot((u / tx.dot(v, v)[..., None, None]), v)
out = after_linear - diff
return tx.np.max(out, axis=0)
return tx.np.max(out, axis=1)

return Envelope(wrapped)

Expand Down Expand Up @@ -92,7 +93,7 @@ class Trail(Monoid, Transformable, TrailLike):
# Monoid
@staticmethod
def empty() -> Trail:
return Trail(Segment(np.array([]), np.array([])), False)
return Trail(Segment(tx.np.array([]), tx.np.array([])), False)

def __add__(self, other: Trail) -> Trail:
assert not (self.closed or other.closed), "Cannot add closed trails"
Expand All @@ -115,7 +116,7 @@ def close(self) -> Trail:

def points(self) -> Float[Array, "B 3"]:
q = self.segments.q
return (np.cumsum(q, axis=0) - q).at[..., 2, 0].set(1)
return tx.to_point(tx.np.cumsum(q, axis=0) - q)

def at(self, p: P2_t) -> Located:
return Located(self, p)
Expand All @@ -131,10 +132,12 @@ def centered(self) -> Located:
return self.at(-sum(self.points(), tx.P2(0, 0)) / self.segments.t.shape[0])

# # Misc. Constructor
# @staticmethod
# def from_offsets(offsets: List[V2_t], closed: bool = False) -> Trail:
# return Trail([Segment(off) for off in offsets], closed)

@staticmethod
def from_offsets(offsets: List[V2_t], closed: bool = False) -> Trail:
trail = Trail.concat([arc.seg(off) for off in offsets])
if closed:
trail = trail.close()
return trail

@staticmethod
def hrule(length: Floating) -> Trail:
Expand Down
29 changes: 21 additions & 8 deletions chalk/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
else:
from jax import ops
#import jax.numpy as np
import jax.numpy as np


Affine = Float[Array, "#B 3 3"]
Expand All @@ -21,7 +21,7 @@
Floating = Union[Scalars, Scalar, float, int]
Mask = Bool[Array, "#B"]
def ftos(f: Floating) -> Scalars:
return np.array(f, dtype=float).reshape(-1)
return np.array(f, dtype=np.double).reshape(-1)

def index_update(arr, index, values):
"""
Expand All @@ -36,7 +36,7 @@ def index_update(arr, index, values):
return new_arr
else:
# If the array is a JAX array
return ops.index_update(arr, index, values)
return arr.at[index].set(values)


def V2(x: Floating, y: Floating) -> V2_t:
Expand All @@ -56,12 +56,23 @@ def norm(v: V2_t) -> V2_t:
def length(v: V2_t) -> Scalars:
return np.sqrt(length2(v))

# These are untyped variants that are needed for multibatch cases.
def length_(v):
return length(v.reshape(-1, 3, 1)).reshape(*v.shape[:2])

def angle_(v):
return angle(v.reshape(-1, 3, 1)).reshape(*v.shape[:2])

def norm_(v):
return norm(v.reshape(-1, 3, 1)).reshape(*v.shape)

def length2(v: V2_t) -> Scalars:
return (v * v)[..., :2, 0].sum(-1)

def angle(v: V2_t) -> Scalars:
return from_radians(rad(v))


def rad(v: V2_t) -> Scalars:
return np.arctan2(v[..., 1, 0], v[..., 0, 0])

Expand All @@ -78,13 +89,13 @@ def make_affine(a: Floating, b:Floating, c:Floating, d:Floating, e:Floating, f:F

ident = make_affine(1., 0., 0., 0., 1., 0.)

def dot(v1: V2_t, v2: V2_t) -> Scalars:
def dot(v1, v2):
return (v1 * v2).sum(-1).sum(-1)

def cross(v1: V2_t, v2: V2_t) -> Scalars:
return np.cross(v1, v2)

def to_point(v: V2_t) -> P2_t:
def to_point(v: Float[Array, "... 1"]) -> P2_t:
index = (Ellipsis, 2, 0)
return index_update(v, index, 1)

Expand All @@ -109,10 +120,12 @@ def get_translation(aff: Affine) -> V2_t:
aff[..., :2, 2])

def rotation(rad: Floating) -> Affine:
rad = ftos(rad)
ca, sa = np.cos(rad), np.sin(rad)
up = np.stack([ca, sa, -sa, ca], axis=-1).reshape(-1, 2, 2)
base = ident.repeat(rad.shape[0], axis=0)
index = (Ellipsis, slice(0, 2), slice(0, 2))
return index_update(ident, index, up)
return index_update(base, index, up)

def inv(aff: Affine) -> Affine:
det = np.linalg.det(aff)
Expand Down Expand Up @@ -213,11 +226,11 @@ def apply_transform(self, t: Affine) -> Self:

@property
def width(self) -> Scalar:
return (self.br - self.tl)[0]
return (self.br - self.tl)[0, 0, 0]

@property
def height(self) -> Scalar:
return (self.br - self.tl)[1]
return (self.br - self.tl)[0, 1, 0]

origin = P2(0, 0)
unit_x = V2(1, 0)
Expand Down
Loading

0 comments on commit ae6f310

Please sign in to comment.