Quality vs k

Silhouette score vs k

Silhouette score versus number of clusters for different algorithms.

 1"""Generate `silhouette_vs_k.png` comparing CLARANS, FastCLARANS, and KMeans.
 2"""
 3from pathlib import Path
 4import matplotlib
 5matplotlib.use("Agg")
 6import matplotlib.pyplot as plt
 7from sklearn.datasets import make_blobs
 8from sklearn.cluster import KMeans
 9from sklearn.metrics import silhouette_score
10from clarans import CLARANS, FastCLARANS
11
12
13def main():
14    X, _ = make_blobs(n_samples=500, centers=4, cluster_std=0.60, random_state=42)
15    ks = range(2, 9)
16    methods = {
17        "CLARANS": lambda k: CLARANS(n_clusters=k, numlocal=3, random_state=42),
18        "FastCLARANS": lambda k: FastCLARANS(n_clusters=k, numlocal=3, random_state=42),
19        "KMeans": lambda k: KMeans(n_clusters=k, random_state=42),
20    }
21
22    results = {name: [] for name in methods}
23
24    for k in ks:
25        for name, factory in methods.items():
26            model = factory(k)
27            model.fit(X)
28            labels = model.labels_
29            if len(set(labels)) > 1:
30                score = silhouette_score(X, labels)
31            else:
32                score = float("nan")
33            results[name].append(score)
34
35    fig, ax = plt.subplots(figsize=(6, 4))
36    for name, scores in results.items():
37        ax.plot(list(ks), scores, marker="o", label=name)
38    ax.set_xlabel("k (n_clusters)")
39    ax.set_ylabel("Silhouette score")
40    ax.set_title("Silhouette score vs k")
41    ax.legend()
42
43    out = "silhouette_vs_k.png"
44    fig.savefig(out, bbox_inches="tight", dpi=150)
45    print(f"Saved {out}")
46
47
48if __name__ == "__main__":
49    main()