-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmarginal_plot.R
116 lines (106 loc) · 5.5 KB
/
marginal_plot.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
marginal_plot = function(x, y, group = NULL, data = NULL, lm_show = FALSE, lm_formula = y ~ x, bw = "nrd0", adjust = 1, alpha = 1, plot_legend = T, ...){
require(scales)
###############
# Plots a scatterplot with marginal probability density functions for x and y.
# Data may be grouped or ungrouped.
# For each group, a linear fit can be plotted. It is hidden by default, but can be shown by providing lm_show = TRUE.
# The model can be modified using the 'lm_formula' argument.
# The 'bw' and 'adjust' argument specify the granularity used for estimating probability density functions. See ?density for more information.
# For large datasets, opacity may be decreased by setting alpha to a value between 0 and 1.
# Additional graphical parameters are passed to the main plot, so you can customize axis labels, titles etc.
###############
moreargs = eval(substitute(list(...)))
# prepare consistent df
if(missing(group)){
if(missing(data)){
if(length(x) != length(y)){stop("Length of arguments not equal")}
data = data.frame(x = as.numeric(x), y = as.numeric(y))
} else {
data = data.frame(x = as.numeric(data[,deparse(substitute(x))]),
y = as.numeric(data[,deparse(substitute(y))]))
}
if(sum(!complete.cases(data)) > 0){
warning(sprintf("Removed %i rows with missing data", sum(!complete.cases(data))))
data = data[complete.cases(data),]
}
group_colors = "black"
} else {
if(missing(data)){
if(length(x) != length(y) | length(x) != length(group)){stop("Length of arguments not equal")}
data = data.frame(x = as.numeric(x), y = as.numeric(y), group = as.factor(group))
} else {
data = data.frame(x = as.numeric(data[,deparse(substitute(x))]),
y = as.numeric(data[,deparse(substitute(y))]),
group = as.factor(data[,deparse(substitute(group))]))
}
if(sum(!complete.cases(data)) > 0){
warning(sprintf("Removed %i rows with missing data", sum(!complete.cases(data))))
data = data[complete.cases(data),]
}
data = subset(data, group %in% names(which(table(data$group) > 5)))
data$group = droplevels(data$group)
group_colors = rainbow(length(unique(data$group)))
}
# log-transform data (this is need for correct plotting of density functions)
if(!is.null(moreargs$log)){
if(!moreargs$log %in% c("y", "x", "yx", "xy")){
warning("Ignoring invalid 'log' argument. Use 'y', 'x', 'yx' or 'xy.")
} else {
data = data[apply(data[unlist(strsplit(moreargs$log, ""))], 1, function(x) !any(x <= 0)), ]
data[,unlist(strsplit(moreargs$log, ""))] = log10(data[,unlist(strsplit(moreargs$log, ""))])
}
moreargs$log = NULL # remove to prevent double logarithm when plotting
}
# Catch unwanted user inputs
if(!is.null(moreargs$col)){moreargs$col = NULL}
if(!is.null(moreargs$type)){moreargs$type = "p"}
# get some default plotting arguments
if(is.null(moreargs$xlim)){moreargs$xlim = range(data$x)}
if(is.null(moreargs$ylim)){moreargs$ylim = range(data$y)}
if(is.null(moreargs$xlab)){moreargs$xlab = deparse(substitute(x))}
if(is.null(moreargs$ylab)){moreargs$ylab = deparse(substitute(y))}
if(is.null(moreargs$las)){moreargs$las = 1}
# plotting
tryCatch(expr = {
ifelse(!is.null(data$group), data_split <- split(data, data$group), data_split <- list(data))
orig_par = par(no.readonly = T)
par(mar = c(0.25,5,1,0))
layout(matrix(1:4, nrow = 2, byrow = T), widths = c(10,3), heights = c(3,10))
# upper density plot
plot(NULL, type = "n", xlim = moreargs$xlim, ylab = "density",
ylim = c(0, max(sapply(data_split, function(group_set) max(density(group_set$x, bw = bw)$y)))), main = NA, axes = F)
axis(2, las = 1)
mapply(function(group_set, group_color){lines(density(group_set$x, bw = bw, adjust = adjust), col = group_color, lwd = 2)}, data_split, group_colors)
# legend
par(mar = c(0.25,0.25,0,0))
plot.new()
if(!missing(group) & plot_legend){
legend("center", levels(data$group), fill = group_colors, border = group_colors, bty = "n", title = deparse(substitute(group)), title.adj = 0.1)
}
# main plot
par(mar = c(4,5,0,0))
if(missing(group)){
do.call(plot, c(list(x = quote(data$x), y = quote(data$y), col = quote(scales::alpha("black", alpha))), moreargs))
} else {
do.call(plot, c(list(x = quote(data$x), y = quote(data$y), col = quote(scales::alpha(group_colors[data$group], alpha))), moreargs))
}
axis(3, labels = F, tck = 0.01)
axis(4, labels = F, tck = 0.01)
box()
if(lm_show == TRUE & !is.null(lm_formula)){
mapply(function(group_set, group_color){
lm_tmp = lm(lm_formula, data = group_set)
x_coords = seq(min(group_set$x), max(group_set$x), length.out = 100)
y_coords = predict(lm_tmp, newdata = data.frame(x = x_coords))
lines(x = x_coords, y = y_coords, col = group_color, lwd = 2.5)
}, data_split, rgb(t(ceiling(col2rgb(group_colors)*0.8)), maxColorValue = 255))
}
# right density plot
par(mar = c(4,0.25,0,1))
plot(NULL, type = "n", ylim = moreargs$ylim, xlim = c(0, max(sapply(data_split, function(group_set) max(density(group_set$y, bw = bw)$y)))), main = NA, axes = F, xlab = "density")
mapply(function(group_set, group_color){lines(x = density(group_set$y, bw = bw, adjust = adjust)$y, y = density(group_set$y, bw = bw)$x, col = group_color, lwd = 2)}, data_split, group_colors)
axis(1)
}, finally = {
par(orig_par)
})
}