AlphaZero version 4

This model was trained from scratch on 1.000.000 training examples from the NEW StageOne dataset on a 5x4 board. The model was trained for 32 epochs.

In [1]:
import sys
sys.path.append('..')

import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug

from keras.callbacks import *
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.utils.np_utils import to_categorical
from keras.utils import plot_model
import keras.backend as K
from keras.regularizers import l2
from keras.engine.topology import Layer

from PIL import Image
from matplotlib.pyplot import imshow
%matplotlib inline
import random

from LineFilterLayer import LineFilterLayer

modelPath = 'model/alphaZeroV4.h5'
Using TensorFlow backend.
In [2]:
print(K.image_data_format()) 
# expected output: channels_last
channels_last
In [3]:
def dotsAndBoxesToCategorical(inputData):
    inp = np.copy(inputData)
    inp[inp == 255] = 1 # Line - comes first so that target data only has two categories
    inp[inp == 65] = 2 # Box A
    inp[inp == 150] = 3 # Box B
    inp[inp == 215] = 4 # Dot
    cat = to_categorical(inp)
    newShape = inp.shape + (cat.shape[-1],)
    return cat.reshape(newShape)
In [4]:
rawDataset = np.load('StageOne-5x4.npz')
x_train = rawDataset['x_train']
y_train = rawDataset['y_train']
x_train_cat = dotsAndBoxesToCategorical(x_train)
y_train_cat = dotsAndBoxesToCategorical(y_train)

np.set_printoptions(precision=2)

print("original data:")
print(x_train[0])
print(y_train[0])
print(x_train.shape)
print(y_train.shape)

print("\nnormalized data:")
print(np.transpose(x_train_cat[0]))
print(np.transpose(y_train_cat[0]))
print(x_train_cat.shape)
print(y_train_cat.shape)
original data:
[[  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0 215 255 215 255 215 255 215 255 215 255 215   0]
 [  0 255 150 255 150 255  65 255  65 255 150 255   0]
 [  0 215 255 215 255 215 255 215 255 215 255 215   0]
 [  0 255 150 255  65 255   0   0   0 255  65 255   0]
 [  0 215 255 215 255 215   0 215 255 215 255 215   0]
 [  0   0   0 255  65 255   0   0   0 255  65 255   0]
 [  0 215   0 215 255 215 255 215 255 215 255 215   0]
 [  0 255   0   0   0   0   0   0   0   0   0 255   0]
 [  0 215 255 215 255 215 255 215 255 215 255 215   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]]
[[  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0 255   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]]
(1000000, 11, 13)
(1000000, 11, 13)

normalized data:
[[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1.]
  [1. 0. 0. 0. 0. 0. 1. 1. 1. 0. 1.]
  [1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1.]
  [1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1.]
  [1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1.]
  [1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 1.]
  [1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1.]
  [1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1.]
  [1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1.]
  [1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1.]
  [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 1. 0. 0. 0. 1. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 0.]
  [0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 0. 0. 1. 0. 1. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]
[[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]
(1000000, 11, 13, 5)
(1000000, 11, 13, 2)
In [5]:
#LineFilterLayer.LineFilterLayer.imgWidth = 13
#LineFilterLayer.LineFilterLayer.imgHeight = 11
#model = load_model('model/alphaZeroV1.h5', custom_objects={'LineFilterLayer':LineFilterLayer.LineFilterLayer})
In [6]:
def imgSizeToBoxes(x):
    return (x-3)/2

def lineFilterMatrixNP(imgWidth,imgHeight):
    boxWidth = imgSizeToBoxes(imgWidth)
    boxHeight = imgSizeToBoxes(imgHeight)
    linesCnt = 2*boxWidth*boxHeight+boxWidth+boxHeight
    mat = np.zeros((imgHeight, imgWidth), dtype=np.bool)
    for idx in range(linesCnt):
        y1 = idx / ((2*boxWidth) + 1)
        if idx % ((2*boxWidth) + 1) < boxWidth:
            # horizontal line
            x1 = idx % ((2*boxWidth) + 1)
            x2 = x1 + 1
            y2 = y1
        else:
            # vertical line
            x1 = idx % ((2*boxWidth) + 1) - boxWidth
            x2 = x1
            y2 = y1 + 1
        px = x2 * 2 + y2 - y1
        py = y2 * 2 + x2 - x1
        mat[py,px] = 1
    return mat

lineFilterMatrixNP(13,11)
Out[6]:
array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False],
       [False, False,  True, False,  True, False,  True, False,  True,
        False,  True, False, False],
       [False,  True, False,  True, False,  True, False,  True, False,
         True, False,  True, False],
       [False, False,  True, False,  True, False,  True, False,  True,
        False,  True, False, False],
       [False,  True, False,  True, False,  True, False,  True, False,
         True, False,  True, False],
       [False, False,  True, False,  True, False,  True, False,  True,
        False,  True, False, False],
       [False,  True, False,  True, False,  True, False,  True, False,
         True, False,  True, False],
       [False, False,  True, False,  True, False,  True, False,  True,
        False,  True, False, False],
       [False,  True, False,  True, False,  True, False,  True, False,
         True, False,  True, False],
       [False, False,  True, False,  True, False,  True, False,  True,
        False,  True, False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False]])
