From cc92af7b8d2109e12cc15d5190c06ec25dfbe11e Mon Sep 17 00:00:00 2001
From: cfranken <fronge@gmail.com>
Date: Fri, 28 Jun 2024 15:24:42 -0700
Subject: [PATCH] Fix GPU code if CPU

---
 src/CoreRT/tools/gpu_batched.jl         | 10 ----------
 src/CoreRT/tools/rt_helper_functions.jl |  8 ++++----
 src/CoreRT/types.jl                     |  8 ++++----
 3 files changed, 8 insertions(+), 18 deletions(-)

diff --git a/src/CoreRT/tools/gpu_batched.jl b/src/CoreRT/tools/gpu_batched.jl
index d314dcb..0805f26 100644
--- a/src/CoreRT/tools/gpu_batched.jl
+++ b/src/CoreRT/tools/gpu_batched.jl
@@ -43,13 +43,8 @@ end
 
 "Given 3D CuArray A, fill in X[:,:,k] = A[:,:,k] \\ I" 
 function batch_inv!(X::CuArray{FT,3}, A::CuArray{FT,3}, Xptrs, Aptrs ) where {FT<:Float32}
-    #CUBLAS.math_mode!(CUBLAS.handle(), CUDA.FAST_MATH)
-    # LU-factorize A
-    #@info "Batch Inv: Float32"
     n = size(A,1)
     lda = max(1,stride(A,2))
-    #@timeit "Pointer" Aptrs = CUBLAS.unsafe_strided_batch(A)
-    #Xptrs = CUBLAS.unsafe_strided_batch(X)
 
     batchSize = length(Aptrs)
     @timeit "info" info = CuArray{Cint}(undef, batchSize)
@@ -63,13 +58,8 @@ end
 
 "Given 3D CuArray A, fill in X[:,:,k] = A[:,:,k] \\ I" 
 function batch_inv!(X::CuArray{FT,3}, A::CuArray{FT,3}, Xptrs, Aptrs ) where {FT<:Float64}
-    #CUBLAS.math_mode!(CUBLAS.handle(), CUDA.FAST_MATH)
-    # LU-factorize A
-    #@info "Batch Inv: Float32"
     n = size(A,1)
     lda = max(1,stride(A,2))
-    #@timeit "Pointer" Aptrs = CUBLAS.unsafe_strided_batch(A)
-    #Xptrs = CUBLAS.unsafe_strided_batch(X)
 
     batchSize = length(Aptrs)
     @timeit "info" info = CuArray{Cint}(undef, batchSize)
diff --git a/src/CoreRT/tools/rt_helper_functions.jl b/src/CoreRT/tools/rt_helper_functions.jl
index 29f4fbb..ac134a7 100644
--- a/src/CoreRT/tools/rt_helper_functions.jl
+++ b/src/CoreRT/tools/rt_helper_functions.jl
@@ -103,10 +103,10 @@ default_J_matrix_rand(FT, arr_type, dims, nSpec) = arr_type(randn(FT, tuple(dims
 
 "Make an added layer, supplying all default matrices"
 function make_added_layer(RS_type::Union{noRS, noRS_plus}, FT, arr_type, dims, nSpec) 
-    t1 = default_matrix(FT, arr_type, dims, nSpec)
-    t2 = default_matrix(FT, arr_type, dims, nSpec)
-    t1_ptr = CUBLAS.unsafe_strided_batch(t1);
-    t2_ptr = CUBLAS.unsafe_strided_batch(t2);
+    t1 = arr_type == Array ? nothing : default_matrix(FT, arr_type, dims, nSpec)
+    t2 = arr_type == Array ? nothing : default_matrix(FT, arr_type, dims, nSpec)
+    t1_ptr = arr_type == Array ? nothing : CUBLAS.unsafe_strided_batch(t1);
+    t2_ptr = arr_type == Array ? nothing : CUBLAS.unsafe_strided_batch(t2);
     return AddedLayer(
                                                         default_matrix(FT, arr_type, dims, nSpec), 
                                                         default_matrix(FT, arr_type, dims, nSpec), 
diff --git a/src/CoreRT/types.jl b/src/CoreRT/types.jl
index b200185..51eb92a 100644
--- a/src/CoreRT/types.jl
+++ b/src/CoreRT/types.jl
@@ -132,13 +132,13 @@ Base.@kwdef struct AddedLayer{FT} <: AbstractLayer
     "Added layer source matrix J (in - direction)"
     j₀⁻::AbstractArray{FT,3}
     "Added layer temporary space to avoid allocations"
-    temp1::AbstractArray{FT,3}
+    temp1::Union{AbstractArray{FT,3}, Nothing}
     "Added layer temporary space to avoid allocations"
-    temp2::AbstractArray{FT,3}
+    temp2::Union{AbstractArray{FT,3}, Nothing}
     "Pointer to temporary space to avoid allocations"
-    temp1_ptr::CuArray{CuPtr{FT}, 1, CUDA.DeviceMemory}
+    temp1_ptr::Union{CuArray{CuPtr{FT}, 1, CUDA.DeviceMemory}, Nothing}
     "Pointer to temporary space to avoid allocations"
-    temp2_ptr::CuArray{CuPtr{FT}, 1, CUDA.DeviceMemory}
+    temp2_ptr::Union{CuArray{CuPtr{FT}, 1, CUDA.DeviceMemory}, Nothing}
 end
 
 "Composite Layer Matrices (`-/+` defined in τ coordinates, i.e. `-`=outgoing, `+`=incoming"