From b3ed0b6823f799ebb82ecb6268b95a1b7f9faa1b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 6 Feb 2025 21:45:08 +0000 Subject: [PATCH] feat: overload additional Base operators --- Project.toml | 2 +- src/ExpressionAlgebra.jl | 3 ++- test/test_expressions.jl | 45 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 262590b6..ef70fdb9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "1.9.3" +version = "1.9.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/ExpressionAlgebra.jl b/src/ExpressionAlgebra.jl index 6b8d7b26..f5dcd7b4 100644 --- a/src/ExpressionAlgebra.jl +++ b/src/ExpressionAlgebra.jl @@ -154,7 +154,8 @@ end for op in ( :*, :/, :+, :-, :^, :÷, :mod, :log, :atan, :atand, :copysign, :flipsign, - :&, :|, :⊻, ://, :\, + :&, :|, :⊻, ://, :\, :rem, + :(>), :(<), :(>=), :(<=), :max, :min, ) @eval @declare_expression_operator Base.$(op) 2 end diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 54448001..a2927f9b 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -454,3 +454,48 @@ end @test get_variable_names(new_ex2, nothing) == ["x1"] @test get_operators(new_ex2, nothing) == new_operators end + +@testitem "New binary operators" begin + using DynamicExpressions + + operators = OperatorEnum(; + binary_operators=[+, -, *, /, >, <, >=, <=, max, min, rem], + unary_operators=[sin, cos], + ) + x1, x2 = [Node(Float64; feature=i) for i in 1:2] + + # Test comparison operators string representation + tree = x1 > x2 + @test string(tree) == "x1 > x2" + + tree = x1 < x2 + @test string(tree) == "x1 < x2" + + tree = x1 >= x2 + @test string(tree) == "x1 >= x2" + + tree = x1 <= x2 + @test string(tree) == "x1 <= x2" + + # Test max/min operators + tree = max(x1, x2) + X = [1.0 2.0; 3.0 1.0]' # Two points: (1,3) and (2,1) + @test tree(X, operators) ≈ [2.0, 3.0] + + tree = min(x1, x2) + @test tree(X, operators) ≈ [1.0, 1.0] + + # Test remainder operator + tree = rem(x1, x2) + X = [5.0 7.0; 3.0 2.0]' # Two points: (5,7) and (3,2) + @test tree(X, operators) ≈ [5.0, 1.0] + + # Test combinations string representation + tree = max(x1, 2.0) > min(x2, 3.0) + @test string(tree) == "max(x1, 2.0) > min(x2, 3.0)" + + # Test with constants + tree = rem(x1, 2.0) + X = [5.0 7.0] # Two points: 5 and 7 + @test tree(X, operators) ≈ [1.0, 1.0] +end