From dc47665b34329ce13782f7707409cd09e8b521cd Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 22 Jun 2021 17:34:11 -0400 Subject: [PATCH] Correct the case of GPU/Float32 (#48) * Update FixedEffectSolverGPU.jl * Update FixedEffectSolverCPU.jl * Update Project.toml --- Project.toml | 2 +- src/FixedEffectSolvers/FixedEffectSolverCPU.jl | 6 +++--- src/FixedEffectSolvers/FixedEffectSolverGPU.jl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index e1b0c67..5aeac6b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FixedEffects" uuid = "c8885935-8500-56a7-9867-7708b20db0eb" -version = "2.0.4" +version = "2.0.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FixedEffectSolvers/FixedEffectSolverCPU.jl b/src/FixedEffectSolvers/FixedEffectSolverCPU.jl index 95f0c85..b3aa4e2 100644 --- a/src/FixedEffectSolvers/FixedEffectSolverCPU.jl +++ b/src/FixedEffectSolvers/FixedEffectSolverCPU.jl @@ -154,14 +154,14 @@ function solve_residuals!(r::AbstractVector, feM::FixedEffectSolverCPU{T}; tol:: end copyto!(feM.b, feM.r) if length(feM.x.x) == 1 - mul!(feM.x, feM.m', feM.b, 1.0, 0.0) + mul!(feM.x, feM.m', feM.b, 1, 0) iter, converged = 1, true else - mul!(feM.x, feM.m', feM.b, 1.0, 0.0) + mul!(feM.x, feM.m', feM.b, 1, 0) x, ch = lsmr!(feM.x, feM.m, feM.b, feM.v, feM.h, feM.hbar; atol = tol, btol = tol, maxiter = maxiter) iter, converged = ch.mvps + 1, ch.isconverged end - mul!(feM.r, feM.m, feM.x, -1.0, 1.0) + mul!(feM.r, feM.m, feM.x, -1, 1) if !(feM.weights isa UnitWeights) feM.r ./= sqrt.(feM.weights) end diff --git a/src/FixedEffectSolvers/FixedEffectSolverGPU.jl b/src/FixedEffectSolvers/FixedEffectSolverGPU.jl index 9fa9b75..63561b8 100644 --- a/src/FixedEffectSolvers/FixedEffectSolverGPU.jl +++ b/src/FixedEffectSolvers/FixedEffectSolverGPU.jl @@ -188,9 +188,9 @@ function solve_residuals!(r::AbstractVector, feM::FixedEffectSolverGPU{T}; tol:: feM.r .*= sqrt.(feM.weights) end copyto!(feM.b, feM.r) - mul!(feM.x, feM.m', feM.b, 1.0, 0.0) + mul!(feM.x, feM.m', feM.b, 1, 0) x, ch = lsmr!(feM.x, feM.m, feM.b, feM.v, feM.h, feM.hbar; atol = tol, btol = tol, maxiter = maxiter) - mul!(feM.r, feM.m, feM.x, -1.0, 1.0) + mul!(feM.r, feM.m, feM.x, -1, 1) if !(feM.weights isa UnitWeights) feM.r ./= sqrt.(feM.weights) end