tensorflow で 自己符号化器

MLPのChapter5の自己符号化器をTensorflowで試す。
誤差関数は二乗誤差、、、の平方根。
何故か平方根取らないと値が発散してしまった。。。

・データ:MNIST
・重み共有:なし
・恒等写像
・中間層:40ユニット
・ミニパッチ(100サンプル毎)


import tensorflow as tf
from PIL import Image
import os


def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def output_image(path, num_images, size, iimg, oimg, prefix=""):
    if len(iimg) > size[0]*size[1]*num_images:
        iimg = iimg[: size[0]*size[1]*num_images]
    if len(oimg) > size[0]*size[1]*num_images:
        oimg = oimg[: size[0]*size[1]*num_images]
    pdata = []
    for i in range(size[1]*num_images):
        st = size[1]*i
        ed = size[1]*(i+1)
        pdata.extend(iimg[st:ed])
        pdata.extend(oimg[st:ed])
    im = Image.new('L', (size[0]*2, size[1]*num_images))
    im.putdata(pdata, 256)
    im.save(os.path.join(path, prefix + '.png'))


def main():
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

    x = tf.placeholder(tf.float32, [None, 784])
    W1 = weight_variable([784, 40])
    b1 = weight_variable([40])
    y1 = tf.matmul(x, W1) + b1
    W2 = weight_variable([40, 784])
    b2 = weight_variable([784])
    y2 = tf.matmul(y1, W2) + b2

    loss = tf.sqrt(tf.reduce_sum(tf.square(tf.sub(y2, x))))

    iimg = tf.reshape(x, [-1])
    oimg = tf.reshape(y2,[-1])
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)

    for i in range(2001):
        trainbatch, _ = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: trainbatch})
        if i%200 == 0:
            testbatch, _ = mnist.train.next_batch(mnist.test.num_examples)
            result, ii, oi = sess.run([loss, iimg, oimg], feed_dict={x: testbatch})
            print(result)
            output_image("MNIST_data", 10, (28,28), ii, oi, '{0:06d}'.format(i) + '_')

if __name__ == "__main__":
    main()

誤差は以下の通り段々収束する。

1474.56
710.956
497.43
435.52
401.355
375.737
365.105
349.923
348.165
346.122
344.344

200ステップ毎にout_image()にてpng画像を出力しており結果は以下の通り。
各10サンプルについて入力(左)出力(右)を比較。

1ステップ後

201ステップ後

401ステップ後


2001ステップ後



中間層が40ユニットでも結構復元できてる。
次はスパース正規化にトライしたい。が。

コメント

このブログの人気の投稿

slackでgeneralの投稿を全削除する

Python SQLite スレッド間でコネクションの使いまわしは出来ない

slackで投稿内容を自動翻訳する(3/5)slackにおけるメッセージの構造