BAKO Blog

Training Supervised Model to Play Classic Snake Arcade

2022-07-14

In this post we will show you how to train supervised model based on TensorFlow.js to play classic snake arcade. It involves a model with eight inputs - four for the neighbours of snake head, two for snake head coordinates on a board and two for coordinates of food piece. There are four outputs - one for each direction snake might turn on the next step of the game. Therefore activation of the last layer of the model is softmax and loss function is categoricalCrossentropy. There is a lengthy function in the code that defines the logic of simulating the game to be able to train the model.

As you will see we use a hack to store the model on Firebase Realtime Database while temporarily placing it in browser local storage before uploading. This is a shortcomming of TensorFlow.js that is quite easily overcome. As you will see from the code we use Three.js library to draw the actual game elements. There is a game loop at the end of code where actual training happens and it takes about 20.000 steps of simulation to train the model that is capable of playing the game itself.

let robot, training, stepper = 0

;(async function app() {
    compile = () => {
        robot.compile({ optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy'] })
    }

    initialize = () => {
        robot = tf.sequential({
            layers: [
                tf.layers.flatten({ inputShape: [8, 1] }),
                tf.layers.dense({
                    units: 100,
                    activation: 'relu',
                    kernelInitializer: 'varianceScaling'
                }),
                tf.layers.dense({
                    units: 4,
                    activation: 'softmax',
                    kernelInitializer: 'varianceScaling'
                })
            ]
        })
        compile()
    }

    await (() => {
        return new Promise(resolve => {
            firebase.database().ref().once('value').then(result => {
                localStorage.setItem('tensorflowjs_models/robot/info', result.val().info)
                localStorage.setItem('tensorflowjs_models/robot/model_topology', result.val().model_topology)
                localStorage.setItem('tensorflowjs_models/robot/weight_data', result.val().weight_data)
                localStorage.setItem('tensorflowjs_models/robot/weight_specs', result.val().weight_specs)
                tf.loadLayersModel('localstorage://robot').then(result => {
                    tf.io.removeModel('localstorage://robot').then(() => {
                        robot = result
                        compile()
                        resolve()
                    })
                })
            })
        })
    })()

    scene = new THREE.Scene()
    scene.background = new THREE.Color( 0xffffff )
    camera = new THREE.PerspectiveCamera(60)
    scene.add(camera)

    sky = new THREE.AmbientLight()
    scene.add(sky)
    light = new THREE.DirectionalLight()
    scene.add(light)

    renderer = new THREE.WebGLRenderer()

    ;(onresize = onorientationchange = () => {
        if (innerWidth > innerHeight) {
            renderer.setSize(2*innerHeight,2*innerHeight)
            renderer.domElement.style.width = 'auto'
            renderer.domElement.style.height = '100%'
        } else {
            renderer.setSize(2*innerWidth,2*innerWidth)
            renderer.domElement.style.width = '100%'
            renderer.domElement.style.height = 'auto'
        }
    })()

    wall = new THREE.Mesh(new THREE.SphereGeometry(0.8), new THREE.MeshLambertMaterial({ color: new THREE.Color('grey') }))
    segment = new THREE.Mesh(new THREE.SphereGeometry(0.7), new THREE.MeshLambertMaterial({ color: new THREE.Color('green') }))
    apple = new THREE.Mesh(new THREE.SphereGeometry(0.3), new THREE.MeshLambertMaterial({ color: new THREE.Color('green') }))

    let level, counter

    field = []
    drive = [[-1, 0], [0, 1], [1, 0], [0, -1]]
    dir = 1
    snake = []
    food = []

    score = () => document.title = 'Score ' + counter

    piece = () => {
        do {
            randomx = Math.round(2 * level * Math.random())
            randomy = Math.round(2 * level * Math.random())
            if (field[randomx][randomy].name == 'open' &&
                field[randomx + drive[0][0]][randomy + drive[0][1]].name == 'open' &&
                field[randomx + drive[1][0]][randomy + drive[1][1]].name == 'open' &&
                field[randomx + drive[2][0]][randomy + drive[2][1]].name == 'open' &&
                field[randomx + drive[3][0]][randomy + drive[3][1]].name == 'open') {
                field[randomx][randomy] = apple.clone()
                field[randomx][randomy].name = 'food'
                field[randomx][randomy].position.set(randomx - level, randomy - level, 0)
                scene.add(field[randomx][randomy])
                break
            }		
        } while (true)
        food.push(field[randomx][randomy])
    }

    eat = next => {
        piece()
        food.forEach((piece, index) => {
            if (piece.position.x == next.position.x && piece.position.y == next.position.y) {
                scene.remove(field[piece.position.x + level][piece.position.y + level])
                field[piece.position.x + level][piece.position.y + level] = { name: 'open' }
                food.splice(index, 1)
            }
        })
    }

    ;(begin = () => {
        food.forEach(piece => {
            scene.remove(field[piece.position.x + level][piece.position.y + level])
            field[piece.position.x + level][piece.position.y + level] = { name: 'open' }
        })
        while (food.length) food.pop()
        
        snake.forEach(segment => {
            scene.remove(field[segment.position.x + level][segment.position.y + level])
            field[segment.position.x + level][segment.position.y + level] = { name: 'open' }
        })
        while (snake.length) snake.pop()
        
        field.forEach(line => {
            line.forEach(square => {
                if (square.name == 'wall') scene.remove(field[square.position.x + level][square.position.y + level])
            })
            while (line.length) line.pop()
        })
        while (field.length) field.pop()

        level = 8
        counter = 0
        
        camera.position.set(0, 0, 2 * level)
        camera.lookAt(scene.position)
        light.position.copy(camera.position)
        light.lookAt(scene.position)
        
        for (x = 0; x < 2 * level + 1; x++) { field[x] = []
        for (y = 0; y < 2 * level + 1; y++) {
            if (x == 0 || y == 0 || x == 2 * level || y == 2 * level) {
                field[x][y] = wall.clone()
                field[x][y].name = 'wall'
                field[x][y].position.set(x - level, y - level, 0)
                scene.add(field[x][y])
            } else {
                field[x][y] = { name: 'open' }
            }
        }}

        for (i = 0; i < level; i++) {
            field[level][level - i] = segment.clone()
            field[level][level - i].name = 'wall'
            field[level][level - i].position.set(0, -i, 0)
            scene.add(field[level][level - i])
            snake[i] = field[level][level - i]
        }
        dir = 1
        for (n = food.length; n < level/3 - 2; n++) piece()

        ;(async function step() {
            ;(find = () => {
                path = []
                path.push(snake[0])

                if (path[0].position.x == food[0].position.x) {
                    if (path[0].position.y < food[0].position.y) {
                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                            dir = 1
                        } else {
                            if (Math.round(Math.random())) {
                                if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                    dir = 2
                                } else {
                                    if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                        dir = 0
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                            dir = 3
                                        }
                                    }
                                }
                            } else {
                                if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                    dir = 0
                                } else {
                                    if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                        dir = 2
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                            dir = 3
                                        }
                                    }
                                }
                            }
                        }
                    } else {
                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                            dir = 3
                        } else {
                            if (Math.round(Math.random())) {
                                if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                    dir = 2
                                } else {
                                    if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                        dir = 0
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                            dir = 1
                                        }
                                    }
                                }
                            } else {
                                if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                    dir = 0
                                } else {
                                    if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                        dir = 2
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                            dir = 1
                                        }
                                    }
                                }
                            }
                        }
                    }
                } else {
                    if (path[0].position.y == food[0].position.y) {
                        if (path[0].position.x < food[0].position.x) {
                            if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                dir = 2
                            } else {
                                if (Math.round(Math.random())) {
                                    if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                        dir = 1
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                            dir = 3
                                        } else {
                                            if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 0
                                            }
                                        }
                                    }
                                } else {
                                    if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                        dir = 3
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                            dir = 1
                                        } else {
                                            if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 0
                                            }
                                        }      
                                    }
                                }
                            }
                        } else {
                            if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                dir = 0
                            } else {
                                if (Math.round(Math.random())) {
                                    if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                        dir = 1
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                            dir = 3
                                        } else {
                                            if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 2
                                            }
                                        }
                                    }
                                } else {
                                    if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                        dir = 3
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                            dir = 1
                                        } else {
                                            if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 2
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    } else {
                        if (Math.round(Math.random())) {
                            if (path[0].position.x < food[0].position.x) {
                                if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                    dir = 2
                                } else {
                                    if (path[0].position.y < food[0].position.y) {
                                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                            dir = 1
                                        } else {
                                            if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 0
                                            } else {
                                                if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                                    dir = 3
                                                }
                                            }
                                        }
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                            dir = 3
                                        } else {
                                            if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 0
                                            } else {
                                                if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                                    dir = 1
                                                }
                                            }
                                        }
                                    }
                                }
                            } else {
                                if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                    dir = 0
                                } else {
                                    if (path[0].position.y < food[0].position.y) {
                                        if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                            dir = 1
                                        } else {
                                            if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 2
                                            } else {
                                                if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                                    dir = 3
                                                }
                                            }
                                        }
                                    } else {
                                        if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                            dir = 3
                                        } else {
                                            if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                                dir = 2
                                            } else {
                                                if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                                    dir = 1
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        } else {
                            if (path[0].position.y < food[0].position.y) {
                                if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                    dir = 1
                                } else {
                                    if (path[0].position.x < food[0].position.x) {
                                        if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                            dir = 2
                                        } else {
                                            if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                                dir = 3
                                            } else {
                                                if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                                    dir = 0
                                                }
                                            }
                                        }
                                    } else {
                                        if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                            dir = 0
                                        } else {
                                            if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                                dir = 3
                                            } else {
                                                if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                                    dir = 2
                                                }
                                            }
                                        }
                                    }
                                }
                            } else {
                                if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
                                    dir = 3
                                } else {
                                    if (path[0].position.x < food[0].position.x) {
                                        if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                            dir = 2
                                        } else {
                                            if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                                dir = 1
                                            } else {
                                                if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                                    dir = 0
                                                }
                                            }
                                        }
                                    } else {
                                        if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
                                            dir = 0
                                        } else {
                                            if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
                                                dir = 1
                                            } else {
                                                if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
                                                    dir = 2
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            })()

            let x = [], y = []

            drive.forEach(neighbor => {
                if (field[snake[0].position.x + neighbor[0] + level][snake[0].position.y + neighbor[1] + level].name == 'open' || field[snake[0].position.x + neighbor[0] + level][snake[0].position.y + neighbor[1] + level].name == 'food') {
                    x.push(0)
                } else {
                    x.push(1)
                }
            })
            x.push(snake[0].position.x)
            x.push(snake[0].position.y)
            x.push(food[0].position.x)
            x.push(food[0].position.y)

            for (i = 0; i < 4; i++) {
                if (i == dir) {
                    y.push(1)
                } else {
                    y.push(0)
                }
            }
            
            if (training) {
                await (() => {
                    return new Promise(resolve => {
                        robot.fit(
                            tf.tidy(() => { return tf.tensor2d(x, [8, 1]).expandDims() }),
                            tf.tidy(() => { return tf.tensor2d(y, [1, 4]) }),
                            { callbacks: { onEpochEnd: (epoch, log) => document.title = stepper + ' ' + log.loss }}
                        ).then(() => resolve())
                    })
                })()
            } else {
                await (() => {
                    return new Promise(resolve => {
                        robot.predict(
                            tf.tidy(() => { return tf.tensor2d(x, [8, 1]).expandDims() }),
                        ).data().then(data => {
                            console.log(data[0],data[1],data[2],data[3])
                            if (data[0] > data[1] && data[0] > data[2] && data[0] > data[3]) dir = 0
                            if (data[1] > data[0] && data[1] > data[2] && data[1] > data[3]) dir = 1
                            if (data[2] > data[0] && data[2] > data[1] && data[2] > data[3]) dir = 2
                            if (data[3] > data[0] && data[3] > data[1] && data[3] > data[2]) dir = 3
                            resolve()
                        })
                    })
                })()
            }

            stepper++

            next = field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level]
            if (next.name == 'open' || next.name == 'food') {
                if (next.name == 'food') {
                    counter++
                    eat(next)
                }
                field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level] = segment.clone()
                if (next.name == 'food') field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].scale.set(1.2, 1.2, 1.2)
                field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].name = 'wall'
                field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].position.set(snake[0].position.x + drive[dir][0], snake[0].position.y + drive[dir][1], 0)
                scene.add(field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level])
                snake.unshift(field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level])
                if (next.name == 'open') {
                    scene.remove(field[snake[snake.length - 1].position.x + level][snake[snake.length - 1].position.y + level])
                    field[snake[snake.length - 1].position.x + level][snake[snake.length - 1].position.y + level] = { name: 'open' }
                    snake.pop()
                }
            } else {
                if (training) { begin() } else { setTimeout(begin, 100) }
                return
            }
            if (training) { step() } else { setTimeout(step, 100) }
        })()
    })()

    ;(animate = () => {
        requestAnimationFrame(animate)
        renderer.render(scene, camera)
    })()

    document.onkeyup = e => {
        if (e.keyCode == 32) {
            if (training) {
                firebase.auth().onAuthStateChanged(user => {
                    if (user) {
                        robot.save('localstorage://robot').then(() => {
                            firebase.database().ref().set({
                                info: localStorage.getItem('tensorflowjs_models/robot/info'),
                                model_topology: localStorage.getItem('tensorflowjs_models/robot/model_topology'),
                                weight_data: localStorage.getItem('tensorflowjs_models/robot/weight_data'),
                                weight_specs: localStorage.getItem('tensorflowjs_models/robot/weight_specs')
                            }).then(() => {
                                tf.io.removeModel('localstorage://robot').then(() => training = false)
                            })
                        })
                    } else firebase.auth().signInWithRedirect(new firebase.auth.GoogleAuthProvider())
                })
            } else {
                initialize()
                stepper = 0
                training = true
            }
        }
    }

    document.body.appendChild(renderer.domElement)
})()
                    

At the end of the code we see a function that uploads trained model to database on a key press. Further this model can be loaded on a separate page as you can see here. If you would like to know more details about this solution make sure to contact us by the address at the bottom of the page.

Get in touch

blog@bako.co

BAKO