In [7]:
y_train_lines = y_train[:,lineFilterMatrixNP(y_train.shape[-1], y_train.shape[-2])]
print(y_train_lines.shape)
print(y_train_lines[0])
(1000000, 49)
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0 255   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0]
In [8]:
kernelSize = (5,5)
filterCnt = 64
l2reg = 1e-4
resBlockCnt = 4

def build_residual_block(x, index):
        in_x = x
        res_name = "res"+str(index)
        x = Conv2D(filters=filterCnt, kernel_size=kernelSize, padding="same",
                   data_format="channels_last", kernel_regularizer=l2(l2reg), 
                   name=res_name+"_conv1_"+str(filterCnt))(x)
        x = BatchNormalization(name=res_name+"_batchnorm1")(x)
        x = Activation("relu",name=res_name+"_relu1")(x)
        x = Conv2D(filters=filterCnt, kernel_size=kernelSize, padding="same",
                   data_format="channels_last", kernel_regularizer=l2(l2reg), 
                   name=res_name+"_conv2-"+str(filterCnt))(x)
        x = BatchNormalization(name="res"+str(index)+"_batchnorm2")(x)
        x = Add(name=res_name+"_add")([in_x, x])
        x = Activation("relu", name=res_name+"_relu2")(x)
        return x


img_input = Input(shape=(None,None,5,))
x = Conv2D(filterCnt, kernelSize, padding='same', kernel_regularizer=l2(l2reg), name="input_conv")(img_input)
x = Activation("relu", name="input_relu")(x)
x = BatchNormalization()(x)

for i in range(resBlockCnt):
    x = build_residual_block(x, i+1)

res_out = x

x = Conv2D(1, kernelSize, padding='same', kernel_regularizer=l2(l2reg), name="output_conv")(x)
x = LineFilterLayer(y_train.shape[-1], y_train.shape[-2])(x)
x = Activation("softmax", name="output_softmax")(x)
    
model = Model(inputs=img_input, outputs=x)
model.compile(optimizer='adam', loss='categorical_crossentropy')

