Skip to content

Commit

Permalink
fix to work with pure numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed May 29, 2024
1 parent 4175891 commit b015027
Show file tree
Hide file tree
Showing 17 changed files with 113 additions and 64 deletions.
20 changes: 11 additions & 9 deletions chalk/backend/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,17 @@ def render_cairo_prims(
shape_renderer = ToCairoShape()
for prim in base.accept(ToList(), Ident):
# apply transformation
matrix = tx_to_cairo(prim.transform)
ctx.transform(matrix)
prim.shape.accept(shape_renderer, ctx=ctx, style=prim.style)

# undo transformation
matrix.invert()
ctx.transform(matrix)
prim.style.render(ctx)
ctx.stroke()
for i in range(prim.transform.shape[0]):
matrix = tx_to_cairo(prim.transform[i:i+1])
ctx.transform(matrix)
prim.shape.accept(shape_renderer, ctx=ctx, style=prim.style)

# undo transformation
matrix.invert()
ctx.transform(matrix)

prim.style.render(ctx)
ctx.stroke()


def render(
Expand Down
27 changes: 17 additions & 10 deletions chalk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ def compose(
self, envelope: Envelope, other: Optional[Diagram] = None
) -> Diagram:
other = other if other is not None else Empty()
if isinstance(self, Empty):
return other
if isinstance(other, Empty):
return self
if isinstance(self, Compose) and isinstance(other, Compose):
return Compose(envelope, self.diagrams + other.diagrams)
elif isinstance(self, Compose):
if isinstance(self, Compose) and isinstance(other, Compose):
return Compose(envelope, self.diagrams + [other])
elif isinstance(other, Compose):
return Compose(envelope, [self] + other.diagrams)
Expand Down Expand Up @@ -211,14 +215,6 @@ def from_shape(cls, shape: Shape) -> Primitive:
return cls(shape, Style.empty(), Ident)

def apply_transform(self, t: Affine) -> Primitive:
"""Applies a transform and returns a primitive.
Args:
t (Transform): A transform object.
Returns:
Primitive
"""
new_transform = t @ self.transform
return Primitive(self.shape, self.style, new_transform)

Expand Down Expand Up @@ -246,6 +242,11 @@ class Empty(BaseDiagram):
def accept(self, visitor: DiagramVisitor[A, Any], args: Any) -> A:
return visitor.visit_empty(self, args)

def apply_transform(self, t: Affine) -> Empty:
return Empty()

def apply_style(self, t: Affine) -> Empty:
return Empty()

@dataclass
class Compose(BaseDiagram):
Expand All @@ -268,7 +269,10 @@ class ApplyTransform(BaseDiagram):
def accept(self, visitor: DiagramVisitor[A, Any], args: Any) -> A:
return visitor.visit_apply_transform(self, args)


def apply_transform(self, t: Affine) -> ApplyTransform:
new_transform = t @ self.transform
return ApplyTransform(new_transform, self.diagram)

@dataclass
class ApplyStyle(BaseDiagram):
"""ApplyStyle class."""
Expand All @@ -279,6 +283,9 @@ class ApplyStyle(BaseDiagram):
def accept(self, visitor: DiagramVisitor[A, Any], args: Any) -> A:
return visitor.visit_apply_style(self, args)

def apply_style(self, style: Style) -> ApplyStyle:
new_style = style.merge(self.style)
return ApplyStyle(new_style, self.diagram)

@dataclass
class ApplyName(BaseDiagram):
Expand Down
6 changes: 3 additions & 3 deletions chalk/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def center(self) -> P2_t:
@property
def width(self) -> Scalar:
assert not self.is_empty
return (self(tx.unit_x) + self(-tx.unit_x)).reshape()
return (self(tx.unit_x) + self(-tx.unit_x)).reshape(())

@property
def height(self) -> Scalar:
assert not self.is_empty
return (self(tx.unit_y) + self(-tx.unit_y)).reshape()
return (self(tx.unit_y) + self(-tx.unit_y)).reshape(())

def apply_transform(self, t: Affine) -> Envelope:
if self.is_empty:
Expand All @@ -87,7 +87,7 @@ def wrapped(v: V2_t) -> SignedDistance:

# Translation
diff = tx.dot((u / tx.dot(v, v)), v)
return (after_linear - diff).reshape()
return (after_linear - diff).reshape(())

return Envelope(wrapped)

Expand Down
6 changes: 3 additions & 3 deletions chalk/shapes/arc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def apply_transform(self, t: Affine) -> Segment:
return Segment(t @ self.transform, self.angles)

def __add__(self, other: Segment) -> Segment:
return Segment(tx.np.concat([self.transform, other.transform], axis=0),
tx.np.concat([self.angles, other.angles], axis=0))
return Segment(tx.np.concatenate([self.transform, other.transform], axis=0),
tx.np.concatenate([self.angles, other.angles], axis=0))

@property
def t(self) -> Affine:
Expand Down Expand Up @@ -91,7 +91,7 @@ def arc_between(
# return LocatedSegment(q - p, p)
d = tx.length(q - p)
# Determine the arc's angle θ and its radius r
θ = tx.np.acos((d**2 - 4.0 * h**2) / (d**2 + 4.0 * h**2))
θ = tx.np.arccos((d**2 - 4.0 * h**2) / (d**2 + 4.0 * h**2))
r = d / (2 * tx.np.sin(θ))

if height > 0:
Expand Down
42 changes: 26 additions & 16 deletions chalk/trail.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ def hrule(length: Floating) -> Trail:
def vrule(length: Floating) -> Trail:
return arc.seg(length * tx.unit_y)

_rectangle = None
@staticmethod
def rectangle(width: Floating, height: Floating) -> Trail:
t = arc.seg(tx.unit_x * width) + arc.seg(tx.unit_y * height)
return (t + t.rotate_by(0.5)).close()
if Trail._rectangle is None:
t = arc.seg(tx.unit_x) + arc.seg(tx.unit_y)
Trail._rectangle = (t + t.rotate_by(0.5)).close()
return Trail._rectangle.scale_x(width).scale_y(height)

@staticmethod
def rounded_rectangle(width: Floating, height: Floating, radius: Floating) -> Trail:
Expand All @@ -161,28 +164,35 @@ def rounded_rectangle(width: Floating, height: Floating, radius: Floating) -> Tr
) + arc.seg(0.01 * tx.unit_y)
return trail.close()

_circle = {}
@staticmethod
def circle(radius: Floating = 1.0, clockwise: bool = True) -> Trail:
sides = 4
dangle = -90
rotate_by = 1
if not clockwise:
dangle = 90
rotate_by *= -1
return (
Trail.concat(
if clockwise in Trail._circle:
return Trail._circle[clockwise]
else:
sides = 4
dangle = -90
rotate_by = 1
if not clockwise:
dangle = 90
rotate_by *= -1
Trail._circle[clockwise] = Trail.concat(
[
arc.arc_seg_angle(0, dangle).rotate_by(rotate_by * i / sides)
for i in range(sides)
]
)
.close()
).close()
return (
Trail._circle[clockwise]
.scale(radius)
)

_polygon = {}
@staticmethod
def regular_polygon(sides: int, side_length: Floating) -> Trail:
edge = Trail.hrule(side_length)
return Trail.concat(
edge.rotate_by(i / sides) for i in range(sides)
).close()
if sides not in Trail._polygon:
edge = Trail.hrule(1)
Trail._polygon[sides] = Trail.concat(
edge.rotate_by(i / sides) for i in range(sides)
).close()
return Trail._polygon[sides].scale(side_length)
55 changes: 43 additions & 12 deletions chalk/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from dataclasses import dataclass

from jaxtyping import Float, Bool, Array
# import numpy as np
import jax.numpy as np
if True:
ops = None
import numpy as np
else:
from jax import ops
#import jax.numpy as np


Affine = Float[Array, "#B 3 3"]
Expand All @@ -19,6 +23,21 @@
def ftos(f: Floating) -> Scalars:
return np.array(f, dtype=float).reshape(-1)

def index_update(arr, index, values):
"""
Update the array `arr` at the given `index` with `values` and return the updated array.
Supports both NumPy and JAX arrays.
"""

if ops is None:
# If the array is a NumPy array
new_arr = arr.copy()
new_arr[index] = values
return new_arr
else:
# If the array is a JAX array
return ops.index_update(arr, index, values)


def V2(x: Floating, y: Floating) -> V2_t:
x, y, o = ftos(x), ftos(y), ftos(0.)
Expand All @@ -44,7 +63,7 @@ def angle(v: V2_t) -> Scalars:
return from_radians(rad(v))

def rad(v: V2_t) -> Scalars:
return np.atan2(v[..., 1, 0], v[..., 0, 0])
return np.arctan2(v[..., 1, 0], v[..., 0, 0])

def perpendicular(v: V2_t) -> V2_t:
return np.hstack([-v[..., 1, 0], v[..., 0, 0], v[..., 2, 0]])
Expand All @@ -66,31 +85,38 @@ def cross(v1: V2_t, v2: V2_t) -> Scalars:
return np.cross(v1, v2)

def to_point(v: V2_t) -> P2_t:
return v.at[..., 2, 0].set(1)
index = (Ellipsis, 2, 0)
return index_update(v, index, 1)

def polar(angle: Floating, length: Floating = 1.0) -> V2_t:
rad = to_radians(angle)
x, y = np.cos(rad), np.sin(rad)
return V2(x * length, y * length)

def scale(vec: V2_t) -> Affine:
return ident.at[..., np.arange(2), np.arange(2)].set(vec[..., :2, 0])
index = (Ellipsis, np.arange(0, 2), np.arange(0, 2))
return index_update(ident, index, vec[..., :2, 0])

def translation(vec: V2_t) -> Affine:
return ident.repeat(vec.shape[0], axis=0).at[..., :2, 2].set(vec[..., :2, 0])
index = (Ellipsis, slice(0, 2), 2)
base = ident.repeat(vec.shape[0], axis=0)
return index_update(base, index, vec[..., :2, 0])
#return .at[..., :2, 2].set(vec[..., :2, 0])

def get_translation(aff: Affine) -> V2_t:
return np.zeros([aff.shape[0], 3, 1]).at[..., :2, 0].set(aff[..., :2, 2])
index = (Ellipsis, slice(0, 2), 0)
return index_update(np.zeros([aff.shape[0], 3, 1]), index,
aff[..., :2, 2])

def rotation(rad: Floating) -> Affine:
ca, sa = np.cos(rad), np.sin(rad)
up = np.stack([ca, sa, -sa, ca], axis=-1).reshape(-1, 2, 2)
return ident.at[..., :2, :2].set(up)
index = (Ellipsis, slice(0, 2), slice(0, 2))
return index_update(ident, index, up)

def inv(aff: Affine) -> Affine:
det = np.linalg.det(aff)
assert np.all(np.abs(det) > 1e-5), f"{det} {aff}"
print(aff.shape)
idet = 1.0 / det
sa, sb, sc = aff[..., 0, 0], aff[..., 0, 1], aff[..., 0, 2]
sd, se, sf = aff[..., 1, 0], aff[..., 1, 1], aff[..., 1, 2]
Expand All @@ -108,13 +134,18 @@ def to_radians(θ: Floating) -> Scalars:
return (ftos(θ) / 180) * math.pi

def remove_translation(aff: Affine) -> Affine:
return aff.at[..., :1, 2].set(0)
#aff.at[..., :1, 2].set(0)
index = (Ellipsis, slice(0, 1), 2)
return index_update(aff, index, 0)

def remove_linear(aff: Affine) -> Affine:
return aff.at[..., :2, :2].set(np.eye(2))
# aff.at[..., :2, :2].set(np.eye(2))
index = (Ellipsis, slice(0, 2), slice(0, 2))
return index_update(aff, index, np.eye(2))

def transpose_translation(aff: Affine) -> Affine:
return aff.at[..., :2, :2].set(aff[..., :2, :2].transpose(0, 2, 1))
index = (Ellipsis, slice(0, 2), slice(0, 2))
return index_update(aff, index, aff[..., :2, :2].transpose(0, 2, 1))

class Transformable:
"""Transformable class."""
Expand Down
11 changes: 5 additions & 6 deletions examples/intro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jaxtyping import install_import_hook
with install_import_hook("chalk", "typeguard.typechecked"):
import chalk # Any module imported inside this `with` block, whose
# from jaxtyping import install_import_hook
# with install_import_hook("chalk", "typeguard.typechecked"):
# import chalk # Any module imported inside this `with` block, whose

from colour import Color
from chalk import *
Expand All @@ -9,11 +9,10 @@
papaya = Color("#ff9700")
blue = Color("#005FDB")


path = "examples/output/intro-01.png"
d = circle(0.5).fill_color(papaya)
d.render(path, height=64)

print("first")

# # Alternative, render as svg
path = "examples/output/intro-01.svg"
Expand Down Expand Up @@ -46,7 +45,7 @@
# d.render_pdf(path)

path = "examples/output/intro-04.png"

print("sierpinsky")
def sierpinski(n: int, size: int) -> Diagram:
if n <= 1:
return triangle(size)
Expand Down
Binary file modified examples/output/hanoi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/output/hanoi.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/output/intro-01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/output/intro-01.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/output/intro-02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/output/intro-02.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/output/intro-03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit b015027

Please sign in to comment.