From ec22881f850deb5dd0d112e33f374c153903b5b3 Mon Sep 17 00:00:00 2001 From: haykam821 <24855774+haykam821@users.noreply.github.com> Date: Fri, 15 Mar 2024 09:44:42 -0400 Subject: [PATCH] Allow obtaining the mean of a distribution Fixes #284 --- core/src/ast.rs | 1 + core/src/error.rs | 2 ++ core/src/num/dist.rs | 22 +++++++++++++++++++++- core/src/num/unit.rs | 7 +++++++ core/src/value.rs | 1 + core/src/value/built_in_function.rs | 2 ++ core/tests/integration_tests.rs | 12 ++++++++++++ 7 files changed, 46 insertions(+), 1 deletion(-) diff --git a/core/src/ast.rs b/core/src/ast.rs index 2cc2e2b9..76060be1 100644 --- a/core/src/ast.rs +++ b/core/src/ast.rs @@ -696,6 +696,7 @@ pub(crate) fn resolve_identifier( "true" => Value::Bool(true), "false" => Value::Bool(false), "sample" | "roll" => Value::BuiltInFunction(BuiltInFunction::Sample), + "mean" | "average" => Value::BuiltInFunction(BuiltInFunction::Mean), "sqrt" => evaluate_to_value("x: x^(1/2)", scope, attrs, context, int)?, "cbrt" => evaluate_to_value("x: x^(1/3)", scope, attrs, context, int)?, "real" | "re" | "Re" => Value::BuiltInFunction(BuiltInFunction::Real), diff --git a/core/src/error.rs b/core/src/error.rs index f2a50b91..009ee590 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -24,6 +24,7 @@ pub(crate) enum FendError { }, NegativeNumbersNotAllowed, ProbabilityDistributionsNotAllowed, + EmptyDistribution, FractionToInteger, ModuloByZero, RandomNumbersNotAvailable, @@ -173,6 +174,7 @@ impl fmt::Display for FendError { "probability distributions are not allowed (consider using `sample`)" ) } + Self::EmptyDistribution => write!(f, "there must be at least one part in a dist"), Self::ParseDateError(s) => write!(f, "failed to convert '{s}' to a date"), Self::ExpectedAString => write!(f, "expected a string"), Self::UnableToInvertFunction(name) => write!(f, "unable to invert function {name}"), diff --git a/core/src/num/dist.rs b/core/src/num/dist.rs index cd32f20d..e559050e 100644 --- a/core/src/num/dist.rs +++ b/core/src/num/dist.rs @@ -96,7 +96,27 @@ impl Dist { } res = Some(Self::from(k)); } - Ok(res.expect("there must be at least one part in a dist")) + res.ok_or(FendError::EmptyDistribution) + } + + pub(crate) fn mean(self, int: &I) -> FResult { + if self.parts.is_empty() { + return Err(FendError::EmptyDistribution); + } + + let len = self.parts.len(); + + if self.parts.len() == 1 { + return Ok(self); + } + + let mut result = Exact::new(Complex::from(0), true); + for (k, _v) in self.parts { + result = result.add(Exact::new(k, true), int)?; + } + + result = result.div(Exact::new(Complex::from(len as u64), true), int)?; + Ok(Self::from(result.value)) } #[allow( diff --git a/core/src/num/unit.rs b/core/src/num/unit.rs index 3b90fcf0..5c4de166 100644 --- a/core/src/num/unit.rs +++ b/core/src/num/unit.rs @@ -537,6 +537,13 @@ impl Value { }) } + pub(crate) fn mean(self, int: &I) -> FResult { + Ok(Self { + value: self.value.mean(int)?, + ..self + }) + } + fn convert_angle_to_rad( self, scope: Option>, diff --git a/core/src/value.rs b/core/src/value.rs index b7117005..117dd50f 100644 --- a/core/src/value.rs +++ b/core/src/value.rs @@ -349,6 +349,7 @@ impl Value { return Ok(Self::Base(Base::from_plain_base(n)?)); } BuiltInFunction::Sample => arg.expect_num()?.sample(context, int)?, + BuiltInFunction::Mean => arg.expect_num()?.mean(int)?, BuiltInFunction::Not => return Ok(Self::Bool(!arg.as_bool()?)), BuiltInFunction::Conjugate => arg.expect_num()?.conjugate()?, BuiltInFunction::Real => arg.expect_num()?.real()?, diff --git a/core/src/value/built_in_function.rs b/core/src/value/built_in_function.rs index 16427b8f..18f71e1c 100644 --- a/core/src/value/built_in_function.rs +++ b/core/src/value/built_in_function.rs @@ -31,6 +31,7 @@ pub(crate) enum BuiltInFunction { Log10, Base, Sample, + Mean, Not, Conjugate, Real, @@ -96,6 +97,7 @@ impl BuiltInFunction { Self::Log10 => "log10", Self::Base => "base", Self::Sample => "sample", + Self::Mean => "mean", Self::Not => "not", Self::Conjugate => "conjugate", Self::Real => "real", diff --git a/core/tests/integration_tests.rs b/core/tests/integration_tests.rs index b8e0d0db..e805de49 100644 --- a/core/tests/integration_tests.rs +++ b/core/tests/integration_tests.rs @@ -5936,3 +5936,15 @@ fn test_roman() { test_eval_simple("1452 to roman", "MCDLII"); test_eval_simple("20002 to roman", "MMMMMMMMMMMMMMMMMMMMII"); } + +#[test] +fn test_mean() { + test_eval("mean d1", "1"); + test_eval("mean d2", "1.5"); + test_eval("mean d500", "250.5"); + + test_eval("mean (d1 + d1)", "2"); + test_eval("mean (d2 + d500)", "252"); + + test_eval("average d500", "250.5"); +}