From 5b8eb0dc57ab8dec14e7235a8509b896332acfe9 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Mon, 18 Nov 2019 13:19:34 -0500 Subject: [PATCH 1/7] Added silhouette plot --- Project.toml | 1 + src/StatsPlots.jl | 3 ++ src/silhouetteplot.jl | 83 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 src/silhouetteplot.jl diff --git a/Project.toml b/Project.toml index 0eb1f9d..f2237f5 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.14.13" Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DataValues = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" diff --git a/src/StatsPlots.jl b/src/StatsPlots.jl index 0cb14d7..4cb7170 100644 --- a/src/StatsPlots.jl +++ b/src/StatsPlots.jl @@ -16,7 +16,9 @@ using Widgets, Observables import Observables: AbstractObservable, @map, observe import Widgets: @nodeps import DataStructures: OrderedDict +using Distances: pairwise, Euclidean import Clustering: Hclust, nnodes +using Clustering: ClusteringResult, silhouettes, assignments, counts using Interpolations import MultivariateStats: MDS, eigvals, projection, principalvars, principalratio, transform @@ -45,5 +47,6 @@ include("dendrogram.jl") include("andrews.jl") include("ordinations.jl") include("covellipse.jl") +include("silhouetteplot.jl") end # module diff --git a/src/silhouetteplot.jl b/src/silhouetteplot.jl new file mode 100644 index 0000000..88d1a67 --- /dev/null +++ b/src/silhouetteplot.jl @@ -0,0 +1,83 @@ +""" + silhouetteplot(C,X[,D];...) +Make a silhouette plot to assess the quality of a clustering. `C` must be a `ClusteringResult` (see the `Clustering` package), and `X` is a matrix in which each column represents a data point. If supplied, `D` should be a distance matrix (as in `Distances`); otherwise, pairwise Euclidean distances are used. + +Each data point has a silhouette score between -1 and 1 indicating how unambiguously the point belongs to its assigned cluster. These are sorted within each cluster and portrayed using horizontal bars. Also shown is a dashed line at the average score. Typically a high-quality clustering has significant numbers of bars within each cluster that cross the line, and few negative scores overall. + +See also: [`Clustering.silhouettes`](@ref), [`Distances`](@ref). + +# Examples + +``` +using Clustering, Distances, Plots +# random dataset with 3-ish clusters in 5 dimensions +X = hcat([rand(5,1) .+ 0.2*randn(5, 200) for _=1:3]...) +D = pairwise(Euclidean(),X,dims=2) +R = kmeans(D, 3; maxiter=200, display=:iter) + +silhouetteplot(R,X,D) +``` +""" +silhouetteplot + +@userplot SilhouettePlot +@recipe function f(h::SilhouettePlot)#R::ClusteringResult,X::AbstractArray,D::AbstractMatrix=[];distance=Euclidean()) + narg = length(h.args) + @assert narg > 1 "At least two arguments are required." + R = h.args[1] + @assert R isa ClusteringResult "First argument must be a ClusteringResult." + X = h.args[2] + @assert X isa AbstractArray "Second argument must be an array." + if narg > 2 + D = h.args[3] + @assert D isa AbstractMatrix "Third argument must be a distance matrix." + else + D = pairwise(Euclidean(),X,dims=2) + end + + a = assignments(R) # assignments to clusters + c = counts(R) # cluster sizes + k = length(c) # number of clusters + n = sum(c) # number of points overall + + s = silhouettes(R,D) + + # Settings for the axes + legend --> false + yflip := true + xlims := [min(-0.1,minimum(s)),1] + # y ticks used to show cluster boundaries, and labels to show the sizes + yticks := cumsum([0;c]),["0",["+$z" for z in c]...] + + # Generate the polygons for each cluster. + offset = 0; + plt = plot([],label="") + for i in 1:k + idx = (a.==i) # members of cluster i + si = sort(s[idx],rev=true) + @series begin + linealpha --> 0 + seriestype := :shape + label := "$i" + x = [0;repeat(si,inner=(2));0] + y = offset .+ repeat(0:c[i],inner=(2)) + x,y + end + # text label to the left of the bars + @series begin + linealpha := 0 + series_annotations := [ Plots.text("$i",:center,:middle,9) ] + [-0.04], [offset+c[i]/2] + end + offset += c[i]; + end + + # Dashed line for overall average. + savg = sum(s)/n + @series begin + linecolor := :black + linestyle := :dash + label := "" + [savg,savg], [0,n] + end +end From b4500a9df1a7ded832de2d402e3120bcc4d343b9 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Tue, 19 Nov 2019 09:17:04 -0500 Subject: [PATCH 2/7] Addressing some concerns raised in #269. --- src/silhouetteplot.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/silhouetteplot.jl b/src/silhouetteplot.jl index 88d1a67..268f802 100644 --- a/src/silhouetteplot.jl +++ b/src/silhouetteplot.jl @@ -45,7 +45,7 @@ silhouetteplot # Settings for the axes legend --> false yflip := true - xlims := [min(-0.1,minimum(s)),1] + xlims --> [min(-0.1,minimum(s)),1] # y ticks used to show cluster boundaries, and labels to show the sizes yticks := cumsum([0;c]),["0",["+$z" for z in c]...] @@ -65,7 +65,8 @@ silhouetteplot end # text label to the left of the bars @series begin - linealpha := 0 + primary := false + seriesalpha := 0 series_annotations := [ Plots.text("$i",:center,:middle,9) ] [-0.04], [offset+c[i]/2] end @@ -75,9 +76,9 @@ silhouetteplot # Dashed line for overall average. savg = sum(s)/n @series begin + primary := false linecolor := :black linestyle := :dash - label := "" [savg,savg], [0,n] end end From 474165c92b732c49fda0ee1b15759445e59a0740 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Mon, 25 Nov 2019 14:08:16 -0500 Subject: [PATCH 3/7] Made yflip a settable option. --- src/silhouetteplot.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/silhouetteplot.jl b/src/silhouetteplot.jl index 268f802..8fcf85d 100644 --- a/src/silhouetteplot.jl +++ b/src/silhouetteplot.jl @@ -21,7 +21,7 @@ silhouetteplot(R,X,D) silhouetteplot @userplot SilhouettePlot -@recipe function f(h::SilhouettePlot)#R::ClusteringResult,X::AbstractArray,D::AbstractMatrix=[];distance=Euclidean()) +@recipe function f(h::SilhouettePlot;yflip=true)#R::ClusteringResult,X::AbstractArray,D::AbstractMatrix=[];distance=Euclidean()) narg = length(h.args) @assert narg > 1 "At least two arguments are required." R = h.args[1] @@ -44,7 +44,7 @@ silhouetteplot # Settings for the axes legend --> false - yflip := true + yflip := yflip xlims --> [min(-0.1,minimum(s)),1] # y ticks used to show cluster boundaries, and labels to show the sizes yticks := cumsum([0;c]),["0",["+$z" for z in c]...] @@ -54,7 +54,7 @@ silhouetteplot plt = plot([],label="") for i in 1:k idx = (a.==i) # members of cluster i - si = sort(s[idx],rev=true) + si = yflip ? sort(s[idx],rev=true) : sort(s[idx]) @series begin linealpha --> 0 seriestype := :shape From 1e9c936bbb2c6c9ddc7090afa20d457772791526 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Tue, 15 Sep 2020 13:52:47 -0400 Subject: [PATCH 4/7] updated README --- README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.md b/README.md index e99f306..c789ca0 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ This package is a drop-in replacement for Plots.jl that contains many statistica - corrplot/cornerplot - andrewsplot - MDS plot + - silhouetteplot It is thus slightly less lightweight, but has more functionality. Main documentation is found in the Plots.jl documentation (https://juliaplots.github.io). @@ -421,3 +422,21 @@ covellipse([0,2], [2 1; 1 4], n_std=2, aspect_ratio=1, label="cov1") covellipse!([1,0], [1 -0.5; -0.5 3], showaxes=true, label="cov2") ``` ![covariance ellipses](https://user-images.githubusercontent.com/4170948/84170978-f0c2f380-aa82-11ea-95de-ce2fe14e16ec.png) + +## Silhouette plots + +Silhouette plots are used to gauge the quality of a clustering. Each data point has a silhouette score between -1 and 1 indicating how unambiguously the point belongs to its assigned cluster. These are sorted within each cluster and portrayed using horizontal bars. Also shown is a dashed line at the average score. Typically a high-quality clustering has significant numbers of bars within each cluster that cross the line, and few negative scores overall. See [Rousseeuw 1987](https://doi.org/10.1016/0377-0427(87)90125-7) for details. + +```julia +using Clustering, LinearAlgebra, Random +Random.seed(1234); +# Make three good clusters in 5 dimensions +X = reduce(hcat,[3*randn(5,1) .+ randn(5,100+100n) for n=1:3]); # X is 5x900 +D = [ norm(x-y) for x in eachcol(X), y in eachcol(X) ]; # D is 900x900 +R = kmeans(D, 3; maxiter=200, display=:iter); +silhouetteplot(R,X,D) +``` + +![silhouetteplot](https://user-images.githubusercontent.com/3577518/93245856-eb4c0800-f759-11ea-8938-fa9c7fa20720.png) + +Requires a ClusteringResult as in the [Clustering](https://github.com/JuliaStats/Clustering.jl) package, and a distance matrix, such as one generated by the [Distances](https://github.com/JuliaStats/Distances.jl) package. From b6359872c7309b1d2d4f9cbf0677073a23529449 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Tue, 15 Sep 2020 15:26:18 -0400 Subject: [PATCH 5/7] updated example --- README.md | 10 +++++----- src/silhouetteplot.jl | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c789ca0..1170561 100644 --- a/README.md +++ b/README.md @@ -429,14 +429,14 @@ Silhouette plots are used to gauge the quality of a clustering. Each data point ```julia using Clustering, LinearAlgebra, Random -Random.seed(1234); -# Make three good clusters in 5 dimensions -X = reduce(hcat,[3*randn(5,1) .+ randn(5,100+100n) for n=1:3]); # X is 5x900 +Random.seed!(123); + +X = reduce(hcat,[2*randn(5,1) .+ randn(5,100+100n) for n=1:3]); # X is 5x900 D = [ norm(x-y) for x in eachcol(X), y in eachcol(X) ]; # D is 900x900 R = kmeans(D, 3; maxiter=200, display=:iter); silhouetteplot(R,X,D) ``` -![silhouetteplot](https://user-images.githubusercontent.com/3577518/93245856-eb4c0800-f759-11ea-8938-fa9c7fa20720.png) +![silhouetteplot](https://user-images.githubusercontent.com/3577518/93255055-551ede80-f767-11ea-843b-f5e2eb58a9b9.png) -Requires a ClusteringResult as in the [Clustering](https://github.com/JuliaStats/Clustering.jl) package, and a distance matrix, such as one generated by the [Distances](https://github.com/JuliaStats/Distances.jl) package. +Requires a ClusteringResult, as in the [Clustering](https://github.com/JuliaStats/Clustering.jl) package, and a distance matrix, such as one generated by the [Distances](https://github.com/JuliaStats/Distances.jl) package. diff --git a/src/silhouetteplot.jl b/src/silhouetteplot.jl index 8fcf85d..04ab132 100644 --- a/src/silhouetteplot.jl +++ b/src/silhouetteplot.jl @@ -9,12 +9,12 @@ See also: [`Clustering.silhouettes`](@ref), [`Distances`](@ref). # Examples ``` -using Clustering, Distances, Plots -# random dataset with 3-ish clusters in 5 dimensions -X = hcat([rand(5,1) .+ 0.2*randn(5, 200) for _=1:3]...) -D = pairwise(Euclidean(),X,dims=2) -R = kmeans(D, 3; maxiter=200, display=:iter) +using Clustering, LinearAlgebra, Random +Random.seed!(123); +X = reduce(hcat,[2*randn(5,1) .+ randn(5,100+100n) for n=1:3]); # X is 5x900 +D = [ norm(x-y) for x in eachcol(X), y in eachcol(X) ]; # D is 900x900 +R = kmeans(D, 3; maxiter=200); silhouetteplot(R,X,D) ``` """ From c8329bf92a31564c8f4f2417d53394f69db100d1 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Tue, 15 Sep 2020 15:26:30 -0400 Subject: [PATCH 6/7] removed Distances dependence --- src/StatsPlots.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/StatsPlots.jl b/src/StatsPlots.jl index 4cb7170..5486639 100644 --- a/src/StatsPlots.jl +++ b/src/StatsPlots.jl @@ -16,7 +16,6 @@ using Widgets, Observables import Observables: AbstractObservable, @map, observe import Widgets: @nodeps import DataStructures: OrderedDict -using Distances: pairwise, Euclidean import Clustering: Hclust, nnodes using Clustering: ClusteringResult, silhouettes, assignments, counts using Interpolations From f92c8b83aaa7b7b90cf28e68a307b0d9dda7697d Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Tue, 15 Sep 2020 15:26:30 -0400 Subject: [PATCH 7/7] removed Distances dependence --- Project.toml | 1 - src/StatsPlots.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index f2237f5..0eb1f9d 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.14.13" Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DataValues = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" diff --git a/src/StatsPlots.jl b/src/StatsPlots.jl index 4cb7170..5486639 100644 --- a/src/StatsPlots.jl +++ b/src/StatsPlots.jl @@ -16,7 +16,6 @@ using Widgets, Observables import Observables: AbstractObservable, @map, observe import Widgets: @nodeps import DataStructures: OrderedDict -using Distances: pairwise, Euclidean import Clustering: Hclust, nnodes using Clustering: ClusteringResult, silhouettes, assignments, counts using Interpolations