AlphaZero version 9

This AlphaZero version starts from the weights of version 7 and is trained on augmented StageFour data. This is mostly about trying out data augmentation.

In [1]:
import sys
sys.path.append('..')

import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug

from keras.callbacks import *
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.initializers import *
from keras.utils.np_utils import to_categorical
from keras.utils import plot_model
import keras.backend as K
from keras.regularizers import l2
from keras.engine.topology import Layer

from PIL import Image
from matplotlib.pyplot import imshow
%matplotlib inline
import random
import gc

from LineFilterLayer import LineFilterLayer
from ValueLayer import ValueLayer
from AugmentationSequence import AugmentationSequence

modelPath = 'model/alphaZeroV9.h5'
baseModelPath = 'model/alphaZeroV7.h5'

datasetPath = 'StageFour-AlphaZeroV7-128-7x6-18:36-24_05_2018.npz'
Using TensorFlow backend.
In [2]:
print(K.image_data_format()) 
# expected output: channels_last
channels_last
In [3]:
def dotsAndBoxesToCategorical(inputData):
    inp = np.copy(inputData)
    inp[inp == 255] = 1 # Line - comes first so that target data only has two categories
    inp[inp == 65] = 2 # Box A
    inp[inp == 150] = 3 # Box B
    inp[inp == 215] = 4 # Dot
    cat = to_categorical(inp)
    newShape = inp.shape + (cat.shape[-1],)
    return cat.reshape(newShape)
In [4]:
def imgSizeToBoxes(x):
    return (x-3)/2

def lineFilterMatrixNP(imgWidth,imgHeight):
    boxWidth = imgSizeToBoxes(imgWidth)
    boxHeight = imgSizeToBoxes(imgHeight)
    linesCnt = 2*boxWidth*boxHeight+boxWidth+boxHeight
    mat = np.zeros((imgHeight, imgWidth), dtype=np.bool)
    for idx in range(linesCnt):
        y1 = idx / ((2*boxWidth) + 1)
        if idx % ((2*boxWidth) + 1) < boxWidth:
            # horizontal line
            x1 = idx % ((2*boxWidth) + 1)
            x2 = x1 + 1
            y2 = y1
        else:
            # vertical line
            x1 = idx % ((2*boxWidth) + 1) - boxWidth
            x2 = x1
            y2 = y1 + 1
        px = x2 * 2 + y2 - y1
        py = y2 * 2 + x2 - x1
        mat[py,px] = 1
    return mat
In [5]:
def loadRawPVDataset(datasetPath):
    rawDataset = np.load(datasetPath)
    
    x_input = rawDataset['input']
    y_policy = rawDataset['policy']
    y_value = rawDataset['value']
    
    return (x_input, y_policy, y_value)

def process_input(x_input):
    return dotsAndBoxesToCategorical(x_input)

def process_policy(y_policy):
    y_policy = y_policy[:,lineFilterMatrixNP(y_policy.shape[-1], y_policy.shape[-2])]
    y_policy /= 255
    return y_policy

