Skip to content

Commit

Permalink
build expressions using sympy.printing.pretty.stringpict
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Nov 10, 2024
1 parent f91168f commit d92fffb
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 89 deletions.
2 changes: 1 addition & 1 deletion mathics/eval/makeboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def make_output_form(expr, evaluation, form):
evaluation.definitions.get_ownvalues("System`$Use2DOutputForm")[0].replace
is SymbolTrue
)
text2d = expression_to_2d_text(expr, evaluation, form, **{"2d": use_2d}).text
text2d = str(expression_to_2d_text(expr, evaluation, form, **{"2d": use_2d}))

if "\n" in text2d:
text2d = "\n" + text2d
Expand Down
162 changes: 103 additions & 59 deletions mathics/format/pane_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,50 @@

from typing import List, Optional, Union

from sympy.printing.pretty.pretty_symbology import line_width, vobj
from sympy.printing.pretty.stringpict import prettyForm, stringPict

class TextBlock:

class TextBlock(prettyForm):
def __init__(self, text, base=0, padding=0, height=1, width=0):
super().__init__(text, base)
assert padding == 0
assert height == 1
assert width == 0

def root(self, n=None):
"""Produce a nice root symbol.
Produces ugly results for big n inserts.
"""
# XXX not used anywhere
# XXX duplicate of root drawing in pretty.py
# put line over expression
result = TextBlock(*self.above("_" * self.width()))
# construct right half of root symbol
height = self.height()
slash = "\n".join(" " * (height - i - 1) + "/" + " " * i for i in range(height))
slash = stringPict(slash, height - 1)
# left half of root symbol
if height > 2:
downline = stringPict("\\ \n \\", 1)
else:
downline = stringPict("\\")
# put n on top, as low as possible
if n is not None and n.width() > downline.width():
downline = downline.left(" " * (n.width() - downline.width()))
downline = downline.above(n)
# build root symbol
root = TextBlock(*downline.right(slash))
# glue it on at the proper height
# normally, the root symbel is as high as self
# which is one less than result
# this moves the root symbol one down
# if the root became higher, the baseline has to grow too
root.baseline = result.baseline - result.height() + root.height()
return result.left(root)


class OldTextBlock:
lines: List[str]
width: int
height: int
Expand Down Expand Up @@ -37,7 +79,7 @@ def _build_attributes(lines, width=0, height=0, base=0):

return (lines, width, height, base)

def __init__(self, text, padding=0, base=0, height=1, width=0):
def __init__(self, text, base=0, padding=0, height=1, width=0):
if isinstance(text, str):
if text == "":
lines = []
Expand All @@ -63,6 +105,9 @@ def text(self):
def text(self, value):
raise TypeError("TextBlock is inmutable")

def __str__(self):
return self.text

def __repr__(self):
return self.text

Expand Down Expand Up @@ -166,45 +211,23 @@ def stack(self, top, align: str = "c"):


