tensorflow で K-means法

元ネタはFIRST CONTACT WITH TENSORFLOWという本で、参照元はK-Means.ipynb

import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from numpy.random import normal

def kmeans(num_points, k):
    vectors_set = [(normal(0.0, 0.9), normal(0.0, 0.9)) for _ in range(int(num_points/2))]
    vectors_set.extend((normal(3.0, 0.5), normal(1.0, 0.5)) for _ in range(int(num_points/2)))

    vectors= tf.constant(vectors_set)
    centroides = tf.Variable(tf.slice(tf.random_shuffle(vectors), [0, 0], [k, -1]))

    expanded_vectors = tf.expand_dims(vectors, 0)
    expanded_centroides = tf.expand_dims(centroides, 1)

    sub = tf.sub(expanded_vectors, expanded_centroides)
    square = tf.square(sub)
    sum = tf.reduce_sum(square, 2)
    assignments = tf.argmin(sum, 0)

    means = tf.concat(0, [tf.reduce_mean(tf.gather(vectors, tf.reshape(tf.where(tf.equal(assignments, c)), [1, -1])),
                                         reduction_indices=[1]) for c in range(k)])
    update_centroides = tf.assign(centroides, means)


    init_op = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init_op)

    epoch = 2000
    for step in range(epoch):
        _, centroid_values, assignment_values = sess.run([update_centroides, centroides, assignments])

    data = {"x":[], "y":[], "cluster":[]}
    for i in range(len(assignment_values)):
        data["x"].append(vectors_set[i][0])
        data["y"].append(vectors_set[i][1])
        data["cluster"].append(assignment_values[i])
    df = pd.DataFrame(data)
    sns.lmplot("x", "y", data=df, fit_reg=False, size=6, hue="cluster", legend=False)

    plt.show()

if __name__ == "__main__":
    kmeans(4000, 20)

k = 4の場合、



k=10の場合、






















【使われている関数群】
tf.expand_dims(input, dim, name=None)
inputのtensorにdimの場所へ次元を追加する。
例えば、スカラーをベクトルへ、ベクトルをマトリックスへ変換する。
この例では差を計算(tf.sub)するためにあえて次元を追加している。

tf.concat(concat_dim, values, name='concat')
テンソルをくっつける。
concat_dimでくっつける次元を選べる。

a = tf.constant([[1,2],[3,4]])
b = tf.constant([[5,6],[7,8]])
c0 = tf.concat(0,[a,b])
c1 = tf.concat(1,[a,b])
sess = tf.Session()
r0, r1 = sess.run([c0,c1])
print(r0)
print(r1)
実行御結果は以下
[[1 2]
 [3 4]
 [5 6]
 [7 8]]

[[1 2 5 6]
 [3 4 7 8]]

コメント

このブログの人気の投稿

slackでgeneralの投稿を全削除する

Python SQLite スレッド間でコネクションの使いまわしは出来ない

slackで投稿内容を自動翻訳する(3/5)slackにおけるメッセージの構造