np.set_printoptions(precision=2)
(x_input_raw, y_policy_raw, y_value_raw) = loadRawPVDataset(datasetPath)
In [6]:
print(x_input_raw.shape)
print(y_policy_raw.shape)
print(y_value_raw.shape)
print("raw input:")
print(x_input_raw[0])
print("raw policy:")
print(y_policy_raw[0])
print('raw value:')
print(y_value_raw[0])
(128, 15, 17)
(128, 15, 17)
(128, 1)
raw input:
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0 215   0 215 255 215 255 215 255 215 255 215   0 215 255 215   0]
 [  0 255   0 255 150 255  65 255  65 255   0   0   0 255  65 255   0]
 [  0 215   0 215 255 215 255 215 255 215   0 215 255 215 255 215   0]
 [  0 255   0 255 150 255  65 255  65 255   0 255   0   0   0 255   0]
 [  0 215   0 215 255 215 255 215 255 215   0 215   0 215   0 215   0]
 [  0 255   0 255 150 255  65 255  65 255   0 255   0 255   0   0   0]
 [  0 215   0 215 255 215 255 215 255 215   0 215   0 215 255 215   0]
 [  0 255   0 255   0   0   0   0   0   0   0 255   0   0   0   0   0]
 [  0 215   0 215   0 215 255 215 255 215 255 215 255 215 255 215   0]
 [  0   0   0 255   0   0   0 255   0   0   0   0   0   0   0   0   0]
 [  0 215 255 215 255 215   0 215   0 215 255 215 255 215 255 215   0]
 [  0 255  65 255  65 255   0 255   0   0   0 255 150 255 150 255   0]
 [  0 215 255 215 255 215   0 215 255 215   0 215 255 215 255 215   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
raw policy:
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0 255   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
raw value:
[-0.29]
In [7]:
imgWidth = x_input_raw.shape[-1]
imgHeight = x_input_raw.shape[-2]
# LineFilterLayer has to be set before loading the model
LineFilterLayer.imgWidth = imgWidth
LineFilterLayer.imgHeight = imgHeight
# ValueLayer has to be set before loading the model
ValueLayer.imgWidth = imgWidth
ValueLayer.imgHeight = imgHeight

model = load_model(baseModelPath,
                   custom_objects={'LineFilterLayer':LineFilterLayer,
                   'ValueLayer':ValueLayer})

model.summary()
LineFilterLayer with image size 17 x 15
ValueLayer with image size 17 x 15
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, None, 5 0                                            
__________________________________________________________________________________________________
input_conv (Conv2D)             (None, None, None, 1 16128       input_1[0][0]                    
__________________________________________________________________________________________________
input_relu (Activation)         (None, None, None, 1 0           input_conv[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 1 512         input_relu[0][0]                 
__________________________________________________________________________________________________
res1_conv1_128 (Conv2D)         (None, None, None, 1 409728      batch_normalization_1[0][0]      
__________________________________________________________________________________________________
res1_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res1_conv1_128[0][0]             
__________________________________________________________________________________________________
res1_relu1 (Activation)         (None, None, None, 1 0           res1_batchnorm1[0][0]            
__________________________________________________________________________________________________
res1_conv2-128 (Conv2D)         (None, None, None, 1 409728      res1_relu1[0][0]                 
__________________________________________________________________________________________________
res1_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res1_conv2-128[0][0]             
__________________________________________________________________________________________________
res1_add (Add)                  (None, None, None, 1 0           batch_normalization_1[0][0]      
                                                                 res1_batchnorm2[0][0]            
__________________________________________________________________________________________________
res1_relu2 (Activation)         (None, None, None, 1 0           res1_add[0][0]                   
__________________________________________________________________________________________________
res2_conv1_128 (Conv2D)         (None, None, None, 1 409728      res1_relu2[0][0]                 
__________________________________________________________________________________________________
res2_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res2_conv1_128[0][0]             
__________________________________________________________________________________________________
res2_relu1 (Activation)         (None, None, None, 1 0           res2_batchnorm1[0][0]            
__________________________________________________________________________________________________
res2_conv2-128 (Conv2D)         (None, None, None, 1 409728      res2_relu1[0][0]                 
__________________________________________________________________________________________________
res2_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res2_conv2-128[0][0]             
__________________________________________________________________________________________________
res2_add (Add)                  (None, None, None, 1 0           res1_relu2[0][0]                 
                                                                 res2_batchnorm2[0][0]            
__________________________________________________________________________________________________
res2_relu2 (Activation)         (None, None, None, 1 0           res2_add[0][0]                   
__________________________________________________________________________________________________
res3_conv1_128 (Conv2D)         (None, None, None, 1 409728      res2_relu2[0][0]                 
__________________________________________________________________________________________________
res3_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res3_conv1_128[0][0]             
__________________________________________________________________________________________________
res3_relu1 (Activation)         (None, None, None, 1 0           res3_batchnorm1[0][0]            
__________________________________________________________________________________________________
res3_conv2-128 (Conv2D)         (None, None, None, 1 409728      res3_relu1[0][0]                 
__________________________________________________________________________________________________
res3_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res3_conv2-128[0][0]             
__________________________________________________________________________________________________
res3_add (Add)                  (None, None, None, 1 0           res2_relu2[0][0]                 
                                                                 res3_batchnorm2[0][0]            
__________________________________________________________________________________________________
res3_relu2 (Activation)         (None, None, None, 1 0           res3_add[0][0]                   
__________________________________________________________________________________________________
res4_conv1_128 (Conv2D)         (None, None, None, 1 409728      res3_relu2[0][0]                 
__________________________________________________________________________________________________
res4_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res4_conv1_128[0][0]             
__________________________________________________________________________________________________
res4_relu1 (Activation)         (None, None, None, 1 0           res4_batchnorm1[0][0]            
__________________________________________________________________________________________________
res4_conv2-128 (Conv2D)         (None, None, None, 1 409728      res4_relu1[0][0]                 
__________________________________________________________________________________________________
res4_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res4_conv2-128[0][0]             
__________________________________________________________________________________________________
res4_add (Add)                  (None, None, None, 1 0           res3_relu2[0][0]                 
                                                                 res4_batchnorm2[0][0]            
__________________________________________________________________________________________________
res4_relu2 (Activation)         (None, None, None, 1 0           res4_add[0][0]                   
__________________________________________________________________________________________________
res5_conv1_128 (Conv2D)         (None, None, None, 1 409728      res4_relu2[0][0]                 
__________________________________________________________________________________________________
res5_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res5_conv1_128[0][0]             
__________________________________________________________________________________________________
res5_relu1 (Activation)         (None, None, None, 1 0           res5_batchnorm1[0][0]            
__________________________________________________________________________________________________
res5_conv2-128 (Conv2D)         (None, None, None, 1 409728      res5_relu1[0][0]                 
__________________________________________________________________________________________________
res5_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res5_conv2-128[0][0]             
__________________________________________________________________________________________________
res5_add (Add)                  (None, None, None, 1 0           res4_relu2[0][0]                 
                                                                 res5_batchnorm2[0][0]            
__________________________________________________________________________________________________
res5_relu2 (Activation)         (None, None, None, 1 0           res5_add[0][0]                   
__________________________________________________________________________________________________
res6_conv1_128 (Conv2D)         (None, None, None, 1 409728      res5_relu2[0][0]                 
__________________________________________________________________________________________________
res6_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res6_conv1_128[0][0]             
__________________________________________________________________________________________________
res6_relu1 (Activation)         (None, None, None, 1 0           res6_batchnorm1[0][0]            
__________________________________________________________________________________________________
res6_conv2-128 (Conv2D)         (None, None, None, 1 409728      res6_relu1[0][0]                 
__________________________________________________________________________________________________
res6_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res6_conv2-128[0][0]             
__________________________________________________________________________________________________
res6_add (Add)                  (None, None, None, 1 0           res5_relu2[0][0]                 
                                                                 res6_batchnorm2[0][0]            
__________________________________________________________________________________________________
res6_relu2 (Activation)         (None, None, None, 1 0           res6_add[0][0]                   
__________________________________________________________________________________________________
res7_conv1_128 (Conv2D)         (None, None, None, 1 409728      res6_relu2[0][0]                 
__________________________________________________________________________________________________
res7_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res7_conv1_128[0][0]             
__________________________________________________________________________________________________
res7_relu1 (Activation)         (None, None, None, 1 0           res7_batchnorm1[0][0]            
__________________________________________________________________________________________________
res7_conv2-128 (Conv2D)         (None, None, None, 1 409728      res7_relu1[0][0]                 
__________________________________________________________________________________________________
res7_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res7_conv2-128[0][0]             
__________________________________________________________________________________________________
res7_add (Add)                  (None, None, None, 1 0           res6_relu2[0][0]                 
                                                                 res7_batchnorm2[0][0]            
__________________________________________________________________________________________________
res7_relu2 (Activation)         (None, None, None, 1 0           res7_add[0][0]                   
__________________________________________________________________________________________________
res8_conv1_128 (Conv2D)         (None, None, None, 1 409728      res7_relu2[0][0]                 
__________________________________________________________________________________________________
res8_batchnorm1 (BatchNormaliza (None, None, None, 1 512         res8_conv1_128[0][0]             
__________________________________________________________________________________________________
res8_relu1 (Activation)         (None, None, None, 1 0           res8_batchnorm1[0][0]            
__________________________________________________________________________________________________
res8_conv2-128 (Conv2D)         (None, None, None, 1 409728      res8_relu1[0][0]                 
__________________________________________________________________________________________________
res8_batchnorm2 (BatchNormaliza (None, None, None, 1 512         res8_conv2-128[0][0]             
__________________________________________________________________________________________________
res8_add (Add)                  (None, None, None, 1 0           res7_relu2[0][0]                 
                                                                 res8_batchnorm2[0][0]            
__________________________________________________________________________________________________
res8_relu2 (Activation)         (None, None, None, 1 0           res8_add[0][0]                   
__________________________________________________________________________________________________
policy_conv (Conv2D)            (None, None, None, 1 3201        res8_relu2[0][0]                 
__________________________________________________________________________________________________
value_conv (Conv2D)             (None, None, None, 1 3201        res8_relu2[0][0]                 
__________________________________________________________________________________________________
line_filter_layer_1 (LineFilter (None, None)         0           policy_conv[0][0]                
__________________________________________________________________________________________________
value_layer_1 (ValueLayer)      (None, 1)            0           value_conv[0][0]                 
__________________________________________________________________________________________________
policy (Activation)             (None, None)         0           line_filter_layer_1[0][0]        
__________________________________________________________________________________________________
value (Activation)              (None, 1)            0           value_layer_1[0][0]              
==================================================================================================
Total params: 6,586,882
Trainable params: 6,582,530
Non-trainable params: 4,352
__________________________________________________________________________________________________
In [8]:
batch_size = 16
data_generator = AugmentationSequence(x_input_raw, y_policy_raw, y_value_raw, batch_size, process_input, process_policy)
In [9]:
#sess = K.get_session()
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
#K.set_session(sess)

# Training
callbacks = []

checkpoint = ModelCheckpoint(filepath=modelPath+".checkpoint", save_weights_only=False)
callbacks.append(checkpoint)

progbar = ProgbarLogger()
callbacks.append(progbar)

#tensorboard = TensorBoard(log_dir='model/log2', write_grads=True, write_graph=True, write_images=True, histogram_freq=1)
#callbacks.append(tensorboard)

model.fit_generator(data_generator, epochs=128, steps_per_epoch=len(data_generator)) #, callbacks=callbacks)

model.save(modelPath)
Epoch 1/128
8/8 [==============================] - 4s 507ms/step - loss: 7.9643 - policy_loss: 7.6627 - value_loss: 0.2626
Epoch 2/128
8/8 [==============================] - 1s 87ms/step - loss: 6.4436 - policy_loss: 6.1928 - value_loss: 0.2118
Epoch 3/128
8/8 [==============================] - 1s 89ms/step - loss: 5.8185 - policy_loss: 5.5787 - value_loss: 0.2007
Epoch 4/128
8/8 [==============================] - 1s 83ms/step - loss: 5.1229 - policy_loss: 4.8999 - value_loss: 0.1838
Epoch 5/128
8/8 [==============================] - 1s 82ms/step - loss: 4.6413 - policy_loss: 4.4164 - value_loss: 0.1856
Epoch 6/128
8/8 [==============================] - 1s 86ms/step - loss: 4.1071 - policy_loss: 3.8894 - value_loss: 0.1782
Epoch 7/128
8/8 [==============================] - 1s 82ms/step - loss: 3.8342 - policy_loss: 3.6165 - value_loss: 0.1781
Epoch 8/128
8/8 [==============================] - 1s 82ms/step - loss: 3.5639 - policy_loss: 3.3382 - value_loss: 0.1859
Epoch 9/128
8/8 [==============================] - 1s 90ms/step - loss: 3.2926 - policy_loss: 3.0675 - value_loss: 0.1851
Epoch 10/128
8/8 [==============================] - 1s 83ms/step - loss: 2.9966 - policy_loss: 2.7752 - value_loss: 0.1812
Epoch 11/128
8/8 [==============================] - 1s 82ms/step - loss: 2.7192 - policy_loss: 2.5022 - value_loss: 0.1767
Epoch 12/128
8/8 [==============================] - 1s 83ms/step - loss: 2.3744 - policy_loss: 2.1641 - value_loss: 0.1698
Epoch 13/128
8/8 [==============================] - 1s 82ms/step - loss: 2.3752 - policy_loss: 2.1638 - value_loss: 0.1709
Epoch 14/128
8/8 [==============================] - 1s 83ms/step - loss: 2.1702 - policy_loss: 1.9634 - value_loss: 0.1662
Epoch 15/128
8/8 [==============================] - 1s 84ms/step - loss: 2.0340 - policy_loss: 1.8305 - value_loss: 0.1628
Epoch 16/128
8/8 [==============================] - 1s 82ms/step - loss: 2.0701 - policy_loss: 1.8613 - value_loss: 0.1680
Epoch 17/128
8/8 [==============================] - 1s 82ms/step - loss: 1.8708 - policy_loss: 1.6650 - value_loss: 0.1650
Epoch 18/128
8/8 [==============================] - 1s 87ms/step - loss: 1.6751 - policy_loss: 1.4710 - value_loss: 0.1632
Epoch 19/128
8/8 [==============================] - 1s 83ms/step - loss: 1.6004 - policy_loss: 1.3950 - value_loss: 0.1645
Epoch 20/128
8/8 [==============================] - 1s 83ms/step - loss: 1.6681 - policy_loss: 1.4636 - value_loss: 0.1634
Epoch 21/128
8/8 [==============================] - 1s 87ms/step - loss: 1.3969 - policy_loss: 1.2010 - value_loss: 0.1549
Epoch 22/128
8/8 [==============================] - 1s 89ms/step - loss: 1.4554 - policy_loss: 1.2618 - value_loss: 0.1525
Epoch 23/128
8/8 [==============================] - 1s 84ms/step - loss: 1.3635 - policy_loss: 1.1677 - value_loss: 0.1545
Epoch 24/128
8/8 [==============================] - 1s 83ms/step - loss: 1.3543 - policy_loss: 1.1591 - value_loss: 0.1539
Epoch 25/128
8/8 [==============================] - 1s 82ms/step - loss: 1.0870 - policy_loss: 0.8910 - value_loss: 0.1546
Epoch 26/128
8/8 [==============================] - 1s 82ms/step - loss: 1.2283 - policy_loss: 1.0325 - value_loss: 0.1543
Epoch 27/128
8/8 [==============================] - 1s 87ms/step - loss: 1.0785 - policy_loss: 0.8933 - value_loss: 0.1437
Epoch 28/128
8/8 [==============================] - 1s 82ms/step - loss: 0.9999 - policy_loss: 0.8204 - value_loss: 0.1378
Epoch 29/128
8/8 [==============================] - 1s 83ms/step - loss: 0.9573 - policy_loss: 0.7729 - value_loss: 0.1426
Epoch 30/128
8/8 [==============================] - 1s 83ms/step - loss: 0.9045 - policy_loss: 0.7209 - value_loss: 0.1417
Epoch 31/128
8/8 [==============================] - 1s 83ms/step - loss: 0.8470 - policy_loss: 0.6746 - value_loss: 0.1304
Epoch 32/128
8/8 [==============================] - 1s 82ms/step - loss: 0.8580 - policy_loss: 0.6912 - value_loss: 0.1248
Epoch 33/128
8/8 [==============================] - 1s 83ms/step - loss: 0.7179 - policy_loss: 0.5486 - value_loss: 0.1272
Epoch 34/128
8/8 [==============================] - 1s 85ms/step - loss: 0.7997 - policy_loss: 0.6363 - value_loss: 0.1211
Epoch 35/128
8/8 [==============================] - 1s 90ms/step - loss: 0.6432 - policy_loss: 0.4713 - value_loss: 0.1295
Epoch 36/128
8/8 [==============================] - 1s 84ms/step - loss: 0.7044 - policy_loss: 0.5400 - value_loss: 0.1220
Epoch 37/128
8/8 [==============================] - 1s 87ms/step - loss: 0.6552 - policy_loss: 0.4894 - value_loss: 0.1233
Epoch 38/128
8/8 [==============================] - 1s 82ms/step - loss: 0.6300 - policy_loss: 0.4716 - value_loss: 0.1159
Epoch 39/128
8/8 [==============================] - 1s 82ms/step - loss: 0.5613 - policy_loss: 0.4039 - value_loss: 0.1147
Epoch 40/128
8/8 [==============================] - 1s 83ms/step - loss: 0.4935 - policy_loss: 0.3372 - value_loss: 0.1136
Epoch 41/128
8/8 [==============================] - 1s 91ms/step - loss: 0.5448 - policy_loss: 0.3929 - value_loss: 0.1091
Epoch 42/128
8/8 [==============================] - 1s 82ms/step - loss: 0.5175 - policy_loss: 0.3680 - value_loss: 0.1064
Epoch 43/128
8/8 [==============================] - 1s 82ms/step - loss: 0.4374 - policy_loss: 0.3003 - value_loss: 0.0941
Epoch 44/128
8/8 [==============================] - 1s 82ms/step - loss: 0.4273 - policy_loss: 0.2845 - value_loss: 0.0997
Epoch 45/128
8/8 [==============================] - 1s 88ms/step - loss: 0.3568 - policy_loss: 0.2201 - value_loss: 0.0935
Epoch 46/128
8/8 [==============================] - 1s 83ms/step - loss: 0.3457 - policy_loss: 0.2121 - value_loss: 0.0905
Epoch 47/128
8/8 [==============================] - 1s 82ms/step - loss: 0.3945 - policy_loss: 0.2601 - value_loss: 0.0912
Epoch 48/128
8/8 [==============================] - 1s 82ms/step - loss: 0.4012 - policy_loss: 0.2635 - value_loss: 0.0944
Epoch 49/128
8/8 [==============================] - 1s 83ms/step - loss: 0.3560 - policy_loss: 0.2256 - value_loss: 0.0870
Epoch 50/128
8/8 [==============================] - 1s 82ms/step - loss: 0.3066 - policy_loss: 0.1836 - value_loss: 0.0795
Epoch 51/128
8/8 [==============================] - 1s 82ms/step - loss: 0.3002 - policy_loss: 0.1461 - value_loss: 0.1107
Epoch 52/128
8/8 [==============================] - 1s 84ms/step - loss: 0.3107 - policy_loss: 0.1857 - value_loss: 0.0814
Epoch 53/128
8/8 [==============================] - 1s 90ms/step - loss: 0.2343 - policy_loss: 0.1151 - value_loss: 0.0755
Epoch 54/128
8/8 [==============================] - 1s 82ms/step - loss: 0.2385 - policy_loss: 0.1265 - value_loss: 0.0683
Epoch 55/128
8/8 [==============================] - 1s 82ms/step - loss: 0.2775 - policy_loss: 0.1655 - value_loss: 0.0683
Epoch 56/128
8/8 [==============================] - 1s 84ms/step - loss: 0.2621 - policy_loss: 0.1493 - value_loss: 0.0691
Epoch 57/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1815 - policy_loss: 0.0798 - value_loss: 0.0579
Epoch 58/128
8/8 [==============================] - 1s 82ms/step - loss: 0.2313 - policy_loss: 0.1259 - value_loss: 0.0616
Epoch 59/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1988 - policy_loss: 0.1048 - value_loss: 0.0502
Epoch 60/128
8/8 [==============================] - 1s 82ms/step - loss: 0.2179 - policy_loss: 0.1204 - value_loss: 0.0537
Epoch 61/128
8/8 [==============================] - 1s 91ms/step - loss: 0.1886 - policy_loss: 0.0956 - value_loss: 0.0491
Epoch 62/128
8/8 [==============================] - 1s 83ms/step - loss: 0.2081 - policy_loss: 0.1026 - value_loss: 0.0616
Epoch 63/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1878 - policy_loss: 0.0941 - value_loss: 0.0498
Epoch 64/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1761 - policy_loss: 0.0817 - value_loss: 0.0503
Epoch 65/128
8/8 [==============================] - 1s 89ms/step - loss: 0.1962 - policy_loss: 0.0988 - value_loss: 0.0533
Epoch 66/128
8/8 [==============================] - 1s 88ms/step - loss: 0.1687 - policy_loss: 0.0791 - value_loss: 0.0455
Epoch 67/128
8/8 [==============================] - 1s 83ms/step - loss: 0.1714 - policy_loss: 0.0772 - value_loss: 0.0501
Epoch 68/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1641 - policy_loss: 0.0721 - value_loss: 0.0478
Epoch 69/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1780 - policy_loss: 0.0896 - value_loss: 0.0442
Epoch 70/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1504 - policy_loss: 0.0652 - value_loss: 0.0410
Epoch 71/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1327 - policy_loss: 0.0559 - value_loss: 0.0325
Epoch 72/128
8/8 [==============================] - 1s 84ms/step - loss: 0.1432 - policy_loss: 0.0504 - value_loss: 0.0486
Epoch 73/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1101 - policy_loss: 0.0358 - value_loss: 0.0300
Epoch 74/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0966 - policy_loss: 0.0267 - value_loss: 0.0256
Epoch 75/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1015 - policy_loss: 0.0249 - value_loss: 0.0322
Epoch 76/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1079 - policy_loss: 0.0286 - value_loss: 0.0349
Epoch 77/128
8/8 [==============================] - 1s 82ms/step - loss: 0.1005 - policy_loss: 0.0232 - value_loss: 0.0329
Epoch 78/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0890 - policy_loss: 0.0182 - value_loss: 0.0264
Epoch 79/128
8/8 [==============================] - 1s 84ms/step - loss: 0.0849 - policy_loss: 0.0176 - value_loss: 0.0229
Epoch 80/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0964 - policy_loss: 0.0267 - value_loss: 0.0253
Epoch 81/128
8/8 [==============================] - 1s 89ms/step - loss: 0.0838 - policy_loss: 0.0187 - value_loss: 0.0208
Epoch 82/128
8/8 [==============================] - 1s 88ms/step - loss: 0.0754 - policy_loss: 0.0142 - value_loss: 0.0169
Epoch 83/128
8/8 [==============================] - 1s 91ms/step - loss: 0.0791 - policy_loss: 0.0192 - value_loss: 0.0156
Epoch 84/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0794 - policy_loss: 0.0155 - value_loss: 0.0195
Epoch 85/128
8/8 [==============================] - 1s 90ms/step - loss: 0.0729 - policy_loss: 0.0131 - value_loss: 0.0155
Epoch 86/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0756 - policy_loss: 0.0119 - value_loss: 0.0194
Epoch 87/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0748 - policy_loss: 0.0142 - value_loss: 0.0163
Epoch 88/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0735 - policy_loss: 0.0136 - value_loss: 0.0156
Epoch 89/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0721 - policy_loss: 0.0108 - value_loss: 0.0171
Epoch 90/128
8/8 [==============================] - 1s 85ms/step - loss: 0.0678 - policy_loss: 0.0088 - value_loss: 0.0148
Epoch 91/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0677 - policy_loss: 0.0082 - value_loss: 0.0154
Epoch 92/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0676 - policy_loss: 0.0112 - value_loss: 0.0122
Epoch 93/128
8/8 [==============================] - 1s 84ms/step - loss: 0.0667 - policy_loss: 0.0097 - value_loss: 0.0129
Epoch 94/128
8/8 [==============================] - 1s 88ms/step - loss: 0.0621 - policy_loss: 0.0083 - value_loss: 0.0097
Epoch 95/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0686 - policy_loss: 0.0097 - value_loss: 0.0148
Epoch 96/128
8/8 [==============================] - 1s 88ms/step - loss: 0.0617 - policy_loss: 0.0083 - value_loss: 0.0093
Epoch 97/128
8/8 [==============================] - 1s 86ms/step - loss: 0.0614 - policy_loss: 0.0082 - value_loss: 0.0092
Epoch 98/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0621 - policy_loss: 0.0093 - value_loss: 0.0087
Epoch 99/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0584 - policy_loss: 0.0080 - value_loss: 0.0064
Epoch 100/128
8/8 [==============================] - 1s 88ms/step - loss: 0.0558 - policy_loss: 0.0064 - value_loss: 0.0054
Epoch 101/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0607 - policy_loss: 0.0084 - value_loss: 0.0084
Epoch 102/128
8/8 [==============================] - 1s 84ms/step - loss: 0.0579 - policy_loss: 0.0062 - value_loss: 0.0077
Epoch 103/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0568 - policy_loss: 0.0061 - value_loss: 0.0067
Epoch 104/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0581 - policy_loss: 0.0066 - value_loss: 0.0076
Epoch 105/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0585 - policy_loss: 0.0057 - value_loss: 0.0089
Epoch 106/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0597 - policy_loss: 0.0066 - value_loss: 0.0092
Epoch 107/128
8/8 [==============================] - 1s 88ms/step - loss: 0.0564 - policy_loss: 0.0061 - value_loss: 0.0065
Epoch 108/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0564 - policy_loss: 0.0059 - value_loss: 0.0067
Epoch 109/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0592 - policy_loss: 0.0065 - value_loss: 0.0089
Epoch 110/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0553 - policy_loss: 0.0049 - value_loss: 0.0066
Epoch 111/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0544 - policy_loss: 0.0045 - value_loss: 0.0062
Epoch 112/128
8/8 [==============================] - 1s 89ms/step - loss: 0.0548 - policy_loss: 0.0054 - value_loss: 0.0057
Epoch 113/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0544 - policy_loss: 0.0039 - value_loss: 0.0067
Epoch 114/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0555 - policy_loss: 0.0049 - value_loss: 0.0069
Epoch 115/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0541 - policy_loss: 0.0053 - value_loss: 0.0052
Epoch 116/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0531 - policy_loss: 0.0041 - value_loss: 0.0054
Epoch 117/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0511 - policy_loss: 0.0037 - value_loss: 0.0038
Epoch 118/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0544 - policy_loss: 0.0060 - value_loss: 0.0049
Epoch 119/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0558 - policy_loss: 0.0038 - value_loss: 0.0084
Epoch 120/128
8/8 [==============================] - 1s 89ms/step - loss: 0.0521 - policy_loss: 0.0034 - value_loss: 0.0052
Epoch 121/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0520 - policy_loss: 0.0051 - value_loss: 0.0035
Epoch 122/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0531 - policy_loss: 0.0049 - value_loss: 0.0048
Epoch 123/128
8/8 [==============================] - 1s 84ms/step - loss: 0.0526 - policy_loss: 0.0043 - value_loss: 0.0049
Epoch 124/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0526 - policy_loss: 0.0042 - value_loss: 0.0050
Epoch 125/128
8/8 [==============================] - 1s 82ms/step - loss: 0.0515 - policy_loss: 0.0041 - value_loss: 0.0040
Epoch 126/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0512 - policy_loss: 0.0034 - value_loss: 0.0044
Epoch 127/128
8/8 [==============================] - 1s 91ms/step - loss: 0.0511 - policy_loss: 0.0033 - value_loss: 0.0045
Epoch 128/128
8/8 [==============================] - 1s 83ms/step - loss: 0.0506 - policy_loss: 0.0036 - value_loss: 0.0037
In [10]:
def linesToDotsAndBoxesImage(lines, imgWidth, imgHeight):
    boxWidth = imgSizeToBoxes(imgWidth)
    boxHeight = imgSizeToBoxes(imgHeight)
    linesCnt = 2*boxWidth*boxHeight+boxWidth+boxHeight
    mat = np.zeros((imgHeight, imgWidth), dtype=lines.dtype)
    for idx in range(linesCnt):
        y1 = idx / ((2*boxWidth) + 1)
        if idx % ((2*boxWidth) + 1) < boxWidth:
            # horizontal line
            x1 = idx % ((2*boxWidth) + 1)
            x2 = x1 + 1
            y2 = y1
        else:
            # vertical line
            x1 = idx % ((2*boxWidth) + 1) - boxWidth
            x2 = x1
            y2 = y1 + 1
        px = x2 * 2 + y2 - y1
        py = y2 * 2 + x2 - x1
        mat[py,px] = lines[idx]
    return mat
In [13]:
example = random.randrange(x_input_raw.shape[0])
print("example: "+str(example))

input_data = process_input(x_input_raw[example:example+1])

(prediction_lines, prediction_value) = model.predict(input_data)
prediction_lines_print = prediction_lines * 100
print(prediction_lines_print.astype(np.uint8))
print(np.sum(prediction_lines))
prediction = linesToDotsAndBoxesImage(prediction_lines[0], imgWidth, imgHeight)

# print input data
input_data_print = x_input_raw[example]
print("input "+str(input_data_print.shape)+": ")
print(input_data_print)

# generate greyscale image data from input data
input_imgdata = x_input_raw[example]

# print prediction
prediction_data_print = prediction * 100 
prediction_data_print = prediction_data_print.astype(np.uint8)
print("prediction policy: ")
print(prediction_data_print)

print("prediction value: ")
print(prediction_value)

print("target value: ")
print(y_value_raw[example])

# generate greyscale image data from prediction data
prediction_imgdata = prediction * 255
prediction_imgdata = prediction_imgdata.astype(np.uint8)

# generate greyscale image of target data
target_imgdata = y_policy_raw[example]

# merge image data in color channels
merged_imgdata = np.stack([input_imgdata, prediction_imgdata, target_imgdata], axis=2)

#create image
img = Image.fromarray(merged_imgdata, 'RGB')
img = img.resize(size=(img.size[0]*10, img.size[1]*10))

img
example: 26
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  99]]
1.0
input (15, 17): 
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0 215 255 215   0 215 255 215   0 215 255 215 255 215   0 215   0]
 [  0   0   0 255   0   0   0 255   0 255   0   0   0   0   0 255   0]
 [  0 215   0 215   0 215   0 215   0 215   0 215   0 215 255 215   0]
 [  0 255   0 255   0 255   0 255   0   0   0   0   0 255   0   0   0]
 [  0 215   0 215   0 215   0 215 255 215   0 215 255 215   0 215   0]
 [  0   0   0 255   0 255   0   0   0 255   0 255   0   0   0 255   0]
 [  0 215 255 215   0 215   0 215   0 215   0 215   0 215 255 215   0]
 [  0 255   0   0   0 255   0   0   0   0   0 255   0   0   0   0   0]
 [  0 215   0 215 255 215 255 215 255 215 255 215   0 215 255 215   0]
 [  0   0   0 255   0   0   0 255   0   0   0 255   0 255   0   0   0]
 [  0 215   0 215   0 215   0 215   0 215   0 215   0 215   0 215   0]
 [  0 255   0 255   0 255   0   0   0   0   0 255   0   0   0   0   0]
 [  0 215   0 215   0 215 255 215 255 215 255 215 255 215   0 215   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
prediction policy: 
[[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0 99  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]]
prediction value: 
[[-0.41]]
target value: 
[-0.38]
Out[13]:
In [ ]: