AlphaZero version 6

First AlphaZero version with a value output. This model was trained from scratch on 1.000.000 training examples from the StageThree dataset on a 5x4 board. The model was trained for 32 epochs.

In [1]:
import sys
sys.path.append('..')

import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug

from keras.callbacks import *
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.initializers import *
from keras.utils.np_utils import to_categorical
from keras.utils import plot_model
import keras.backend as K
from keras.regularizers import l2
from keras.engine.topology import Layer

from PIL import Image
from matplotlib.pyplot import imshow
%matplotlib inline
import random
import gc

from LineFilterLayer import LineFilterLayer

modelPath = 'model/alphaZeroV6.h5'

datasetPath = 'StageThree-1000000-5x4-22:56-19_04_2018.npz'
Using TensorFlow backend.
In [2]:
print(K.image_data_format()) 
# expected output: channels_last
channels_last
In [3]:
def dotsAndBoxesToCategorical(inputData):
    inp = np.copy(inputData)
    inp[inp == 255] = 1 # Line - comes first so that target data only has two categories
    inp[inp == 65] = 2 # Box A
    inp[inp == 150] = 3 # Box B
    inp[inp == 215] = 4 # Dot
    cat = to_categorical(inp)
    newShape = inp.shape + (cat.shape[-1],)
    return cat.reshape(newShape)
In [4]:
def imgSizeToBoxes(x):
    return (x-3)/2

def lineFilterMatrixNP(imgWidth,imgHeight):
    boxWidth = imgSizeToBoxes(imgWidth)
    boxHeight = imgSizeToBoxes(imgHeight)
    linesCnt = 2*boxWidth*boxHeight+boxWidth+boxHeight
    mat = np.zeros((imgHeight, imgWidth), dtype=np.bool)
    for idx in range(linesCnt):
        y1 = idx / ((2*boxWidth) + 1)
        if idx % ((2*boxWidth) + 1) < boxWidth:
            # horizontal line
            x1 = idx % ((2*boxWidth) + 1)
            x2 = x1 + 1
            y2 = y1
        else:
            # vertical line
            x1 = idx % ((2*boxWidth) + 1) - boxWidth
            x2 = x1
            y2 = y1 + 1
        px = x2 * 2 + y2 - y1
        py = y2 * 2 + x2 - x1
        mat[py,px] = 1
    return mat
In [5]:
def loadPVDataset(datasetPath):
    rawDataset = np.load(datasetPath)
    
    x_input = rawDataset['input']
    y_policy = rawDataset['policy']
    y_value = rawDataset['value']
    
    x_input = dotsAndBoxesToCategorical(x_input)
    y_policy = y_policy[:,lineFilterMatrixNP(y_policy.shape[-1], y_policy.shape[-2])]
    y_policy /= 255
    
    return (x_input, y_policy, y_value)

