diff --git a/src/FlashWeave.jl b/src/FlashWeave.jl index 20b531f..fa430c1 100644 --- a/src/FlashWeave.jl +++ b/src/FlashWeave.jl @@ -10,7 +10,7 @@ using LightGraphs, SimpleWeightedGraphs using StatsBase, Distributions, Combinatorics # io -using JSON, HDF5 +using JSON, HDF5, FileIO # utilities import Base.show @@ -37,11 +37,11 @@ export learn_network, show, graph -function __init__() - warn_items = [(:FileIO, "JLD/JLD2")] - for (mod_symbol, format) in warn_items - isdefined(mod_symbol) && warn("Package $mod_symbol was loaded before importing FlashWeave. $format will not be available for FlashWeave's IO functions.") - end -end +# function __init__() +# warn_items = [(:FileIO, "JLD/JLD2")] +# for (mod_symbol, format) in warn_items +# isdefined(mod_symbol) && warn("Package $mod_symbol was loaded before importing FlashWeave. $format will not be available for FlashWeave's IO functions.") +# end +# end end diff --git a/src/io.jl b/src/io.jl index 889c778..fbf168a 100644 --- a/src/io.jl +++ b/src/io.jl @@ -62,8 +62,9 @@ function save_network(net_path::AbstractString, net_result::FWResult; detailed:: if isedgelist(file_ext) write_edgelist(net_path, graph(net_result)) elseif isjld(file_ext) - isdefined(:FileIO) || @eval using FileIO: save, load - Base.invokelatest(save, net_path, "results", net_result) + # isdefined(:FileIO) || @eval using FileIO: save, load + # Base.invokelatest(save, net_path, "results", net_result) + save(net_path, "results", net_result) else error("$(file_ext) not a valid output format. Choose one of $(valid_net_formats)") end @@ -88,10 +89,13 @@ function load_network(net_path::AbstractString) if isedgelist(file_ext) G = read_edgelist(net_path) net_result = FWResult(G) + elseif isjld(file_ext) - isdefined(:FileIO) || @eval using FileIO: save, load - d = Base.invokelatest(load, net_path) - net_result = d["results"] + # isdefined(:FileIO) || @eval using FileIO: save, load + # d = Base.invokelatest(load, net_path) + # net_result = d["results"] + net_result = load(net_path, "results") + else error("$(file_ext) not a valid network format. Valid formats are $(valid_net_formats)") end @@ -102,16 +106,17 @@ end ## Helper functions ## ###################### -function load_jld(data_path::AbstractString, data_key::AbstractString, header_key::AbstractString, - meta_key=nothing, meta_header_key=nothing; transposed::Bool=false) - isdefined(:FileIO) || @eval using FileIO: save, load - d = Base.invokelatest(load, data_path) +function load_jld(data_path::AbstractString, otu_data_key::AbstractString, otu_header_key::AbstractString, + meta_data_key=nothing, meta_header_key=nothing; transposed::Bool=false) + # isdefined(:FileIO) || @eval using FileIO: save, load + # d = Base.invokelatest(load, data_path) + d = load(data_path) - data = d[data_key] - header = d[header_key] + data = d[otu_data_key] + header = d[otu_header_key] - if meta_key != nothing - meta_data = d[meta_key] + if meta_data_key != nothing + meta_data = d[meta_data_key] meta_header = d[meta_header_key] else meta_data = meta_header = nothing diff --git a/src/learning.jl b/src/learning.jl index 50c740f..a76fe09 100644 --- a/src/learning.jl +++ b/src/learning.jl @@ -345,14 +345,14 @@ meta data table as an input (instead of a data matrix). """ -function learn_network(data_path::AbstractString, meta_data_path=nothing; otu_data_key::AbstractString="data", - otu_header_key::AbstractString="header", meta_data_key::AbstractString="meta_data", +function learn_network(data_path::AbstractString, meta_data_path=nothing; otu_data_key::AbstractString="otu_data", + otu_header_key::AbstractString="otu_header", meta_data_key::AbstractString="meta_data", meta_header_key::AbstractString="meta_header", verbose::Bool=true, transposed::Bool=false, kwargs...) verbose && println("\n### Loading data ###\n") data, header, meta_data, meta_header = load_data(data_path, meta_data_path, otu_data_key=otu_data_key, - otu_header_key=otu_header_key, meta_key=meta_data_key, + otu_header_key=otu_header_key, meta_data_key=meta_data_key, meta_header_key=meta_header_key, transposed=transposed) @@ -362,8 +362,8 @@ function learn_network(data_path::AbstractString, meta_data_path=nothing; otu_d else check_data(data, meta_data, header=header, meta_header=meta_header) data = hcat(data, meta_data) - header = hcat(header, meta_header) - meta_mask = hcat(falses(length(header)), trues(length(meta_header))) + meta_mask = vcat(falses(length(header)), trues(length(meta_header))) + header = vcat(header, meta_header) end learn_network(data; header=header, meta_mask=meta_mask, verbose=verbose, kwargs...) @@ -444,6 +444,14 @@ function learn_network(data::AbstractArray{ElType}; sensitive::Bool=true, data = data' end + if header == nothing + header = ["X" * string(i) for i in 1:size(data, 2)] + end + + if meta_mask == nothing + meta_mask = falses(length(header)) + end + check_data(data, header, meta_mask=meta_mask) if !issparse(data) && make_sparse @@ -451,10 +459,6 @@ function learn_network(data::AbstractArray{ElType}; sensitive::Bool=true, data = sparse(data) end - if header != nothing && meta_mask == nothing - meta_mask = falses(length(header)) - end - n_mvs = sum(meta_mask) if verbose println("""Inferring network with $(mode_string(heterogeneous, sensitive, max_k))\n diff --git a/src/misc.jl b/src/misc.jl index 3de1b0d..b84d48b 100644 --- a/src/misc.jl +++ b/src/misc.jl @@ -22,7 +22,7 @@ end function check_data(data::AbstractMatrix, header::AbstractVector; meta_mask=nothing) @assert size(data, 2) == length(header) "header does not fit data: $(size(data, 2)) vs. $(length(header))" - meta_mask != nothing && @assert size(data, 2) == length(meta_mask) "meta_mask does not fit data: $(length(data, 2)) vs. $(length(meta_mask))" + meta_mask != nothing && @assert size(data, 2) == length(meta_mask) "meta_mask does not fit data: $(size(data, 2)) vs. $(length(meta_mask))" end diff --git a/test/learning.jl b/test/learning.jl index 8bdd9bd..55358c0 100644 --- a/test/learning.jl +++ b/test/learning.jl @@ -6,8 +6,8 @@ nprocs() == 1 && addprocs(1) using FlashWeave using FileIO - -data = Matrix{Float64}(readdlm(joinpath("data", "HMP_SRA_gut", "HMP_SRA_gut_small.tsv"), '\t')[2:end, 2:end]) +data_path = joinpath("data", "HMP_SRA_gut", "HMP_SRA_gut_small.tsv") +data = Matrix{Float64}(readdlm(data_path, '\t')[2:end, 2:end]) data_sp = sparse(data) adj_exp_dict = load(joinpath("data", "learning_expected.jld2")) @@ -180,6 +180,29 @@ end end end end + + @testset "from file" begin + path_trunk = joinpath("data", "HMP_SRA_gut", "HMP_SRA_gut_tiny") + for (data_format, suff_pair, transp_suff_pair) in zip(["tsv", "jld"], + [(".tsv", "_ids_transposed.tsv"), + ("_plus_meta.jld", "_plus_meta_transposed.jld")], + [("_meta.tsv", "_meta_transposed.tsv"),("","")]) + @testset "$data_format" begin + path_pairs = [path_trunk * suff for suff in (suff_pair..., transp_suff_pair...)] + pred_graphs = [graph(learn_network(path_pairs[i], path_pairs[i_meta], sensitive=true, + heterogeneous=false, max_k=3, verbose=false, transposed=transp)) + for (i, i_meta, transp) in [(1, 3, false), (2, 4, true)]] + + for pred_graph in pred_graphs + @test compare_graph_results(pred_graphs..., + rtol=rtol, atol=atol, + approx=true, + approx_nbr_diff=approx_nbr_diff, + approx_weight_meandiff=approx_weight_meandiff) + end + end + end + end end # to create expected output