Quality vs k
Silhouette score versus number of clusters for different algorithms.
1"""Generate `silhouette_vs_k.png` comparing CLARANS, FastCLARANS, and KMeans.
2"""
3from pathlib import Path
4import matplotlib
5matplotlib.use("Agg")
6import matplotlib.pyplot as plt
7from sklearn.datasets import make_blobs
8from sklearn.cluster import KMeans
9from sklearn.metrics import silhouette_score
10from clarans import CLARANS, FastCLARANS
11
12
13def main():
14 X, _ = make_blobs(n_samples=500, centers=4, cluster_std=0.60, random_state=42)
15 ks = range(2, 9)
16 methods = {
17 "CLARANS": lambda k: CLARANS(n_clusters=k, numlocal=3, random_state=42),
18 "FastCLARANS": lambda k: FastCLARANS(n_clusters=k, numlocal=3, random_state=42),
19 "KMeans": lambda k: KMeans(n_clusters=k, random_state=42),
20 }
21
22 results = {name: [] for name in methods}
23
24 for k in ks:
25 for name, factory in methods.items():
26 model = factory(k)
27 model.fit(X)
28 labels = model.labels_
29 if len(set(labels)) > 1:
30 score = silhouette_score(X, labels)
31 else:
32 score = float("nan")
33 results[name].append(score)
34
35 fig, ax = plt.subplots(figsize=(6, 4))
36 for name, scores in results.items():
37 ax.plot(list(ks), scores, marker="o", label=name)
38 ax.set_xlabel("k (n_clusters)")
39 ax.set_ylabel("Silhouette score")
40 ax.set_title("Silhouette score vs k")
41 ax.legend()
42
43 out = "silhouette_vs_k.png"
44 fig.savefig(out, bbox_inches="tight", dpi=150)
45 print(f"Saved {out}")
46
47
48if __name__ == "__main__":
49 main()