model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, None, 5 0                                            
__________________________________________________________________________________________________
input_conv (Conv2D)             (None, None, None, 6 8064        input_1[0][0]                    
__________________________________________________________________________________________________
input_relu (Activation)         (None, None, None, 6 0           input_conv[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256         input_relu[0][0]                 
__________________________________________________________________________________________________
res1_conv1_64 (Conv2D)          (None, None, None, 6 102464      batch_normalization_1[0][0]      
__________________________________________________________________________________________________
res1_batchnorm1 (BatchNormaliza (None, None, None, 6 256         res1_conv1_64[0][0]              
__________________________________________________________________________________________________
res1_relu1 (Activation)         (None, None, None, 6 0           res1_batchnorm1[0][0]            
__________________________________________________________________________________________________
res1_conv2-64 (Conv2D)          (None, None, None, 6 102464      res1_relu1[0][0]                 
__________________________________________________________________________________________________
res1_batchnorm2 (BatchNormaliza (None, None, None, 6 256         res1_conv2-64[0][0]              
__________________________________________________________________________________________________
res1_add (Add)                  (None, None, None, 6 0           batch_normalization_1[0][0]      
                                                                 res1_batchnorm2[0][0]            
__________________________________________________________________________________________________
res1_relu2 (Activation)         (None, None, None, 6 0           res1_add[0][0]                   
__________________________________________________________________________________________________
res2_conv1_64 (Conv2D)          (None, None, None, 6 102464      res1_relu2[0][0]                 
__________________________________________________________________________________________________
res2_batchnorm1 (BatchNormaliza (None, None, None, 6 256         res2_conv1_64[0][0]              
__________________________________________________________________________________________________
res2_relu1 (Activation)         (None, None, None, 6 0           res2_batchnorm1[0][0]            
__________________________________________________________________________________________________
res2_conv2-64 (Conv2D)          (None, None, None, 6 102464      res2_relu1[0][0]                 
__________________________________________________________________________________________________
res2_batchnorm2 (BatchNormaliza (None, None, None, 6 256         res2_conv2-64[0][0]              
__________________________________________________________________________________________________
res2_add (Add)                  (None, None, None, 6 0           res1_relu2[0][0]                 
                                                                 res2_batchnorm2[0][0]            
__________________________________________________________________________________________________
res2_relu2 (Activation)         (None, None, None, 6 0           res2_add[0][0]                   
__________________________________________________________________________________________________
res3_conv1_64 (Conv2D)          (None, None, None, 6 102464      res2_relu2[0][0]                 
__________________________________________________________________________________________________
res3_batchnorm1 (BatchNormaliza (None, None, None, 6 256         res3_conv1_64[0][0]              
__________________________________________________________________________________________________
res3_relu1 (Activation)         (None, None, None, 6 0           res3_batchnorm1[0][0]            
__________________________________________________________________________________________________
res3_conv2-64 (Conv2D)          (None, None, None, 6 102464      res3_relu1[0][0]                 
__________________________________________________________________________________________________
res3_batchnorm2 (BatchNormaliza (None, None, None, 6 256         res3_conv2-64[0][0]              
__________________________________________________________________________________________________
res3_add (Add)                  (None, None, None, 6 0           res2_relu2[0][0]                 
                                                                 res3_batchnorm2[0][0]            
__________________________________________________________________________________________________
res3_relu2 (Activation)         (None, None, None, 6 0           res3_add[0][0]                   
__________________________________________________________________________________________________
res4_conv1_64 (Conv2D)          (None, None, None, 6 102464      res3_relu2[0][0]                 
__________________________________________________________________________________________________
res4_batchnorm1 (BatchNormaliza (None, None, None, 6 256         res4_conv1_64[0][0]              
__________________________________________________________________________________________________
res4_relu1 (Activation)         (None, None, None, 6 0           res4_batchnorm1[0][0]            
__________________________________________________________________________________________________
res4_conv2-64 (Conv2D)          (None, None, None, 6 102464      res4_relu1[0][0]                 
__________________________________________________________________________________________________
res4_batchnorm2 (BatchNormaliza (None, None, None, 6 256         res4_conv2-64[0][0]              
__________________________________________________________________________________________________
res4_add (Add)                  (None, None, None, 6 0           res3_relu2[0][0]                 
                                                                 res4_batchnorm2[0][0]            
__________________________________________________________________________________________________
res4_relu2 (Activation)         (None, None, None, 6 0           res4_add[0][0]                   
__________________________________________________________________________________________________
output_conv (Conv2D)            (None, None, None, 1 1601        res4_relu2[0][0]                 
__________________________________________________________________________________________________
line_filter_layer_1 (LineFilter (None, None)         0           output_conv[0][0]                
__________________________________________________________________________________________________
output_softmax (Activation)     (None, None)         0           line_filter_layer_1[0][0]        
==================================================================================================
Total params: 831,681
Trainable params: 830,529
Non-trainable params: 1,152
__________________________________________________________________________________________________
In [9]:
#sess = K.get_session()
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
#K.set_session(sess)

# Training
callbacks = []

checkpoint = ModelCheckpoint(filepath=modelPath+".checkpoint", save_weights_only=False)
callbacks.append(checkpoint)

progbar = ProgbarLogger()
callbacks.append(progbar)

tensorboard = TensorBoard(log_dir='model/log2', write_grads=True, write_graph=True, write_images=True, histogram_freq=1)
#callbacks.append(tensorboard)

model.fit(x_train_cat, y_train_lines, epochs=32, batch_size=64, callbacks=callbacks, validation_split=0.001)

model.save(modelPath)
Train on 999000 samples, validate on 1000 samples
Epoch 1/32
Epoch 1/32
999000/999000 [==============================] - 716s 717us/step - loss: 278.7167 - val_loss: 264.4487
999000/999000 [==============================] - 717s 718us/step - loss: 278.7167 - val_loss: 264.4487
Epoch 2/32
Epoch 2/32
999000/999000 [==============================] - 718s 719us/step - loss: 254.2629 - val_loss: 299.8625
999000/999000 [==============================] - 718s 719us/step - loss: 254.2629 - val_loss: 299.8625
Epoch 3/32
Epoch 3/32
999000/999000 [==============================] - 724s 724us/step - loss: 250.7524 - val_loss: 256.2867
999000/999000 [==============================] - 724s 725us/step - loss: 250.7524 - val_loss: 256.2867
Epoch 4/32
Epoch 4/32
999000/999000 [==============================] - 725s 725us/step - loss: 249.1422 - val_loss: 255.1527
999000/999000 [==============================] - 725s 725us/step - loss: 249.1422 - val_loss: 255.1527
Epoch 5/32
Epoch 5/32
999000/999000 [==============================] - 723s 724us/step - loss: 248.1794 - val_loss: 255.6740
999000/999000 [==============================] - 724s 724us/step - loss: 248.1794 - val_loss: 255.6740
Epoch 6/32
Epoch 6/32
999000/999000 [==============================] - 725s 726us/step - loss: 247.3944 - val_loss: 270.5581
999000/999000 [==============================] - 725s 726us/step - loss: 247.3944 - val_loss: 270.5581
Epoch 7/32
Epoch 7/32
999000/999000 [==============================] - 725s 726us/step - loss: 246.8300 - val_loss: 255.6762
999000/999000 [==============================] - 725s 726us/step - loss: 246.8300 - val_loss: 255.6762
Epoch 8/32
Epoch 8/32
999000/999000 [==============================] - 718s 718us/step - loss: 246.4205 - val_loss: 256.3166
999000/999000 [==============================] - 718s 718us/step - loss: 246.4205 - val_loss: 256.3166
Epoch 9/32
Epoch 9/32
999000/999000 [==============================] - 719s 720us/step - loss: 246.0360 - val_loss: 258.1618
999000/999000 [==============================] - 719s 720us/step - loss: 246.0360 - val_loss: 258.1618
Epoch 10/32
Epoch 10/32
999000/999000 [==============================] - 719s 720us/step - loss: 245.7584 - val_loss: 254.0135
999000/999000 [==============================] - 719s 720us/step - loss: 245.7584 - val_loss: 254.0135
Epoch 11/32
Epoch 11/32
999000/999000 [==============================] - 716s 717us/step - loss: 245.4592 - val_loss: 255.0869
999000/999000 [==============================] - 716s 717us/step - loss: 245.4592 - val_loss: 255.0869
Epoch 12/32
Epoch 12/32
999000/999000 [==============================] - 716s 716us/step - loss: 245.2128 - val_loss: 255.0709
999000/999000 [==============================] - 716s 717us/step - loss: 245.2128 - val_loss: 255.0709
Epoch 13/32
Epoch 13/32
999000/999000 [==============================] - 718s 719us/step - loss: 245.0520 - val_loss: 254.1625
999000/999000 [==============================] - 718s 719us/step - loss: 245.0520 - val_loss: 254.1625
Epoch 14/32
Epoch 14/32
999000/999000 [==============================] - 718s 719us/step - loss: 244.7663 - val_loss: 256.9329
999000/999000 [==============================] - 718s 719us/step - loss: 244.7663 - val_loss: 256.9329
Epoch 15/32
Epoch 15/32
999000/999000 [==============================] - 717s 718us/step - loss: 244.6284 - val_loss: 255.5091
999000/999000 [==============================] - 717s 718us/step - loss: 244.6284 - val_loss: 255.5091
Epoch 16/32
Epoch 16/32
999000/999000 [==============================] - 717s 717us/step - loss: 244.3039 - val_loss: 259.1290
999000/999000 [==============================] - 717s 717us/step - loss: 244.3039 - val_loss: 259.1290
Epoch 17/32
Epoch 17/32
999000/999000 [==============================] - 718s 719us/step - loss: 244.1131 - val_loss: 256.8970
999000/999000 [==============================] - 718s 719us/step - loss: 244.1131 - val_loss: 256.8970
Epoch 18/32
Epoch 18/32
999000/999000 [==============================] - 716s 716us/step - loss: 243.8361 - val_loss: 256.6168
999000/999000 [==============================] - 716s 716us/step - loss: 243.8361 - val_loss: 256.6168
Epoch 19/32
Epoch 19/32
999000/999000 [==============================] - 715s 716us/step - loss: 243.5561 - val_loss: 258.8813
999000/999000 [==============================] - 715s 716us/step - loss: 243.5561 - val_loss: 258.8813
Epoch 20/32
Epoch 20/32
999000/999000 [==============================] - 717s 718us/step - loss: 243.3269 - val_loss: 260.5998
999000/999000 [==============================] - 718s 718us/step - loss: 243.3269 - val_loss: 260.5998
Epoch 21/32
Epoch 21/32
999000/999000 [==============================] - 716s 716us/step - loss: 243.0043 - val_loss: 260.8126
999000/999000 [==============================] - 716s 717us/step - loss: 243.0043 - val_loss: 260.8126
Epoch 22/32
Epoch 22/32
999000/999000 [==============================] - 716s 716us/step - loss: 242.7286 - val_loss: 258.3561
999000/999000 [==============================] - 716s 717us/step - loss: 242.7286 - val_loss: 258.3561
Epoch 23/32
Epoch 23/32
999000/999000 [==============================] - 715s 716us/step - loss: 242.3785 - val_loss: 258.4142
999000/999000 [==============================] - 715s 716us/step - loss: 242.3785 - val_loss: 258.4142
Epoch 24/32
Epoch 24/32
999000/999000 [==============================] - 716s 717us/step - loss: 242.0396 - val_loss: 259.2178
999000/999000 [==============================] - 716s 717us/step - loss: 242.0396 - val_loss: 259.2178
Epoch 25/32
Epoch 25/32
999000/999000 [==============================] - 714s 715us/step - loss: 241.6783 - val_loss: 264.9506
999000/999000 [==============================] - 715s 715us/step - loss: 241.6783 - val_loss: 264.9506
Epoch 26/32
Epoch 26/32
999000/999000 [==============================] - 713s 713us/step - loss: 241.2356 - val_loss: 262.6501
999000/999000 [==============================] - 713s 714us/step - loss: 241.2356 - val_loss: 262.6501
Epoch 27/32
Epoch 27/32
999000/999000 [==============================] - 715s 716us/step - loss: 240.8816 - val_loss: 262.8784
999000/999000 [==============================] - 715s 716us/step - loss: 240.8816 - val_loss: 262.8784
Epoch 28/32
Epoch 28/32
999000/999000 [==============================] - 716s 716us/step - loss: 240.4988 - val_loss: 264.7552
999000/999000 [==============================] - 716s 717us/step - loss: 240.4988 - val_loss: 264.7552
Epoch 29/32
Epoch 29/32
999000/999000 [==============================] - 714s 715us/step - loss: 240.0224 - val_loss: 267.4462
999000/999000 [==============================] - 714s 715us/step - loss: 240.0224 - val_loss: 267.4462
Epoch 30/32
Epoch 30/32
999000/999000 [==============================] - 715s 716us/step - loss: 239.5577 - val_loss: 266.9343
999000/999000 [==============================] - 715s 716us/step - loss: 239.5577 - val_loss: 266.9343
Epoch 31/32
Epoch 31/32
999000/999000 [==============================] - 717s 718us/step - loss: 239.1245 - val_loss: 271.5216
999000/999000 [==============================] - 717s 718us/step - loss: 239.1245 - val_loss: 271.5216
Epoch 32/32
Epoch 32/32
999000/999000 [==============================] - 718s 719us/step - loss: 238.5485 - val_loss: 271.5884
999000/999000 [==============================] - 718s 719us/step - loss: 238.5485 - val_loss: 271.5884
In [10]:
def linesToDotsAndBoxesImage(lines, imgWidth, imgHeight):
    boxWidth = imgSizeToBoxes(imgWidth)
    boxHeight = imgSizeToBoxes(imgHeight)
    linesCnt = 2*boxWidth*boxHeight+boxWidth+boxHeight
    mat = np.zeros((imgHeight, imgWidth), dtype=lines.dtype)
    for idx in range(linesCnt):
        y1 = idx / ((2*boxWidth) + 1)
        if idx % ((2*boxWidth) + 1) < boxWidth:
            # horizontal line
            x1 = idx % ((2*boxWidth) + 1)
            x2 = x1 + 1
            y2 = y1
        else:
            # vertical line
            x1 = idx % ((2*boxWidth) + 1) - boxWidth
            x2 = x1
            y2 = y1 + 1
        px = x2 * 2 + y2 - y1
        py = y2 * 2 + x2 - x1
        mat[py,px] = lines[idx]
    return mat
In [11]:
example = random.randrange(x_train.shape[0])
print("example: "+str(example))

input_data = x_train[example:example+1]
input_data_cat = x_train_cat[example:example+1]

prediction_lines = model.predict(input_data_cat)
prediction_lines_print = prediction_lines * 100
print(prediction_lines_print.astype(np.uint8))
print(np.sum(prediction_lines))
prediction = linesToDotsAndBoxesImage(prediction_lines[0], x_train.shape[2], x_train.shape[1])

# print input data
input_data_print = x_train[example,:,:] 
input_data_print = input_data_print.astype(np.uint8)
print("input "+str(input_data_print.shape)+": ")
print(input_data_print)

# generate greyscale image data from input data
target_imgdata = x_train[example,:,:] 
target_imgdata = target_imgdata.astype(np.uint8)

# print prediction
prediction_data_print = prediction * 100 
prediction_data_print = prediction_data_print.astype(np.uint8)
print("prediction: ")
print(prediction_data_print)

# generate greyscale image data from prediction data
prediction_imgdata = prediction * 255
prediction_imgdata = prediction_imgdata.astype(np.uint8)

# merge image data in color channels
tmp = np.zeros((prediction.shape[0], prediction.shape[1]), dtype=np.uint8)
merged_imgdata = np.stack([target_imgdata, prediction_imgdata, tmp], axis=2)

#create image
img = Image.fromarray(merged_imgdata, 'RGB')
img = img.resize(size=(img.size[0]*10, img.size[1]*10))

img
example: 806312
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 99  0  0  0  0  0
   0]]
1.0
input (11, 13): 
[[  0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0 215 255 215   0 215 255 215 255 215 255 215   0]
 [  0 255  65 255   0   0   0 255   0   0   0 255   0]
 [  0 215 255 215 255 215   0 215   0 215   0 215   0]
 [  0   0   0 255   0   0   0 255   0 255   0 255   0]
 [  0 215   0 215   0 215 255 215   0 215   0 215   0]
 [  0 255   0   0   0 255   0   0   0 255   0   0   0]
 [  0 215 255 215 255 215   0 215 255 215 255 215   0]
 [  0 255 150 255   0   0   0 255   0   0   0   0   0]
 [  0 215 255 215   0 215 255 215   0 215 255 215   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0]]
prediction: 
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0 99  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0]]
Out[11]:
In [ ]: