Skip to content

Commit

Permalink
Add UnitStep .... (#1250)
Browse files Browse the repository at this point in the history
Also, refactor mathics.eval.arithmetic to remove eval functions belonging to mathics.builtin.arithfns, and mathics.builtin.numeric
  • Loading branch information
rocky authored Dec 29, 2024
1 parent 455b074 commit 55c86d3
Show file tree
Hide file tree
Showing 8 changed files with 1,112 additions and 600 deletions.
2 changes: 1 addition & 1 deletion mathics/builtin/arithfns/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
SymbolPattern,
SymbolSequence,
)
from mathics.eval.arithmetic import eval_Plus, eval_Times
from mathics.eval.arithfns.basic import eval_Plus, eval_Times
from mathics.eval.nevaluator import eval_N
from mathics.eval.numerify import numerify

Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@
SymbolTable,
SymbolUndefined,
)
from mathics.eval.arithmetic import eval_Sign
from mathics.eval.inference import get_assumptions_list
from mathics.eval.nevaluator import eval_N
from mathics.eval.numeric import eval_Sign

# This tells documentation how to sort this module
sort_order = "mathics.builtin.mathematical-functions"
Expand Down
57 changes: 54 additions & 3 deletions mathics/builtin/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
A_HOLD_ALL,
A_LISTABLE,
A_NUMERIC_FUNCTION,
A_ORDERLESS,
A_PROTECTED,
)
from mathics.core.builtin import Builtin, MPMathFunction, SympyFunction
Expand All @@ -44,14 +45,16 @@
SymbolTrue,
)
from mathics.core.systemsymbols import SymbolPiecewise
from mathics.eval.arithmetic import (
from mathics.eval.inference import evaluate_predicate
from mathics.eval.nevaluator import eval_NValues
from mathics.eval.numeric import (
eval_Abs,
eval_negate_number,
eval_RealSign,
eval_Sign,
eval_UnitStep,
eval_UnitStep_multidimensional,
)
from mathics.eval.inference import evaluate_predicate
from mathics.eval.nevaluator import eval_NValues


def chop(expr, delta=10.0 ** (-10.0)):
Expand Down Expand Up @@ -787,3 +790,51 @@ def eval(self, x, evaluation: Evaluation):
def eval_error(self, x, seqs, evaluation: Evaluation):
"Sign[x_, seqs__]"
evaluation.message("Sign", "argx", Integer(len(seqs.get_sequence()) + 1))


class UnitStep(Builtin):
"""
<url>
:Heaviside step function:
https://en.wikipedia.org/wiki/Heaviside_step_function</url> (<url>
:WMA link:
https://reference.wolfram.com/language/ref/UnitStep.html</url>)
<dl>
<dt>'UnitStep[$x$]'
<dd>return 0 if $x$ < 0, and 1 if $x$ >= 0.
<dt>'UnitStep[$x1$, $x2$, ...]'
<dd>return the multidimensional unit step function which is 1 only if none of the $xi$ are negative.
</dl>
Evaluation numerically:
>> UnitStep[0.7]
= 1
We can use 'UnitStep' on irrational numbers and infinities:
>> Map[UnitStep, {Pi, Infinity, -Infinity}]
= {1, 1, 0}
>> Table[UnitStep[x], {x, -3, 3}]
= {0, 0, 0, 1, 1, 1, 1}
Plot in one dimension:
>> Plot[UnitStep[x], {x, -4, 4}]
= -Graphics-
## UnitStep is a piecewise function
## PiecewiseExpand[UnitStep[x]]
## = ...
"""

summary_text = "unit step function of a number"

attributes = A_LISTABLE | A_NUMERIC_FUNCTION | A_ORDERLESS | A_PROTECTED

def eval(self, x, evaluation: Evaluation):
"UnitStep[x_]"
return eval_UnitStep(x)

def eval_multidimenional(self, seqs, evaluation: Evaluation):
"UnitStep[seqs__]"
return eval_UnitStep_multidimensional(seqs)
3 changes: 3 additions & 0 deletions mathics/eval/arithfns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Module tracking eval functions under mathics.builtin.arithfns
"""
270 changes: 270 additions & 0 deletions mathics/eval/arithfns/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# -*- coding: utf-8 -*-

"""
evaluation function for builtins in mathics.builtin.arithfns.basic
Many of these depend on the evaluation context. Conversions to SymPy are
used just as a last resource.
"""

from typing import Optional

import mpmath
import sympy

# Note: it is important *not* use: from mathics.eval.tracing import run_sympy
# but instead import the module and access below as tracing.run_sympy.
# This allows us change where tracing.run_sympy points at runtime.
from mathics.core.atoms import (
Integer,
Integer0,
Integer1,
Integer2,
IntegerM1,
Number,
Rational,
Real,
)
from mathics.core.convert.mpmath import from_mpmath
from mathics.core.convert.sympy import from_sympy
from mathics.core.element import BaseElement, ElementsProperties
from mathics.core.expression import Expression
from mathics.core.number import min_prec
from mathics.core.symbols import SymbolPlus, SymbolPower, SymbolTimes
from mathics.core.systemsymbols import SymbolIndeterminate
from mathics.eval.arithmetic import (
eval_Power_number,
segregate_numbers_from_sorted_list,
)

RationalMOneHalf = Rational(-1, 2)
RealM0p5 = Real(-0.5)
RealOne = Real(1.0)


def eval_Plus(*items: BaseElement) -> BaseElement:
"evaluate Plus for general elements"
numbers, items_tuple = segregate_numbers_from_sorted_list(*items)
elements = []
last_item = last_count = None
number = eval_add_numbers(*numbers) if numbers else Integer0

# This reduces common factors
# TODO: Check if it possible to avoid the conversions back and forward to sympy.
def append_last():
if last_item is not None:
if last_count == 1:
elements.append(last_item)
else:
if last_item.has_form("Times", None):
elements.append(
Expression(
SymbolTimes, from_sympy(last_count), *last_item.elements
)
)
else:
elements.append(
Expression(SymbolTimes, from_sympy(last_count), last_item)
)

for item in items_tuple:
count = rest = None
if item.has_form("Times", None):
for element in item.elements:
if isinstance(element, Number):
count = element.to_sympy()
rest = item.get_mutable_elements()
rest.remove(element)
if len(rest) == 1:
rest = rest[0]
else:
rest.sort()
rest = Expression(SymbolTimes, *rest)
break
if count is None:
count = sympy.Integer(1)
rest = item
if last_item is not None and last_item == rest:
last_count = last_count + count
else:
append_last()
last_item = rest
last_count = count
append_last()

# now elements contains the symbolic terms which can not be simplified.
# by collecting common symbolic factors.
if not elements:
return number

if number is not Integer0:
elements.insert(0, number)
elif len(elements) == 1:
return elements[0]

elements.sort()
return Expression(
SymbolPlus,
*elements,
elements_properties=ElementsProperties(False, False, True),
)


def eval_Times(*items: BaseElement) -> Optional[BaseElement]:
elements = []
numbers = []
# find numbers and simplify Times -> Power
numbers, symbolic_items = segregate_numbers_from_sorted_list(*(items))
# This loop handles factors representing infinite quantities,
# and factors which are powers of the same basis.

for item in symbolic_items:
if item is SymbolIndeterminate:
return item
# Process powers
if elements:
previous_elem = elements[-1]
if item == previous_elem:
elements[-1] = Expression(SymbolPower, previous_elem, Integer2)
continue
elif item.has_form("Power", 2):
base, exp = item.elements
if previous_elem.has_form("Power", 2) and base.sameQ(
previous_elem.elements[0]
):
exp = eval_Plus(exp, previous_elem.elements[1])
elements[-1] = Expression(
SymbolPower,
base,
exp,
)
continue
if base.sameQ(previous_elem):
exp = eval_Plus(Integer1, exp)
elements[-1] = Expression(
SymbolPower,
base,
exp,
)
continue
elif previous_elem.has_form("Power", 2) and previous_elem.elements[0].sameQ(
item
):
exp = eval_Plus(Integer1, previous_elem.elements[1])
elements[-1] = Expression(
SymbolPower,
item,
exp,
)
continue
else:
item = item
# Otherwise, just append the element...
elements.append(item)

number = eval_multiply_numbers(*numbers) if numbers else Integer1

if len(elements) == 0 or number is Integer0:
return number

if number is IntegerM1 and elements and elements[0].has_form("Plus", None):
elements[0] = Expression(
elements[0].get_head(),
*[
Expression(SymbolTimes, IntegerM1, element)
for element in elements[0].elements
],
)
number = Integer1

if number is not Integer1:
elements.insert(0, number)

if len(elements) == 1:
return elements[0]

elements = sorted(elements)
items_elements = items
if len(elements) == len(items_elements) and all(
elem.sameQ(item) for elem, item in zip(elements, items_elements)
):
return None

return Expression(
SymbolTimes,
*elements,
elements_properties=ElementsProperties(False, False, True),
)


def eval_add_numbers(
*numbers: Number,
) -> BaseElement:
"""
Add the elements in ``numbers``.
"""
if len(numbers) == 0:
return Integer0
if len(numbers) == 1:
return numbers[0]

is_machine_precision = any(number.is_machine_precision() for number in numbers)
if is_machine_precision:
terms = (item.to_mpmath() for item in numbers)
number = mpmath.fsum(terms)
return from_mpmath(number)

prec = min_prec(*numbers)
if prec is not None:
# For a sum, what is relevant is the minimum accuracy of the terms
with mpmath.workprec(prec):
terms = (item.to_mpmath() for item in numbers)
number = mpmath.fsum(terms)
return from_mpmath(number, precision=prec)
else:
return from_sympy(sum(item.to_sympy() for item in numbers))


def eval_inverse_number(n: Number) -> Number:
"""
Eval 1/n
"""
if isinstance(n, Integer):
n_value = n.value
if n_value == 1 or n_value == -1:
return n
return Rational(-1, -n_value) if n_value < 0 else Rational(1, n_value)
if isinstance(n, Rational):
n_num, n_den = n.value.as_numer_denom()
if n_num < 0:
n_num, n_den = -n_num, -n_den
if n_num == 1:
return Integer(n_den)
return Rational(n_den, n_num)
# Otherwise, use power....
return eval_Power_number(n, IntegerM1)


def eval_multiply_numbers(*numbers: Number) -> Number:
"""
Multiply the elements in ``numbers``.
"""
if len(numbers) == 0:
return Integer1
if len(numbers) == 1:
return numbers[0]

is_machine_precision = any(number.is_machine_precision() for number in numbers)
if is_machine_precision:
factors = (item.to_mpmath() for item in numbers)
number = mpmath.fprod(factors)
return from_mpmath(number)

prec = min_prec(*numbers)
if prec is not None:
with mpmath.workprec(prec):
factors = (item.to_mpmath() for item in numbers)
number = mpmath.fprod(factors)
return from_mpmath(number, prec)
else:
return from_sympy(sympy.Mul(*(item.to_sympy() for item in numbers)))
Loading

0 comments on commit 55c86d3

Please sign in to comment.