diff --git a/kmeans.py b/kmeans.py index db26e69..94d48c4 100644 --- a/kmeans.py +++ b/kmeans.py @@ -19,6 +19,7 @@ sns.lmplot("x", "y", data=df, fit_reg=False, size=7) plt.show() vectors = tf.constant(vector_values) +# choose the first num_clusters points as centroids centroids = tf.Variable(tf.slice(tf.random_shuffle(vectors), [0,0],[num_clusters,-1])) expanded_vectors = tf.expand_dims(vectors, 0)