np.set_printoptions(precision=2)
(x_input, y_policy, y_value) = loadPVDataset(datasetPath)
In [6]:
print(x_input.shape)
print(y_policy.shape)
print(y_value.shape)
print("input:")
print(x_input[0,::,::,1])
print("policy:")
print(y_policy[0])
print('value:')
print(y_value[0])
(1000000, 11, 13, 5)
(1000000, 49)
(1000000, 1)
input:
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0.]
 [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
 [0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0.]
 [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
 [0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
policy:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0]
value:
[-0.88]
In [7]:
kernelSize = (5,5)
filterCnt = 64
l2reg = 1e-4
resBlockCnt = 4
imgWidth = x_input.shape[-2]
imgHeight = x_input.shape[-3]

def build_residual_block(x, index):
        in_x = x
        res_name = "res"+str(index)
        x = Conv2D(filters=filterCnt, kernel_size=kernelSize, padding="same",
                   data_format="channels_last", kernel_regularizer=l2(l2reg), 
                   name=res_name+"_conv1_"+str(filterCnt))(x)
        x = BatchNormalization(name=res_name+"_batchnorm1")(x)
        x = Activation("relu",name=res_name+"_relu1")(x)
        x = Conv2D(filters=filterCnt, kernel_size=kernelSize, padding="same",
                   data_format="channels_last", kernel_regularizer=l2(l2reg), 
                   name=res_name+"_conv2-"+str(filterCnt))(x)
        x = BatchNormalization(name="res"+str(index)+"_batchnorm2")(x)
        x = Add(name=res_name+"_add")([in_x, x])
        x = Activation("relu", name=res_name+"_relu2")(x)
        return x


img_input = Input(shape=(imgHeight,imgWidth,5,))
x = Conv2D(filterCnt, kernelSize, padding='same', kernel_regularizer=l2(l2reg), name="input_conv")(img_input)
x = Activation("relu", name="input_relu")(x)
x = BatchNormalization()(x)

for i in range(resBlockCnt):
    x = build_residual_block(x, i+1)

res_out = x

# policy output
x = Conv2D(1, kernelSize, padding='same', kernel_regularizer=l2(l2reg), name="policy_conv")(x)
x = LineFilterLayer(imgWidth, imgHeight)(x)
x = Activation("softmax", name="policy_softmax")(x)
policy_output = x

# value output
x = Conv2D(1, kernelSize, padding='same', kernel_regularizer=l2(l2reg), name="value_conv")(res_out)
#x = LineFilterLayer(imgWidth, imgHeight)(x)
#x = Conv2D(filters=1, kernel_size=1, use_bias=False, kernel_regularizer=l2(l2reg), name="value_conv")(x)
x = Flatten()(x)
x = Dense(1, trainable=False, kernel_initializer=Constant(1.0/(imgWidth*imgHeight)), use_bias=False, name="value_dense")(x)
x = Activation("tanh", name="value_tanh")(x)
value_output = x
    
model = Model(inputs=img_input, outputs=[policy_output, value_output])
model.compile(optimizer='adam', loss=['categorical_crossentropy', 'mean_squared_error'])

for layer in model.layers:
    print("{:30}: {}".format(layer.name, layer.output_shape))
    if layer.name is 'value_dense':
        print(layer.kernel)
    
model.summary()
input_1                       : (None, 11, 13, 5)
input_conv                    : (None, 11, 13, 64)
input_relu                    : (None, 11, 13, 64)
batch_normalization_1         : (None, 11, 13, 64)
res1_conv1_64                 : (None, 11, 13, 64)
res1_batchnorm1               : (None, 11, 13, 64)
res1_relu1                    : (None, 11, 13, 64)
res1_conv2-64                 : (None, 11, 13, 64)
res1_batchnorm2               : (None, 11, 13, 64)
res1_add                      : (None, 11, 13, 64)
res1_relu2                    : (None, 11, 13, 64)
res2_conv1_64                 : (None, 11, 13, 64)
res2_batchnorm1               : (None, 11, 13, 64)
res2_relu1                    : (None, 11, 13, 64)
res2_conv2-64                 : (None, 11, 13, 64)
res2_batchnorm2               : (None, 11, 13, 64)
res2_add                      : (None, 11, 13, 64)
res2_relu2                    : (None, 11, 13, 64)
res3_conv1_64                 : (None, 11, 13, 64)
res3_batchnorm1               : (None, 11, 13, 64)
res3_relu1                    : (None, 11, 13, 64)
res3_conv2-64                 : (None, 11, 13, 64)
res3_batchnorm2               : (None, 11, 13, 64)
res3_add                      : (None, 11, 13, 64)
res3_relu2                    : (None, 11, 13, 64)
res4_conv1_64                 : (None, 11, 13, 64)
res4_batchnorm1               : (None, 11, 13, 64)
res4_relu1                    : (None, 11, 13, 64)
res4_conv2-64                 : (None, 11, 13, 64)
res4_batchnorm2               : (None, 11, 13, 64)
res4_add                      : (None, 11, 13, 64)
res4_relu2                    : (None, 11, 13, 64)
value_conv                    : (None, 11, 13, 1)
policy_conv                   : (None, 11, 13, 1)
flatten_1                     : (None, 143)
line_filter_layer_1           : (None, 49)
value_dense                   : (None, 1)
<tf.Variable 'value_dense/kernel:0' shape=(143, 1) dtype=float32_ref>
policy_softmax                : (None, 49)
value_tanh                    : (None, 1)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 11, 13, 5)    0                                            
__________________________________________________________________________________________________
input_conv (Conv2D)             (None, 11, 13, 64)   8064        input_1[0][0]                    
__________________________________________________________________________________________________
input_relu (Activation)         (None, 11, 13, 64)   0           input_conv[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 11, 13, 64)   256         input_relu[0][0]                 
__________________________________________________________________________________________________
res1_conv1_64 (Conv2D)          (None, 11, 13, 64)   102464      batch_normalization_1[0][0]      
__________________________________________________________________________________________________
res1_batchnorm1 (BatchNormaliza (None, 11, 13, 64)   256         res1_conv1_64[0][0]              
__________________________________________________________________________________________________
res1_relu1 (Activation)         (None, 11, 13, 64)   0           res1_batchnorm1[0][0]            
__________________________________________________________________________________________________
res1_conv2-64 (Conv2D)          (None, 11, 13, 64)   102464      res1_relu1[0][0]                 
__________________________________________________________________________________________________
res1_batchnorm2 (BatchNormaliza (None, 11, 13, 64)   256         res1_conv2-64[0][0]              
__________________________________________________________________________________________________
res1_add (Add)                  (None, 11, 13, 64)   0           batch_normalization_1[0][0]      
                                                                 res1_batchnorm2[0][0]            
__________________________________________________________________________________________________
res1_relu2 (Activation)         (None, 11, 13, 64)   0           res1_add[0][0]                   
__________________________________________________________________________________________________
res2_conv1_64 (Conv2D)          (None, 11, 13, 64)   102464      res1_relu2[0][0]                 
__________________________________________________________________________________________________
res2_batchnorm1 (BatchNormaliza (None, 11, 13, 64)   256         res2_conv1_64[0][0]              
__________________________________________________________________________________________________
res2_relu1 (Activation)         (None, 11, 13, 64)   0           res2_batchnorm1[0][0]            
__________________________________________________________________________________________________
res2_conv2-64 (Conv2D)          (None, 11, 13, 64)   102464      res2_relu1[0][0]                 
__________________________________________________________________________________________________
res2_batchnorm2 (BatchNormaliza (None, 11, 13, 64)   256         res2_conv2-64[0][0]              
__________________________________________________________________________________________________
res2_add (Add)                  (None, 11, 13, 64)   0           res1_relu2[0][0]                 
                                                                 res2_batchnorm2[0][0]            
__________________________________________________________________________________________________
res2_relu2 (Activation)         (None, 11, 13, 64)   0           res2_add[0][0]                   
__________________________________________________________________________________________________
res3_conv1_64 (Conv2D)          (None, 11, 13, 64)   102464      res2_relu2[0][0]                 
__________________________________________________________________________________________________
res3_batchnorm1 (BatchNormaliza (None, 11, 13, 64)   256         res3_conv1_64[0][0]              
__________________________________________________________________________________________________
res3_relu1 (Activation)         (None, 11, 13, 64)   0           res3_batchnorm1[0][0]            
__________________________________________________________________________________________________
res3_conv2-64 (Conv2D)          (None, 11, 13, 64)   102464      res3_relu1[0][0]                 
__________________________________________________________________________________________________
res3_batchnorm2 (BatchNormaliza (None, 11, 13, 64)   256         res3_conv2-64[0][0]              
__________________________________________________________________________________________________
res3_add (Add)                  (None, 11, 13, 64)   0           res2_relu2[0][0]                 
                                                                 res3_batchnorm2[0][0]            
__________________________________________________________________________________________________
res3_relu2 (Activation)         (None, 11, 13, 64)   0           res3_add[0][0]                   
__________________________________________________________________________________________________
res4_conv1_64 (Conv2D)          (None, 11, 13, 64)   102464      res3_relu2[0][0]                 
__________________________________________________________________________________________________
res4_batchnorm1 (BatchNormaliza (None, 11, 13, 64)   256         res4_conv1_64[0][0]              
__________________________________________________________________________________________________
res4_relu1 (Activation)         (None, 11, 13, 64)   0           res4_batchnorm1[0][0]            
__________________________________________________________________________________________________
res4_conv2-64 (Conv2D)          (None, 11, 13, 64)   102464      res4_relu1[0][0]                 
__________________________________________________________________________________________________
res4_batchnorm2 (BatchNormaliza (None, 11, 13, 64)   256         res4_conv2-64[0][0]              
__________________________________________________________________________________________________
res4_add (Add)                  (None, 11, 13, 64)   0           res3_relu2[0][0]                 
                                                                 res4_batchnorm2[0][0]            
__________________________________________________________________________________________________
res4_relu2 (Activation)         (None, 11, 13, 64)   0           res4_add[0][0]                   
__________________________________________________________________________________________________
value_conv (Conv2D)             (None, 11, 13, 1)    1601        res4_relu2[0][0]                 
__________________________________________________________________________________________________
policy_conv (Conv2D)            (None, 11, 13, 1)    1601        res4_relu2[0][0]                 
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 143)          0           value_conv[0][0]                 
__________________________________________________________________________________________________
line_filter_layer_1 (LineFilter (None, 49)           0           policy_conv[0][0]                
__________________________________________________________________________________________________
value_dense (Dense)             (None, 1)            143         flatten_1[0][0]                  
__________________________________________________________________________________________________
policy_softmax (Activation)     (None, 49)           0           line_filter_layer_1[0][0]        
__________________________________________________________________________________________________
value_tanh (Activation)         (None, 1)            0           value_dense[0][0]                
==================================================================================================
Total params: 833,425
Trainable params: 832,130
Non-trainable params: 1,295
__________________________________________________________________________________________________
In [8]:
#sess = K.get_session()
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
#K.set_session(sess)

# Training
callbacks = []

checkpoint = ModelCheckpoint(filepath=modelPath+".checkpoint", save_weights_only=False)
callbacks.append(checkpoint)

progbar = ProgbarLogger()
callbacks.append(progbar)

tensorboard = TensorBoard(log_dir='model/log2', write_grads=True, write_graph=True, write_images=True, histogram_freq=1)
#callbacks.append(tensorboard)

model.fit(x_input, [y_policy, y_value], epochs=32, batch_size=64, callbacks=callbacks, validation_split=0.001)

model.save(modelPath)
Train on 999000 samples, validate on 1000 samples
Epoch 1/32
Epoch 1/32
999000/999000 [==============================] - 764s 765us/step - loss: 1.3799 - policy_softmax_loss: 1.0871 - value_tanh_loss: 0.2222 - val_loss: 1.2723 - val_policy_softmax_loss: 0.9926 - val_value_tanh_loss: 0.2155
999000/999000 [==============================] - 764s 765us/step - loss: 1.3799 - policy_softmax_loss: 1.0871 - value_tanh_loss: 0.2222 - val_loss: 1.2723 - val_policy_softmax_loss: 0.9926 - val_value_tanh_loss: 0.2155
Epoch 2/32
Epoch 2/32
999000/999000 [==============================] - 766s 767us/step - loss: 1.2500 - policy_softmax_loss: 1.0048 - value_tanh_loss: 0.1872 - val_loss: 1.2092 - val_policy_softmax_loss: 0.9748 - val_value_tanh_loss: 0.1808
999000/999000 [==============================] - 766s 767us/step - loss: 1.2500 - policy_softmax_loss: 1.0048 - value_tanh_loss: 0.1872 - val_loss: 1.2092 - val_policy_softmax_loss: 0.9748 - val_value_tanh_loss: 0.1808
Epoch 3/32
Epoch 3/32
999000/999000 [==============================] - 798s 799us/step - loss: 1.2240 - policy_softmax_loss: 0.9933 - value_tanh_loss: 0.1799 - val_loss: 1.1839 - val_policy_softmax_loss: 0.9555 - val_value_tanh_loss: 0.1798
999000/999000 [==============================] - 798s 799us/step - loss: 1.2240 - policy_softmax_loss: 0.9933 - value_tanh_loss: 0.1799 - val_loss: 1.1839 - val_policy_softmax_loss: 0.9555 - val_value_tanh_loss: 0.1798
Epoch 4/32
Epoch 4/32
999000/999000 [==============================] - 784s 785us/step - loss: 1.2134 - policy_softmax_loss: 0.9883 - value_tanh_loss: 0.1776 - val_loss: 1.1873 - val_policy_softmax_loss: 0.9611 - val_value_tanh_loss: 0.1796 - po
999000/999000 [==============================] - 784s 785us/step - loss: 1.2134 - policy_softmax_loss: 0.9883 - value_tanh_loss: 0.1776 - val_loss: 1.1873 - val_policy_softmax_loss: 0.9611 - val_value_tanh_loss: 0.1796
Epoch 5/32
Epoch 5/32
999000/999000 [==============================] - 784s 785us/step - loss: 1.2076 - policy_softmax_loss: 0.9858 - value_tanh_loss: 0.1759 - val_loss: 1.1736 - val_policy_softmax_loss: 0.9502 - val_value_tanh_loss: 0.1782
999000/999000 [==============================] - 785s 785us/step - loss: 1.2076 - policy_softmax_loss: 0.9858 - value_tanh_loss: 0.1759 - val_loss: 1.1736 - val_policy_softmax_loss: 0.9502 - val_value_tanh_loss: 0.1782
Epoch 6/32
Epoch 6/32
999000/999000 [==============================] - 784s 785us/step - loss: 1.2036 - policy_softmax_loss: 0.9838 - value_tanh_loss: 0.1747 - val_loss: 1.2402 - val_policy_softmax_loss: 0.9777 - val_value_tanh_loss: 0.2176ax_loss: 0.9838 -
999000/999000 [==============================] - 784s 785us/step - loss: 1.2036 - policy_softmax_loss: 0.9838 - value_tanh_loss: 0.1747 - val_loss: 1.2402 - val_policy_softmax_loss: 0.9777 - val_value_tanh_loss: 0.2176
Epoch 7/32
Epoch 7/32
999000/999000 [==============================] - 782s 783us/step - loss: 1.2005 - policy_softmax_loss: 0.9823 - value_tanh_loss: 0.1737 - val_loss: 1.1690 - val_policy_softmax_loss: 0.9487 - val_value_tanh_loss: 0.1764
999000/999000 [==============================] - 782s 783us/step - loss: 1.2005 - policy_softmax_loss: 0.9823 - value_tanh_loss: 0.1737 - val_loss: 1.1690 - val_policy_softmax_loss: 0.9487 - val_value_tanh_loss: 0.1764
Epoch 8/32
Epoch 8/32
999000/999000 [==============================] - 781s 782us/step - loss: 1.1982 - policy_softmax_loss: 0.9813 - value_tanh_loss: 0.1729 - val_loss: 1.1855 - val_policy_softmax_loss: 0.9576 - val_value_tanh_loss: 0.1846loss: 1.1982 - policy_softmax_loss: 0.9814 - value_t - ETA: 0s - loss: 1.1982 - policy_softmax_loss: 0.9813 - val
999000/999000 [==============================] - 781s 782us/step - loss: 1.1982 - policy_softmax_loss: 0.9813 - value_tanh_loss: 0.1729 - val_loss: 1.1855 - val_policy_softmax_loss: 0.9576 - val_value_tanh_loss: 0.1846
Epoch 9/32
Epoch 9/32
999000/999000 [==============================] - 779s 779us/step - loss: 1.1966 - policy_softmax_loss: 0.9804 - value_tanh_loss: 0.1727 - val_loss: 1.1686 - val_policy_softmax_loss: 0.9501 - val_value_tanh_loss: 0.1754
999000/999000 [==============================] - 779s 779us/step - loss: 1.1966 - policy_softmax_loss: 0.9804 - value_tanh_loss: 0.1727 - val_loss: 1.1686 - val_policy_softmax_loss: 0.9501 - val_value_tanh_loss: 0.1754
Epoch 10/32
Epoch 10/32
999000/999000 [==============================] - 777s 778us/step - loss: 1.1955 - policy_softmax_loss: 0.9799 - value_tanh_loss: 0.1723 - val_loss: 1.1736 - val_policy_softmax_loss: 0.9559 - val_value_tanh_loss: 0.1751
999000/999000 [==============================] - 777s 778us/step - loss: 1.1955 - policy_softmax_loss: 0.9799 - value_tanh_loss: 0.1723 - val_loss: 1.1736 - val_policy_softmax_loss: 0.9559 - val_value_tanh_loss: 0.1751
Epoch 11/32
Epoch 11/32
999000/999000 [==============================] - 777s 777us/step - loss: 1.1940 - policy_softmax_loss: 0.9792 - value_tanh_loss: 0.1720 - val_loss: 1.1826 - val_policy_softmax_loss: 0.9649 - val_value_tanh_loss: 0.1749
999000/999000 [==============================] - 777s 777us/step - loss: 1.1940 - policy_softmax_loss: 0.9792 - value_tanh_loss: 0.1720 - val_loss: 1.1826 - val_policy_softmax_loss: 0.9649 - val_value_tanh_loss: 0.1749
Epoch 12/32
Epoch 12/32
999000/999000 [==============================] - 777s 778us/step - loss: 1.1928 - policy_softmax_loss: 0.9786 - value_tanh_loss: 0.1718 - val_loss: 1.1584 - val_policy_softmax_loss: 0.9413 - val_value_tanh_loss: 0.1751: 0.17
999000/999000 [==============================] - 777s 778us/step - loss: 1.1928 - policy_softmax_loss: 0.9786 - value_tanh_loss: 0.1718 - val_loss: 1.1584 - val_policy_softmax_loss: 0.9413 - val_value_tanh_loss: 0.1751
Epoch 13/32
Epoch 13/32
999000/999000 [==============================] - 778s 779us/step - loss: 1.1919 - policy_softmax_loss: 0.9784 - value_tanh_loss: 0.1715 - val_loss: 1.1556 - val_policy_softmax_loss: 0.9424 - val_value_tanh_loss: 0.1716 1.1919 - policy_softmax
999000/999000 [==============================] - 778s 779us/step - loss: 1.1919 - policy_softmax_loss: 0.9784 - value_tanh_loss: 0.1715 - val_loss: 1.1556 - val_policy_softmax_loss: 0.9424 - val_value_tanh_loss: 0.1716
Epoch 14/32
Epoch 14/32
999000/999000 [==============================] - 776s 777us/step - loss: 1.1904 - policy_softmax_loss: 0.9778 - value_tanh_loss: 0.1711 - val_loss: 1.1605 - val_policy_softmax_loss: 0.9516 - val_value_tanh_loss: 0.1674
999000/999000 [==============================] - 776s 777us/step - loss: 1.1904 - policy_softmax_loss: 0.9778 - value_tanh_loss: 0.1711 - val_loss: 1.1605 - val_policy_softmax_loss: 0.9516 - val_value_tanh_loss: 0.1674
Epoch 15/32
Epoch 15/32
999000/999000 [==============================] - 776s 777us/step - loss: 1.1902 - policy_softmax_loss: 0.9779 - value_tanh_loss: 0.1708 - val_loss: 1.1546 - val_policy_softmax_loss: 0.9378 - val_value_tanh_loss: 0.1756
999000/999000 [==============================] - 776s 777us/step - loss: 1.1902 - policy_softmax_loss: 0.9779 - value_tanh_loss: 0.1708 - val_loss: 1.1546 - val_policy_softmax_loss: 0.9378 - val_value_tanh_loss: 0.1756
Epoch 16/32
Epoch 16/32
999000/999000 [==============================] - 778s 779us/step - loss: 1.1892 - policy_softmax_loss: 0.9775 - value_tanh_loss: 0.1704 - val_loss: 1.1809 - val_policy_softmax_loss: 0.9654 - val_value_tanh_loss: 0.1744 loss: 1.1892  - ETA: 3s - loss: 1.1891 - polic - ETA: 1s - loss: 1.1891 - poli
999000/999000 [==============================] - 778s 779us/step - loss: 1.1892 - policy_softmax_loss: 0.9775 - value_tanh_loss: 0.1704 - val_loss: 1.1809 - val_policy_softmax_loss: 0.9654 - val_value_tanh_loss: 0.1744
Epoch 17/32
Epoch 17/32
999000/999000 [==============================] - 760s 761us/step - loss: 1.1880 - policy_softmax_loss: 0.9770 - value_tanh_loss: 0.1700 - val_loss: 1.1669 - val_policy_softmax_loss: 0.9528 - val_value_tanh_loss: 0.1730
999000/999000 [==============================] - 760s 761us/step - loss: 1.1880 - policy_softmax_loss: 0.9770 - value_tanh_loss: 0.1700 - val_loss: 1.1669 - val_policy_softmax_loss: 0.9528 - val_value_tanh_loss: 0.1730
Epoch 18/32
Epoch 18/32
999000/999000 [==============================] - 751s 751us/step - loss: 1.1874 - policy_softmax_loss: 0.9768 - value_tanh_loss: 0.1697 - val_loss: 1.1479 - val_policy_softmax_loss: 0.9376 - val_value_tanh_loss: 0.1695
999000/999000 [==============================] - 751s 751us/step - loss: 1.1874 - policy_softmax_loss: 0.9768 - value_tanh_loss: 0.1697 - val_loss: 1.1479 - val_policy_softmax_loss: 0.9376 - val_value_tanh_loss: 0.1695
Epoch 19/32
Epoch 19/32
999000/999000 [==============================] - 750s 751us/step - loss: 1.1867 - policy_softmax_loss: 0.9764 - value_tanh_loss: 0.1694 - val_loss: 1.1642 - val_policy_softmax_loss: 0.9536 - val_value_tanh_loss: 0.1700
999000/999000 [==============================] - 750s 751us/step - loss: 1.1867 - policy_softmax_loss: 0.9764 - value_tanh_loss: 0.1694 - val_loss: 1.1642 - val_policy_softmax_loss: 0.9536 - val_value_tanh_loss: 0.1700
Epoch 20/32
Epoch 20/32
999000/999000 [==============================] - 751s 751us/step - loss: 1.1870 - policy_softmax_loss: 0.9768 - value_tanh_loss: 0.1696 - val_loss: 1.1704 - val_policy_softmax_loss: 0.9493 - val_value_tanh_loss: 0.1805
999000/999000 [==============================] - 751s 752us/step - loss: 1.1870 - policy_softmax_loss: 0.9768 - value_tanh_loss: 0.1696 - val_loss: 1.1704 - val_policy_softmax_loss: 0.9493 - val_value_tanh_loss: 0.1805
Epoch 21/32
Epoch 21/32
999000/999000 [==============================] - 751s 752us/step - loss: 1.1859 - policy_softmax_loss: 0.9761 - value_tanh_loss: 0.1694 - val_loss: 1.2022 - val_policy_softmax_loss: 0.9886 - val_value_tanh_loss: 0.1734
999000/999000 [==============================] - 751s 752us/step - loss: 1.1859 - policy_softmax_loss: 0.9761 - value_tanh_loss: 0.1694 - val_loss: 1.2022 - val_policy_softmax_loss: 0.9886 - val_value_tanh_loss: 0.1734
Epoch 22/32
Epoch 22/32
999000/999000 [==============================] - 752s 753us/step - loss: 1.1855 - policy_softmax_loss: 0.9760 - value_tanh_loss: 0.1692 - val_loss: 1.1614 - val_policy_softmax_loss: 0.9486 - val_value_tanh_loss: 0.1726
999000/999000 [==============================] - 752s 753us/step - loss: 1.1855 - policy_softmax_loss: 0.9760 - value_tanh_loss: 0.1692 - val_loss: 1.1614 - val_policy_softmax_loss: 0.9486 - val_value_tanh_loss: 0.1726
Epoch 23/32
Epoch 23/32
999000/999000 [==============================] - 752s 752us/step - loss: 1.1851 - policy_softmax_loss: 0.9759 - value_tanh_loss: 0.1690 - val_loss: 1.1598 - val_policy_softmax_loss: 0.9447 - val_value_tanh_loss: 0.1749
999000/999000 [==============================] - 752s 753us/step - loss: 1.1851 - policy_softmax_loss: 0.9759 - value_tanh_loss: 0.1690 - val_loss: 1.1598 - val_policy_softmax_loss: 0.9447 - val_value_tanh_loss: 0.1749
Epoch 24/32
Epoch 24/32
999000/999000 [==============================] - 752s 753us/step - loss: 1.1847 - policy_softmax_loss: 0.9757 - value_tanh_loss: 0.1689 - val_loss: 1.1608 - val_policy_softmax_loss: 0.9506 - val_value_tanh_loss: 0.1700
999000/999000 [==============================] - 752s 753us/step - loss: 1.1847 - policy_softmax_loss: 0.9757 - value_tanh_loss: 0.1689 - val_loss: 1.1608 - val_policy_softmax_loss: 0.9506 - val_value_tanh_loss: 0.1700
Epoch 25/32
Epoch 25/32
999000/999000 [==============================] - 752s 753us/step - loss: 1.1845 - policy_softmax_loss: 0.9755 - value_tanh_loss: 0.1689 - val_loss: 1.1514 - val_policy_softmax_loss: 0.9390 - val_value_tanh_loss: 0.1723
999000/999000 [==============================] - 752s 753us/step - loss: 1.1845 - policy_softmax_loss: 0.9755 - value_tanh_loss: 0.1689 - val_loss: 1.1514 - val_policy_softmax_loss: 0.9390 - val_value_tanh_loss: 0.1723
Epoch 26/32
Epoch 26/32
999000/999000 [==============================] - 752s 753us/step - loss: 1.1839 - policy_softmax_loss: 0.9752 - value_tanh_loss: 0.1688 - val_loss: 1.1511 - val_policy_softmax_loss: 0.9442 - val_value_tanh_loss: 0.1676
999000/999000 [==============================] - 752s 753us/step - loss: 1.1839 - policy_softmax_loss: 0.9752 - value_tanh_loss: 0.1688 - val_loss: 1.1511 - val_policy_softmax_loss: 0.9442 - val_value_tanh_loss: 0.1676
Epoch 27/32
Epoch 27/32
999000/999000 [==============================] - 752s 752us/step - loss: 1.1834 - policy_softmax_loss: 0.9750 - value_tanh_loss: 0.1687 - val_loss: 1.1527 - val_policy_softmax_loss: 0.9448 - val_value_tanh_loss: 0.1684
999000/999000 [==============================] - 752s 752us/step - loss: 1.1834 - policy_softmax_loss: 0.9750 - value_tanh_loss: 0.1687 - val_loss: 1.1527 - val_policy_softmax_loss: 0.9448 - val_value_tanh_loss: 0.1684
Epoch 28/32
Epoch 28/32
999000/999000 [==============================] - 750s 751us/step - loss: 1.1833 - policy_softmax_loss: 0.9750 - value_tanh_loss: 0.1686 - val_loss: 1.1505 - val_policy_softmax_loss: 0.9393 - val_value_tanh_loss: 0.1712
999000/999000 [==============================] - 750s 751us/step - loss: 1.1833 - policy_softmax_loss: 0.9750 - value_tanh_loss: 0.1686 - val_loss: 1.1505 - val_policy_softmax_loss: 0.9393 - val_value_tanh_loss: 0.1712
Epoch 29/32
Epoch 29/32
999000/999000 [==============================] - 749s 750us/step - loss: 1.1828 - policy_softmax_loss: 0.9746 - value_tanh_loss: 0.1685 - val_loss: 1.1857 - val_policy_softmax_loss: 0.9811 - val_value_tanh_loss: 0.1651
999000/999000 [==============================] - 749s 750us/step - loss: 1.1828 - policy_softmax_loss: 0.9746 - value_tanh_loss: 0.1685 - val_loss: 1.1857 - val_policy_softmax_loss: 0.9811 - val_value_tanh_loss: 0.1651
Epoch 30/32
Epoch 30/32
999000/999000 [==============================] - 748s 749us/step - loss: 1.1825 - policy_softmax_loss: 0.9744 - value_tanh_loss: 0.1686 - val_loss: 1.1482 - val_policy_softmax_loss: 0.9399 - val_value_tanh_loss: 0.1687
999000/999000 [==============================] - 749s 749us/step - loss: 1.1825 - policy_softmax_loss: 0.9744 - value_tanh_loss: 0.1686 - val_loss: 1.1482 - val_policy_softmax_loss: 0.9399 - val_value_tanh_loss: 0.1687
Epoch 31/32
Epoch 31/32
999000/999000 [==============================] - 748s 748us/step - loss: 1.1821 - policy_softmax_loss: 0.9742 - value_tanh_loss: 0.1685 - val_loss: 1.1558 - val_policy_softmax_loss: 0.9449 - val_value_tanh_loss: 0.1715
999000/999000 [==============================] - 748s 749us/step - loss: 1.1821 - policy_softmax_loss: 0.9742 - value_tanh_loss: 0.1685 - val_loss: 1.1558 - val_policy_softmax_loss: 0.9449 - val_value_tanh_loss: 0.1715
Epoch 32/32
Epoch 32/32
999000/999000 [==============================] - 748s 749us/step - loss: 1.1821 - policy_softmax_loss: 0.9744 - value_tanh_loss: 0.1683 - val_loss: 1.1431 - val_policy_softmax_loss: 0.9379 - val_value_tanh_loss: 0.1658
999000/999000 [==============================] - 748s 749us/step - loss: 1.1821 - policy_softmax_loss: 0.9744 - value_tanh_loss: 0.1683 - val_loss: 1.1431 - val_policy_softmax_loss: 0.9379 - val_value_tanh_loss: 0.1658
In [9]:
def linesToDotsAndBoxesImage(lines, imgWidth, imgHeight):
    boxWidth = imgSizeToBoxes(imgWidth)
    boxHeight = imgSizeToBoxes(imgHeight)
    linesCnt = 2*boxWidth*boxHeight+boxWidth+boxHeight
    mat = np.zeros((imgHeight, imgWidth), dtype=lines.dtype)
    for idx in range(linesCnt):
        y1 = idx / ((2*boxWidth) + 1)
        if idx % ((2*boxWidth) + 1) < boxWidth:
            # horizontal line
            x1 = idx % ((2*boxWidth) + 1)
            x2 = x1 + 1
            y2 = y1
        else:
            # vertical line
            x1 = idx % ((2*boxWidth) + 1) - boxWidth
            x2 = x1
            y2 = y1 + 1
        px = x2 * 2 + y2 - y1
        py = y2 * 2 + x2 - x1
        mat[py,px] = lines[idx]
    return mat
In [11]:
example = random.randrange(x_input.shape[0])
print("example: "+str(example))

input_data = x_input[example:example+1]

(prediction_lines, prediction_value) = model.predict(input_data)
prediction_lines_print = prediction_lines * 100
print(prediction_lines_print.astype(np.uint8))
print(np.sum(prediction_lines))
prediction = linesToDotsAndBoxesImage(prediction_lines[0], imgWidth, imgHeight)

# print input data
input_data_print = x_input[example,:,:,1] 
input_data_print = input_data_print.astype(np.uint8)
print("input "+str(input_data_print.shape)+": ")
print(input_data_print)

# generate greyscale image data from input data
planes = [1,2,3,4]
input_imgdata = np.sum(x_input[example,:,:,1:], axis=-1) * 255
input_imgdata = input_imgdata.astype(np.uint8)

# print prediction
prediction_data_print = prediction * 100 
prediction_data_print = prediction_data_print.astype(np.uint8)
print("prediction policy: ")
print(prediction_data_print)

print("prediction value: ")
print(prediction_value)

print("target value: ")
print(y_value[example])

# generate greyscale image data from prediction data
prediction_imgdata = prediction * 255
prediction_imgdata = prediction_imgdata.astype(np.uint8)

# generate greyscale image of target data
target_imgdata = linesToDotsAndBoxesImage(y_policy[example], imgWidth, imgHeight) * 255

# merge image data in color channels
merged_imgdata = np.stack([input_imgdata, prediction_imgdata, target_imgdata], axis=2)

#create image
img = Image.fromarray(merged_imgdata, 'RGB')
img = img.resize(size=(img.size[0]*10, img.size[1]*10))

img
example: 667178
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 99  0  0  0  0  0  0  0
   0]]
1.0
input (11, 13): 
[[0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 1 0 1 0 1 0 0]
 [0 1 0 1 0 0 0 1 0 0 0 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 1 0 0 0 1 0 1 0]
 [0 0 1 0 1 0 1 0 1 0 0 0 0]
 [0 1 0 1 0 1 0 0 0 1 0 0 0]
 [0 0 1 0 1 0 0 0 0 0 1 0 0]
 [0 1 0 1 0 0 0 1 0 1 0 1 0]
 [0 0 1 0 1 0 1 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0]]
prediction policy: 
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0 99  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]]
prediction value: 
[[-0.75]]
target value: 
[-0.88]
Out[11]:
In [ ]: