Я використовую Spark Mlib для кластеризації kmeans. У мене є набір векторів, з яких я хочу визначити найімовірніший центр кластера. Тож я проведу тренування кластеризації kmeans на цьому наборі та виберу кластер із найбільшою кількістю присвоєних йому векторів.
Тому мені потрібно знати кількість векторівприсвоюється кожному кластеру після тренування (тобто KMeans.run (...)). Але я не можу знайти спосіб отримати цю інформацію з результату KMeanModel. Мені, мабуть, потрібно бігти predict
на всіх навчальних векторах і підраховуйте мітку, яка виявляється найбільше.
Чи є інший спосіб зробити це?
Дякую
Відповіді:
2 для відповіді № 1Ви маєте рацію, ця інформація не надається моделлю, і вам потрібно запустити predict
. Ось приклад цього робити паралельно (Spark v. 1.5.1):
from pyspark.mllib.clustering import KMeans
from numpy import array
data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0, 10.0, 9.0]).reshape(5, 2)
data
# array([[ 0., 0.],
# [ 1., 1.],
# [ 9., 8.],
# [ 8., 9.],
# [ 10., 9.]])
k = 2 # no. of clusters
model = KMeans.train(
sc.parallelize(data), k, maxIterations=10, runs=30, initializationMode="random",
seed=50, initializationSteps=5, epsilon=1e-4)
cluster_ind = model.predict(sc.parallelize(data))
cluster_ind.collect()
# [1, 1, 0, 0, 0]
cluster_ind
є RDD тієї самої кардинальності з нашоюпочаткові дані, і це показує, до якого кластеру належить кожна точка даних. Отже, у нас є два кластери: один з 3 точками (кластер 0) і один з 2 точками (кластер 1). Зауважте, що ми використовували метод прогнозування паралельно (тобто на RDD) - collect()
тут використовується лише для наших демонстраційних цілей, і він не потрібен у "реальній" ситуації.
Тепер ми можемо отримати розміри кластерів
cluster_sizes = cluster_ind.countByValue().items()
cluster_sizes
# [(0, 3), (1, 2)]
З цього ми можемо отримати максимальний індекс & розмір кластера як
from operator import itemgetter
max(cluster_sizes, key=itemgetter(1))
# (0, 3)
тобто наш найбільший кластер - це кластер 0, розмір 3 точок даних, який можна легко перевірити, перевіривши cluster_ind.collect()
вище.