diff --git a/src/states.jl b/src/states.jl index 145d5c73..9f78d993 100644 --- a/src/states.jl +++ b/src/states.jl @@ -119,6 +119,11 @@ julia> new_mat = states(test_mat, fill(3.0f0, 3)) """ struct ExtendedStates <: AbstractStates end +function (states_type::ExtendedStates)(mat::AbstractMatrix, inp::AbstractMatrix) + results = states_type.(eachcol(mat), eachcol(inp)) + return hcat(results...) +end + function (states_type::ExtendedStates)(mat::AbstractMatrix, inp::AbstractVector) results = Vector{Vector{eltype(mat)}}(undef, size(mat, 2)) for (idx, col) in enumerate(eachcol(mat)) diff --git a/test/test_states.jl b/test/test_states.jl index 2eaab5fb..1a191c4f 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -12,12 +12,10 @@ nlas = [(NLADefault(), test_array), pes = [(StandardStates(), test_array), (PaddedStates(; padding=padding), - reshape(vcat(padding, test_array), length(test_array) + 1, 1)), + vcat(test_array, padding)), (PaddedExtendedStates(; padding=padding), - reshape(vcat(padding, extension, test_array), - length(test_array) + length(extension) + 1, - 1)), - (ExtendedStates(), vcat(extension, test_array))] + vcat(test_array, padding, extension)), + (ExtendedStates(), vcat(test_array, extension))] @testset "States Testing" for T in test_types @testset "Nonlinear Algorithms Testing: $algo $T" for (algo, expected_output) in nlas