def _draw_integral_symbol(height: int) -> TextBlock:
return TextBlock(
(" /+ \n" + "\n".join(height * [" | "]) + "\n+/ "), base=int((height + 1) / 2)
)
if height % 2 == 0:
height = height + 1
result = TextBlock(vobj("int", height), (height - 1) // 2)
return result


def bracket(inner: Union[str, TextBlock]) -> TextBlock:
if isinstance(inner, str):
inner = TextBlock(inner)
height = inner.height
if height == 1:
left_br, right_br = TextBlock("["), TextBlock("]")
else:
left_br = TextBlock(
"+-\n" + "\n".join((height) * ["| "]) + "\n+-", base=inner.base + 1
)
right_br = TextBlock(
"-+ \n" + "\n".join((height) * [" |"]) + "\n-+", base=inner.base + 1
)
return left_br + inner + right_br

return TextBlock(*inner.parens("[", "]"))


def curly_braces(inner: Union[str, TextBlock]) -> TextBlock:
if isinstance(inner, str):
inner = TextBlock(inner)
height = inner.height
if height == 1:
left_br, right_br = TextBlock("{"), TextBlock("}")
else:
half_height = max(1, int((height - 3) / 2))
half_line = "\n".join(half_height * [" |"])
left_br = TextBlock(
"\n".join([" /", half_line, "< ", half_line, " \\"]), base=half_height + 1
)
half_line = "\n".join(half_height * ["| "])
right_br = TextBlock(
"\n".join(["\\ ", half_line, " >", half_line, "/ "]), base=half_height + 1
)

return left_br + inner + right_br
return TextBlock(*inner.parens("{", "}"))


def draw_vertical(
Expand Down Expand Up @@ -233,11 +256,7 @@ def fraction(a: Union[TextBlock, str], b: Union[TextBlock, str]) -> TextBlock:
a = TextBlock(a)
if isinstance(b, str):
b = TextBlock(b)
width = max(b.width, a.width)
frac_bar = TextBlock(width * "-")
result = frac_bar.stack(a)
result = b.stack(result)
result.base = b.height
return a / b
return result


Expand Down Expand Up @@ -359,8 +378,8 @@ def integral_indefinite(
if isinstance(integrand, str):
integrand = TextBlock(integrand)

int_symb: TextBlock = _draw_integral_symbol(integrand.height)
return int_symb + integrand + " d" + var
int_symb: TextBlock = _draw_integral_symbol(integrand.height())
return TextBlock(*TextBlock.next(int_symb, integrand, TextBlock(" d"), var))


def integral_definite(
Expand All @@ -380,24 +399,20 @@ def integral_definite(
if isinstance(b, str):
b = TextBlock(b)

int_symb = _draw_integral_symbol(integrand.height)
return subsuperscript(int_symb, a, b) + " " + integrand + " d" + var
h_int = integrand.height()
symbol_height = h_int
# for ascii, symbol_height +=2
int_symb = _draw_integral_symbol(symbol_height)
orig_baseline = int_symb.baseline
int_symb = subsuperscript(int_symb, a, b)
return TextBlock(*TextBlock.next(int_symb, integrand, TextBlock(" d"), var))


def parenthesize(inner: Union[str, TextBlock]) -> TextBlock:
if isinstance(inner, str):
inner = TextBlock(inner)
height = inner.height
if height == 1:
left_br, right_br = TextBlock("("), TextBlock(")")
else:
left_br = TextBlock(
"/ \n" + "\n".join((height - 2) * ["| "]) + "\n\\ ", base=inner.base
)
right_br = TextBlock(
" \\ \n" + "\n".join((height - 2) * [" |"]) + "\n /", base=inner.base
)
return left_br + inner + right_br

return TextBlock(*inner.parens())


def sqrt_block(
Expand All @@ -408,9 +423,13 @@ def sqrt_block(
"""
if isinstance(a, str):
a = TextBlock(a)
if index is None:
index = ""
if isinstance(index, str):
index = TextBlock(index)

return TextBlock(*a.root(index))

a_height = a.height
result_2 = TextBlock(
"\n".join("|" + line for line in a.text.split("\n")), base=a.base
Expand All @@ -433,33 +452,58 @@ def sqrt_block(


def subscript(base: Union[TextBlock, str], a: Union[TextBlock, str]) -> TextBlock:
"""
Join b with a as a subscript.
"""
if isinstance(a, str):
a = TextBlock(a)
if isinstance(base, str):
base = TextBlock(base)

text2 = a.stack(TextBlock(base.height * [""], base=base.base), align="l")
text2.base = base.base + a.height
return base + text2
a = TextBlock(*TextBlock.next(TextBlock(base.width() * " "), a))
base = TextBlock(*TextBlock.next(base, TextBlock(a.width() * " ")))
result = TextBlock(*TextBlock.below(base, a))
return result


def subsuperscript(
base: Union[TextBlock, str], a: Union[TextBlock, str], b: Union[TextBlock, str]
) -> TextBlock:
"""
Join base with a as a superscript and b as a subscript
"""
if isinstance(base, str):
base = TextBlock(base)
if isinstance(a, str):
a = TextBlock(a)
if isinstance(b, str):
b = TextBlock(b)

text2 = a.stack((base.height - 1) * "\n", align="l").stack(b, align="l")
text2.base = base.base + a.height
return base + text2
# Ensure that a and b have the same width
width_diff = a.width() - b.width()
if width_diff < 0:
a = TextBlock(*TextBlock.next(a, TextBlock((-width_diff) * " ")))
elif width_diff > 0:
b = TextBlock(*TextBlock.next(b, TextBlock((width_diff) * " ")))

indx_spaces = b.width() * " "
base_spaces = base.width() * " "
a = TextBlock(*TextBlock.next(TextBlock(base_spaces), a))
b = TextBlock(*TextBlock.next(TextBlock(base_spaces), b))
base = TextBlock(*TextBlock.next(base, TextBlock(base_spaces)))
result = TextBlock(*TextBlock.below(base, a))
result = TextBlock(*TextBlock.above(result, b))
return result


def superscript(base: Union[TextBlock, str], a: Union[TextBlock, str]) -> TextBlock:
if isinstance(a, str):
a = TextBlock(a)
if isinstance(base, str):
base = TextBlock(base)
text2 = TextBlock((base.height - 1) * "\n", base=base.base).stack(a, align="l")
return base + text2

base_width, a_width = base.width(), a.width()
a = TextBlock(*TextBlock.next(TextBlock(base_width * " "), a))
base = TextBlock(*TextBlock.next(base, TextBlock(a_width * " ")))
result = TextBlock(*TextBlock.above(base, a))
return result
Loading

0 comments on commit d92fffb

Please sign in to comment.