From f4308297a4039ac7509fd730283b18a3e0503457 Mon Sep 17 00:00:00 2001 From: Toby Driscoll Date: Mon, 18 Nov 2019 13:19:34 -0500 Subject: [PATCH] Added silhouette plot --- Project.toml | 1 + src/StatsPlots.jl | 3 ++ src/silhouetteplot.jl | 83 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 src/silhouetteplot.jl diff --git a/Project.toml b/Project.toml index b469fa2..1c3bb0f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.12.0" Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DataValues = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" diff --git a/src/StatsPlots.jl b/src/StatsPlots.jl index 4efeefe..7026a1c 100644 --- a/src/StatsPlots.jl +++ b/src/StatsPlots.jl @@ -13,7 +13,9 @@ using Widgets, Observables import Observables: AbstractObservable, @map, observe import Widgets: @nodeps import DataStructures: OrderedDict +using Distances: pairwise, Euclidean import Clustering: Hclust, nnodes +using Clustering: ClusteringResult, silhouettes, assignments, counts using Interpolations import MultivariateStats: MDS, eigvals, projection, principalvars, principalratio, transform @@ -40,5 +42,6 @@ include("bar.jl") include("dendrogram.jl") include("andrews.jl") include("ordinations.jl") +include("silhouetteplot.jl") end # module diff --git a/src/silhouetteplot.jl b/src/silhouetteplot.jl new file mode 100644 index 0000000..88d1a67 --- /dev/null +++ b/src/silhouetteplot.jl @@ -0,0 +1,83 @@ +""" + silhouetteplot(C,X[,D];...) +Make a silhouette plot to assess the quality of a clustering. `C` must be a `ClusteringResult` (see the `Clustering` package), and `X` is a matrix in which each column represents a data point. If supplied, `D` should be a distance matrix (as in `Distances`); otherwise, pairwise Euclidean distances are used. + +Each data point has a silhouette score between -1 and 1 indicating how unambiguously the point belongs to its assigned cluster. These are sorted within each cluster and portrayed using horizontal bars. Also shown is a dashed line at the average score. Typically a high-quality clustering has significant numbers of bars within each cluster that cross the line, and few negative scores overall. + +See also: [`Clustering.silhouettes`](@ref), [`Distances`](@ref). + +# Examples + +``` +using Clustering, Distances, Plots +# random dataset with 3-ish clusters in 5 dimensions +X = hcat([rand(5,1) .+ 0.2*randn(5, 200) for _=1:3]...) +D = pairwise(Euclidean(),X,dims=2) +R = kmeans(D, 3; maxiter=200, display=:iter) + +silhouetteplot(R,X,D) +``` +""" +silhouetteplot + +@userplot SilhouettePlot +@recipe function f(h::SilhouettePlot)#R::ClusteringResult,X::AbstractArray,D::AbstractMatrix=[];distance=Euclidean()) + narg = length(h.args) + @assert narg > 1 "At least two arguments are required." + R = h.args[1] + @assert R isa ClusteringResult "First argument must be a ClusteringResult." + X = h.args[2] + @assert X isa AbstractArray "Second argument must be an array." + if narg > 2 + D = h.args[3] + @assert D isa AbstractMatrix "Third argument must be a distance matrix." + else + D = pairwise(Euclidean(),X,dims=2) + end + + a = assignments(R) # assignments to clusters + c = counts(R) # cluster sizes + k = length(c) # number of clusters + n = sum(c) # number of points overall + + s = silhouettes(R,D) + + # Settings for the axes + legend --> false + yflip := true + xlims := [min(-0.1,minimum(s)),1] + # y ticks used to show cluster boundaries, and labels to show the sizes + yticks := cumsum([0;c]),["0",["+$z" for z in c]...] + + # Generate the polygons for each cluster. + offset = 0; + plt = plot([],label="") + for i in 1:k + idx = (a.==i) # members of cluster i + si = sort(s[idx],rev=true) + @series begin + linealpha --> 0 + seriestype := :shape + label := "$i" + x = [0;repeat(si,inner=(2));0] + y = offset .+ repeat(0:c[i],inner=(2)) + x,y + end + # text label to the left of the bars + @series begin + linealpha := 0 + series_annotations := [ Plots.text("$i",:center,:middle,9) ] + [-0.04], [offset+c[i]/2] + end + offset += c[i]; + end + + # Dashed line for overall average. + savg = sum(s)/n + @series begin + linecolor := :black + linestyle := :dash + label := "" + [savg,savg], [0,n] + end +end