From 16b2c2fcdbd078126a2e3d9a8ded3de8f7a46a7b Mon Sep 17 00:00:00 2001 From: Evan Haas Date: Thu, 29 Feb 2024 23:54:42 -0800 Subject: [PATCH] Value: add complex division support --- src/aro/Value.zig | 29 ++++++++---- src/aro/annex_g.zig | 92 ++++++++++++++++++++++++++++--------- test/cases/complex values.c | 4 +- 3 files changed, 94 insertions(+), 31 deletions(-) diff --git a/src/aro/Value.zig b/src/aro/Value.zig index e290f09e..7a7c61e2 100644 --- a/src/aro/Value.zig +++ b/src/aro/Value.zig @@ -539,15 +539,26 @@ pub fn mul(res: *Value, lhs: Value, rhs: Value, ty: Type, comp: *Compilation) !b pub fn div(res: *Value, lhs: Value, rhs: Value, ty: Type, comp: *Compilation) !bool { const bits: usize = @intCast(ty.bitSizeof(comp).?); if (ty.isFloat()) { - const f: Interner.Key.Float = switch (bits) { - 16 => .{ .f16 = lhs.toFloat(f16, comp) / rhs.toFloat(f16, comp) }, - 32 => .{ .f32 = lhs.toFloat(f32, comp) / rhs.toFloat(f32, comp) }, - 64 => .{ .f64 = lhs.toFloat(f64, comp) / rhs.toFloat(f64, comp) }, - 80 => .{ .f80 = lhs.toFloat(f80, comp) / rhs.toFloat(f80, comp) }, - 128 => .{ .f128 = lhs.toFloat(f128, comp) / rhs.toFloat(f128, comp) }, - else => unreachable, - }; - res.* = try intern(comp, .{ .float = f }); + if (ty.isComplex()) { + const cf: Interner.Key.Complex = switch (bits) { + 64 => .{ .cf32 = annex_g.complexFloatDiv(f32, lhs.toFloat(f32, comp), lhs.imag(f32, comp), rhs.toFloat(f32, comp), rhs.imag(f32, comp)) }, + 128 => .{ .cf64 = annex_g.complexFloatDiv(f64, lhs.toFloat(f64, comp), lhs.imag(f64, comp), rhs.toFloat(f64, comp), rhs.imag(f64, comp)) }, + 160 => .{ .cf80 = annex_g.complexFloatDiv(f80, lhs.toFloat(f80, comp), lhs.imag(f80, comp), rhs.toFloat(f80, comp), rhs.imag(f80, comp)) }, + 256 => .{ .cf128 = annex_g.complexFloatDiv(f128, lhs.toFloat(f128, comp), lhs.imag(f128, comp), rhs.toFloat(f128, comp), rhs.imag(f128, comp)) }, + else => unreachable, + }; + res.* = try intern(comp, .{ .complex = cf }); + } else { + const f: Interner.Key.Float = switch (bits) { + 16 => .{ .f16 = lhs.toFloat(f16, comp) / rhs.toFloat(f16, comp) }, + 32 => .{ .f32 = lhs.toFloat(f32, comp) / rhs.toFloat(f32, comp) }, + 64 => .{ .f64 = lhs.toFloat(f64, comp) / rhs.toFloat(f64, comp) }, + 80 => .{ .f80 = lhs.toFloat(f80, comp) / rhs.toFloat(f80, comp) }, + 128 => .{ .f128 = lhs.toFloat(f128, comp) / rhs.toFloat(f128, comp) }, + else => unreachable, + }; + res.* = try intern(comp, .{ .float = f }); + } return false; } else { var lhs_space: BigIntSpace = undefined; diff --git a/src/aro/annex_g.zig b/src/aro/annex_g.zig index 089a9866..48e0d95d 100644 --- a/src/aro/annex_g.zig +++ b/src/aro/annex_g.zig @@ -1,6 +1,14 @@ //! Complex arithmetic algorithms from C99 Annex G const std = @import("std"); +const copysign = std.math.copysign; +const ilogb = std.math.ilogb; +const inf = std.math.inf; +const isFinite = std.math.isFinite; +const isInf = std.math.isInf; +const isNan = std.math.isNan; +const isPositiveZero = std.math.isPositiveZero; +const scalbn = std.math.scalbn; /// computes floating point z*w where a_param, b_param are real, imaginary parts of z and c_param, d_param are real, imaginary parts of w pub fn complexFloatMul(comptime T: type, a_param: T, b_param: T, c_param: T, d_param: T) [2]T { @@ -15,36 +23,71 @@ pub fn complexFloatMul(comptime T: type, a_param: T, b_param: T, c_param: T, d_p const bc = b * c; var x = ac - bd; var y = ad + bc; - if (std.math.isNan(x) and std.math.isNan(y)) { + if (isNan(x) and isNan(y)) { var recalc = false; - if (std.math.isInf(a) or std.math.isInf(b)) { + if (isInf(a) or isInf(b)) { // lhs infinite // Box the infinity and change NaNs in the other factor to 0 - a = std.math.copysign(if (std.math.isInf(a)) @as(T, 1.0) else @as(T, 0.0), a); - b = std.math.copysign(if (std.math.isInf(b)) @as(T, 1.0) else @as(T, 0.0), b); - if (std.math.isNan(c)) c = std.math.copysign(@as(T, 0.0), c); - if (std.math.isNan(d)) d = std.math.copysign(@as(T, 0.0), d); + a = copysign(if (isInf(a)) @as(T, 1.0) else @as(T, 0.0), a); + b = copysign(if (isInf(b)) @as(T, 1.0) else @as(T, 0.0), b); + if (isNan(c)) c = copysign(@as(T, 0.0), c); + if (isNan(d)) d = copysign(@as(T, 0.0), d); recalc = true; } - if (std.math.isInf(c) or std.math.isInf(d)) { + if (isInf(c) or isInf(d)) { // rhs infinite // Box the infinity and change NaNs in the other factor to 0 - c = std.math.copysign(if (std.math.isInf(c)) @as(T, 1.0) else @as(T, 0.0), c); - d = std.math.copysign(if (std.math.isInf(d)) @as(T, 1.0) else @as(T, 0.0), d); - if (std.math.isNan(a)) a = std.math.copysign(@as(T, 0.0), a); - if (std.math.isNan(b)) b = std.math.copysign(@as(T, 0.0), b); + c = copysign(if (isInf(c)) @as(T, 1.0) else @as(T, 0.0), c); + d = copysign(if (isInf(d)) @as(T, 1.0) else @as(T, 0.0), d); + if (isNan(a)) a = copysign(@as(T, 0.0), a); + if (isNan(b)) b = copysign(@as(T, 0.0), b); recalc = true; } - if (!recalc and (std.math.isInf(ac) or std.math.isInf(bd) or std.math.isInf(ad) or std.math.isInf(bc))) { + if (!recalc and (isInf(ac) or isInf(bd) or isInf(ad) or isInf(bc))) { // Recover infinities from overflow by changing NaN's to 0 - if (std.math.isNan(a)) a = std.math.copysign(@as(T, 0.0), a); - if (std.math.isNan(b)) b = std.math.copysign(@as(T, 0.0), b); - if (std.math.isNan(c)) c = std.math.copysign(@as(T, 0.0), c); - if (std.math.isNan(d)) d = std.math.copysign(@as(T, 0.0), d); + if (isNan(a)) a = copysign(@as(T, 0.0), a); + if (isNan(b)) b = copysign(@as(T, 0.0), b); + if (isNan(c)) c = copysign(@as(T, 0.0), c); + if (isNan(d)) d = copysign(@as(T, 0.0), d); } if (recalc) { - x = std.math.inf(T) * (a * c - b * d); - y = std.math.inf(T) * (a * d + b * c); + x = inf(T) * (a * c - b * d); + y = inf(T) * (a * d + b * c); + } + } + return .{ x, y }; +} + +/// computes floating point z / w where a_param, b_param are real, imaginary parts of z and c_param, d_param are real, imaginary parts of w +pub fn complexFloatDiv(comptime T: type, a_param: T, b_param: T, c_param: T, d_param: T) [2]T { + var a = a_param; + var b = b_param; + var c = c_param; + var d = d_param; + var denom_logb: i32 = 0; + const max_cd = @max(@abs(c), @abs(d)); + if (isFinite(max_cd)) { + denom_logb = ilogb(max_cd); + c = scalbn(c, -denom_logb); + d = scalbn(d, -denom_logb); + } + const denom = c * c + d * d; + var x = scalbn((a * c + b * d) / denom, -denom_logb); + var y = scalbn((b * c - a * d) / denom, -denom_logb); + if (isNan(x) and isNan(y)) { + if (isPositiveZero(denom) and (!isNan(a) or !isNan(b))) { + x = copysign(inf(T), c) * a; + y = copysign(inf(T), c) * b; + } else if ((isInf(a) or isInf(b)) and isFinite(c) and isFinite(d)) { + a = copysign(if (isInf(a)) @as(T, 1.0) else @as(T, 0.0), a); + b = copysign(if (isInf(b)) @as(T, 1.0) else @as(T, 0.0), b); + x = inf(T) * (a * c + b * d); + y = inf(T) * (b * c - a * d); + } else if (isInf(max_cd) and isFinite(a) and isFinite(b)) { + c = copysign(if (isInf(c)) @as(T, 1.0) else @as(T, 0.0), c); + d = copysign(if (isInf(d)) @as(T, 1.0) else @as(T, 0.0), d); + x = 0.0 * (a * c + b * d); + y = 0.0 * (b * c - a * d); } } return .{ x, y }; @@ -52,7 +95,14 @@ pub fn complexFloatMul(comptime T: type, a_param: T, b_param: T, c_param: T, d_p test complexFloatMul { // Naive algorithm would produce NaN + NaNi instead of inf + NaNi - const result = complexFloatMul(f64, std.math.inf(f64), std.math.nan(f64), 2, 0); - try std.testing.expect(std.math.isInf(result[0])); - try std.testing.expect(std.math.isNan(result[1])); + const result = complexFloatMul(f64, inf(f64), std.math.nan(f64), 2, 0); + try std.testing.expect(isInf(result[0])); + try std.testing.expect(isNan(result[1])); +} + +test complexFloatDiv { + // Naive algorithm would produce NaN + NaNi instead of inf + NaNi + const result = complexFloatDiv(f64, inf(f64), std.math.nan(f64), 2, 0); + try std.testing.expect(isInf(result[0])); + try std.testing.expect(isNan(result[1])); } diff --git a/test/cases/complex values.c b/test/cases/complex values.c index 5fff402a..301d4cd1 100644 --- a/test/cases/complex values.c +++ b/test/cases/complex values.c @@ -3,4 +3,6 @@ _Static_assert(2.0i * 2.0i == -4.0, ""); _Static_assert((double)2.0i == 0, ""); _Static_assert((double)(_Complex double)42 == 42, ""); _Static_assert(-2.0i - 2.0i == -4.0i, ""); -_Static_assert(~(2.0 + 4.0i) == 2.0 - 4.0i, ""); \ No newline at end of file +_Static_assert(~(2.0 + 4.0i) == 2.0 - 4.0i, ""); +_Static_assert((2.0 + 2.0i) / 2.0 == 1.0 + 1.0i, ""); +_Static_assert((2.0 + 4.0i) / 1.0i == 4.0 - 2.0i, ""); \ No newline at end of file