diff --git a/src/df.jl b/src/df.jl index b600c71..d754742 100644 --- a/src/df.jl +++ b/src/df.jl @@ -29,10 +29,11 @@ function df_helper(x) end function df_helper(d, x) - if isa(x, Expr) && x.head == :block - commands = [df_helper(d, xx) for xx in x.args if !(isa(xx, Expr) && xx.head == :line || isa(xx, LineNumberNode))] + if isa(x, Expr) && x.head == :block # meaning that there were multiple plot commands + commands = [df_helper(d, xx) for xx in x.args if !(isa(xx, Expr) && xx.head == :line || isa(xx, LineNumberNode))] # apply the helper recursively return Expr(:block, commands...) - elseif isa(x, Expr) && x.head == :call + + elseif isa(x, Expr) && x.head == :call # each function call is operated on alone syms = Any[] vars = Symbol[] plot_call = parse_table_call!(d, x, syms, vars) @@ -47,6 +48,7 @@ function df_helper(d, x) label_plot_call = Expr(:call, :(StatsPlots.add_label), argnames, plot_call.args...) end return Expr(:block, compute_vars, label_plot_call) + else error("Second argument ($x) can only be a block or function call") end @@ -133,6 +135,7 @@ compute_name(names, i::Int) = names[i] compute_name(names, i::Symbol) = i compute_name(names, i) = reshape([compute_name(names, ii) for ii in i], 1, :) +# This function ensures that labels are passed to the plottig command. function add_label(argnames, f, args...; kwargs...) i = findlast(t -> isa(t, Expr) || isa(t, AbstractArray), argnames) if (i === nothing) @@ -146,14 +149,32 @@ get_col(s::Int, col_nt, names) = col_nt[names[s]] get_col(s::Symbol, col_nt, names) = get(col_nt, s, s) get_col(syms, col_nt, names) = hcat((get_col(s, col_nt, names) for s in syms)...) +""" + extract_columns_and_names(df, syms...) + +Extracts columns and their names (if the column number is an integer) +into a slightly complex `Tuple`. + +The structure goes as `((columndata...), names)`. This is unpacked by the [`@df`](@ref) macro into `gensym`'ed variables, which are passed to the plotting function. + +!!! note + If you want to extend the [`@df`](@ref) macro + to work with your custom type, this is the + function you should overload! +""" function extract_columns_and_names(df, syms...) istable(df) || error("Only tables are supported") names = column_names(df) selected_cols = Symbol[] + + # get the appropriate name when passed an Integer add_sym!(s::Integer) = push!(selected_cols, names[s]) + # check for errors in Symbols add_sym!(s::Symbol) = s in names && push!(selected_cols, s) + # recursively extract column names add_sym!(s) = foreach(add_sym!, s) + foreach(add_sym!, syms) cols = columntable(select(df, unique(selected_cols)...))