Skip to content

Commit

Permalink
Add GraphML output option
Browse files Browse the repository at this point in the history
  • Loading branch information
jtackm committed Jul 27, 2018
1 parent dbbe189 commit 55b0cea
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 43 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ julia> # data_path = "/my/example/data.jld2"
julia> # netw_results = learn_network(data_path, otu_data_key="otu_data", otu_header_key="otu_header", meta_data_key="meta_data", meta_header_key="meta_header", sensitive=true, heterogeneous=false)
```

Results can currently be saved in JLD/2, fast also for large networks, or as traditional edgelist (".edgelist") format:
Results can currently be saved in JLD/2, fast for large networks, or as traditional GraphML (".gml") or edgelist (".edgelist") formats:

```julia
julia> save_network("/my/example/network_output.jld2", netw_results)
julia> ## or: save_network("/my/example/network_output.edgelist", netw_results)
julia> ## or: save_network("/my/example/network_output.gml", netw_results)
```

For output of additional information (such as discarding sets, if available) in separate files you can specify the "detailed" flag:
Expand Down
5 changes: 3 additions & 2 deletions src/FlashWeave.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using StatsBase, Distributions, Combinatorics
using JSON, HDF5, FileIO

# utilities
import Base.show
import Base.show, Base.names, Base.==


include("types.jl")
Expand All @@ -35,7 +35,8 @@ export learn_network,
load_network,
load_data,
show,
graph
graph,
meta_variable_mask

# function __init__()
# warn_items = [(:FileIO, "JLD/JLD2")]
Expand Down
2 changes: 1 addition & 1 deletion src/interleaved.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function interleaved_backend(target_vars::AbstractVector{Int}, data::AbstractMat


if verbose
println("\nPreparing workers for conditional search..")
println("\nPreparing workers..")
tic()
end

Expand Down
167 changes: 138 additions & 29 deletions src/io.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# note: needed lots of @eval and Base.invokelatest hacks for conditional
# module loading

const valid_net_formats = (".edgelist", ".jld2", ".jld")
const valid_net_formats = (".edgelist", ".gml", ".jld2", ".jld")
const valid_data_formats = (".tsv", ".csv", ".biom", ".jld2", ".jld")

isjld(ext::AbstractString) = ext in (".jld2", ".jld")
isdlm(ext::AbstractString) = ext in (".tsv", ".csv")
isbiom(ext::AbstractString) = ext == ".biom"
isedgelist(ext::AbstractString) = ext == ".edgelist"
isgml(ext::AbstractString) = ext == ".gml"


"""
Expand Down Expand Up @@ -49,7 +50,7 @@ end
"""
save_network(net_path::AbstractString, net_result::FWResult) -> Void
Save network results to disk. Available formats are '.tsv', '.csv', '.jld' and '.jld2'.
Save network results to disk. Available formats are '.tsv', '.csv', '.gml', '.jld' and '.jld2'.
- `net_path` - output path for the network
Expand All @@ -60,10 +61,10 @@ Save network results to disk. Available formats are '.tsv', '.csv', '.jld' and '
function save_network(net_path::AbstractString, net_result::FWResult; detailed::Bool=false)
file_ext = splitext(net_path)[2]
if isedgelist(file_ext)
write_edgelist(net_path, graph(net_result))
write_edgelist(net_path, net_result)
elseif isgml(file_ext)
write_gml(net_path, net_result)
elseif isjld(file_ext)
# isdefined(:FileIO) || @eval using FileIO: save, load
# Base.invokelatest(save, net_path, "results", net_result)
save(net_path, "results", net_result)
else
error("$(file_ext) not a valid output format. Choose one of $(valid_net_formats)")
Expand All @@ -80,22 +81,18 @@ end
"""
load_network(net_path::AbstractString) -> FWResult{Int}
Load network results from disk. Available formats are '.tsv', '.csv', '.jld' and '.jld2'.
Load network results from disk. Available formats are '.tsv', '.csv', '.gml', '.jld' and '.jld2'. For GraphML, only files with structure identical to save_network('network.gml') output can currently be loaded.
- `net_path` - path from which to load the network results
"""
function load_network(net_path::AbstractString)
file_ext = splitext(net_path)[2]
if isedgelist(file_ext)
G = read_edgelist(net_path)
net_result = FWResult(G)

net_result = read_edgelist(net_path)
elseif isgml(file_ext)
net_result = read_gml(net_path)
elseif isjld(file_ext)
# isdefined(:FileIO) || @eval using FileIO: save, load
# d = Base.invokelatest(load, net_path)
# net_result = d["results"]
net_result = load(net_path, "results")

else
error("$(file_ext) not a valid network format. Valid formats are $(valid_net_formats)")
end
Expand All @@ -108,14 +105,12 @@ end

function load_jld(data_path::AbstractString, otu_data_key::AbstractString, otu_header_key::AbstractString,
meta_data_key=nothing, meta_header_key=nothing; transposed::Bool=false)
# isdefined(:FileIO) || @eval using FileIO: save, load
# d = Base.invokelatest(load, data_path)
d = load(data_path)

data = d[otu_data_key]
header = d[otu_header_key]

if meta_data_key != nothing
if meta_data_key != nothing && haskey(d, meta_data_key) && haskey(d, meta_header_key)
meta_data = d[meta_data_key]
meta_header = d[meta_header_key]
else
Expand Down Expand Up @@ -250,8 +245,15 @@ end



function write_edgelist(out_path::AbstractString, G::SimpleWeightedGraph; header=nothing)
function write_edgelist(out_path::AbstractString, net_result::FWResult)
G = graph(net_result)
meta_mask = net_result.meta_variable_mask
header = names(net_result)

open(out_path, "w") do out_f
write(out_f, "# header\t", join(header, ","), "\n")
write(out_f, "# meta mask\t", join(meta_mask, ","), "\n")

for e in edges(G)
if header == nothing
e1 = e.src
Expand All @@ -266,31 +268,138 @@ function write_edgelist(out_path::AbstractString, G::SimpleWeightedGraph; header
end


function read_edgelist(in_path::AbstractString; header=nothing)
function read_edgelist(in_path::AbstractString)
srcs = Int[]
dsts = Int[]
ws = Float64[]

if header != nothing
header, meta_mask = open(in_path, "r") do in_f
header_items = split(readline(in_f), "\t")[end]
header = Vector{String}(split(header_items, ","))
inv_header = Dict{eltype(header), Int}(zip(header, 1:length(header)))
end

open(in_path, "r") do in_f
meta_items = split(readline(in_f), "\t")[end]
meta_mask = BitVector(map(x->parse(Bool, x), split(meta_items, ",")))

for line in eachline(in_f)
line_items = split(chomp(line), '\t')

if header != nothing
src = inv_header[line_items[1]]
dst = inv_header[line_items[2]]
else
src = parse(Int, line_items[1])
dst = parse(Int, line_items[2])
end
src = inv_header[line_items[1]]
dst = inv_header[line_items[2]]

push!(srcs, src)
push!(dsts, dst)
push!(ws, parse(Float64, line_items[end]))
end

header, meta_mask
end
G = SimpleWeightedGraph(srcs, dsts, ws)
net_result = FWResult(G; variable_ids=header, meta_variable_mask=meta_mask)
end


function write_gml(out_path::AbstractString, net_result::FWResult)
G = graph(net_result)
header = names(net_result)
meta_mask = net_result.meta_variable_mask

open(out_path, "w") do out_f
write(out_f, "graph [", "\n")
write(out_f, "\tdirected 0", "\n")

for node in vertices(G)
write(out_f, "\tnode [", "\n")
write(out_f, "\t\tid " * string(node), "\n")
write(out_f, "\t\tlabel \"" * header[node] * "\"", "\n")
write(out_f, "\t\tmv " * string(Int(meta_mask[node])), "\n")
write(out_f, "\t]", "\n")
end

for e in edges(G)
e1, e2, weight = e.src, e.dst, e.weight
write(out_f, "\tedge [", "\n")
write(out_f, "\t\tsource " * string(e1), "\n")
write(out_f, "\t\ttarget " * string(e2), "\n")
write(out_f, "\t\tweight " * string(weight), "\n")
write(out_f, "\t]", "\n")
end

write(out_f, "]", "\n")
end
nothing
end


function parse_gml_field(in_f::IO)
line = strip(readline(in_f))
info_pairs = Tuple[]

if !(startswith(line, "node") || startswith(line, "edge"))
return info_pairs
end

if startswith(line, "node") || startswith(line, "edge")
while !startswith(line, "]")
push!(info_pairs, Tuple(split(line)))
line = strip(readline(in_f))
end
end

# field_type = info_pairs[1][1]

# if field_type == "node"
# @NT(ent=field_type, id=parse(Int, info_pairs[2][2]), label=info_pairs[3][2],
# mv=parse(Bool, info_pairs[4][2]))
# elseif field_type == "edge"
# @NT(ent=field_type, source=parse(Int, info_pairs[2][2]), target=info_pairs[3][2],
# weight=parse(Float64, info_pairs[4][2]))
# else
# error("$field_type is not a valid field.")
# end
info_pairs
end


function read_gml(in_path::AbstractString)
node_dict = Dict{Int,Vector{Tuple}}()

srcs = Int[]
dsts = Int[]
ws = Float64[]

header, meta_mask = open(in_path, "r") do in_f
line = readline(in_f)
line = readline(in_f)

node_info = parse_gml_field(in_f)#[("node", "")]# @NT(ent="node")
while node_info[1][1] == "node"
#println(node_info)
node_id = parse(Int, node_info[2][2])
node_dict[node_id] = node_info
node_info = parse_gml_field(in_f)
end
#println(node_info)

n_nodes = maximum(keys(node_dict))
header = fill("", n_nodes)#Vector{String}(n_nodes)
meta_mask = falses(n_nodes)#BitVector(n_nodes)
for (node_id, n_inf) in node_dict
header[node_id] = n_inf[3][2][2:end-1]#node_info.label
meta_mask[node_id] = Bool(parse(Int, n_inf[4][2]))
end

edge_info = node_info#[("edge", "")]#@NT(ent="edge")
while !isempty(edge_info) && edge_info[1][1] == "edge"
#println(edge_info)
push!(srcs, parse(Int, edge_info[2][2]))
push!(dsts, parse(Int, edge_info[3][2]))
push!(ws, parse(Float64, edge_info[4][2]))
edge_info = parse_gml_field(in_f)
end

header, meta_mask
end
SimpleWeightedGraph(srcs, dsts, ws)
G = SimpleWeightedGraph(srcs, dsts, ws)
net_result = FWResult(G; variable_ids=header, meta_variable_mask=meta_mask)
end
4 changes: 2 additions & 2 deletions src/learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ function infer_conditional_neighbors(target_vars::Vector{Int}, data::AbstractMat
end

if verbose
println("Starting conditioning search..")
println("\nStarting conditioning search..")
tic()
end

Expand Down Expand Up @@ -391,7 +391,7 @@ Learn an interaction network from a data table (including OTUs and optionally me
- `alpha` - threshold used to determine statistical significance
- `conv` - convergence threshold, i.e. if `conv=0.01` assume convergence if the number of edges increased by only 1% after doubling the runtime
- `conv` - convergence threshold, i.e. if `conv=0.01` assume convergence if the number of edges increased by only 1% after 100% more runtime (checked in intervals)
- `feed_forward` - enable feed-forward heuristic
Expand Down
16 changes: 10 additions & 6 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,11 @@ struct FWResult{T<:Integer}
parameters::Dict{Symbol,Any}
end

function FWResult(inf_results::LGLResult{T}, params=nothing, variable_ids=nothing, meta_variable_mask=nothing) where T<:Integer
function FWResult(inf_results::LGLResult{T}; variable_ids=nothing, meta_variable_mask=nothing,
parameters=nothing) where T<:Integer
n_vars = nv(inf_results.graph)
if params == nothing
params = Dict{Symbol,Any}()
if parameters == nothing
parameters = Dict{Symbol,Any}()
end

if variable_ids == nothing
Expand All @@ -189,10 +190,10 @@ function FWResult(inf_results::LGLResult{T}, params=nothing, variable_ids=nothin
@assert n_vars == length(variable_ids) "variable_ids do not fit number of variables"
@assert n_vars == length(meta_variable_mask) "meta_variable_mask does not fit number of variables"

FWResult(inf_results, variable_ids, meta_variable_mask, params)
FWResult(inf_results, variable_ids, meta_variable_mask, parameters)
end

FWResult(G::SimpleWeightedGraph) = FWResult(LGLResult(G))
FWResult(G::SimpleWeightedGraph; kwargs...) = FWResult(LGLResult(G); kwargs...)


"""
Expand All @@ -216,8 +217,11 @@ parameters(result::FWResult{T}) where T<:Integer = result.parameters
Extract the IDs/names of all variables (nodes) in the network.
"""
variable_ids(result::FWResult{T}) where T<:Integer = result.variable_ids
names(result::FWResult{T}) where T<:Integer = result.variable_ids
meta_variable_mask(result::FWResult{T}) where T<:Integer = result.meta_variable_mask
converged(result::FWResult{T}) where T<:Integer = !isempty(result.inference_results.unfinished_states)
==(result1::FWResult{T}, result2::FWResult{S}) where {T<:Integer, S<:Integer} =
all([f(result1) == f(result2) for f in (graph, names, meta_variable_mask)])

function unchecked_statistics(result::FWResult)
unf_states_dict = unfinished_states(result)
Expand Down
2 changes: 1 addition & 1 deletion test/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ net_result = load_network(joinpath("data", "io_expected.jld2"))
@testset "networks" begin
tmp_path = tempname()

for net_format in ["edgelist", "jld2", "jld"]
for net_format in ["edgelist", "gml", "jld2", "jld"]
@testset "$net_format" begin
tmp_net_path = tmp_path * "." * net_format
save_network(tmp_net_path, net_result)
Expand Down

0 comments on commit 55b0cea

Please sign in to comment.