2022-07-14
In this post we will show you how to train supervised model based on TensorFlow.js to play classic snake arcade. It involves a model with eight inputs - four for the neighbours of snake head, two for snake head coordinates on a board and two for coordinates of food piece. There are four outputs - one for each direction snake might turn on the next step of the game. Therefore activation of the last layer of the model is softmax and loss function is categoricalCrossentropy. There is a lengthy function in the code that defines the logic of simulating the game to be able to train the model.
As you will see we use a hack to store the model on Firebase Realtime Database while temporarily placing it in browser local storage before uploading. This is a shortcomming of TensorFlow.js that is quite easily overcome. As you will see from the code we use Three.js library to draw the actual game elements. There is a game loop at the end of code where actual training happens and it takes about 20.000 steps of simulation to train the model that is capable of playing the game itself.
let robot, training, stepper = 0 ;(async function app() { compile = () => { robot.compile({ optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy'] }) } initialize = () => { robot = tf.sequential({ layers: [ tf.layers.flatten({ inputShape: [8, 1] }), tf.layers.dense({ units: 100, activation: 'relu', kernelInitializer: 'varianceScaling' }), tf.layers.dense({ units: 4, activation: 'softmax', kernelInitializer: 'varianceScaling' }) ] }) compile() } await (() => { return new Promise(resolve => { firebase.database().ref().once('value').then(result => { localStorage.setItem('tensorflowjs_models/robot/info', result.val().info) localStorage.setItem('tensorflowjs_models/robot/model_topology', result.val().model_topology) localStorage.setItem('tensorflowjs_models/robot/weight_data', result.val().weight_data) localStorage.setItem('tensorflowjs_models/robot/weight_specs', result.val().weight_specs) tf.loadLayersModel('localstorage://robot').then(result => { tf.io.removeModel('localstorage://robot').then(() => { robot = result compile() resolve() }) }) }) }) })() scene = new THREE.Scene() scene.background = new THREE.Color( 0xffffff ) camera = new THREE.PerspectiveCamera(60) scene.add(camera) sky = new THREE.AmbientLight() scene.add(sky) light = new THREE.DirectionalLight() scene.add(light) renderer = new THREE.WebGLRenderer() ;(onresize = onorientationchange = () => { if (innerWidth > innerHeight) { renderer.setSize(2*innerHeight,2*innerHeight) renderer.domElement.style.width = 'auto' renderer.domElement.style.height = '100%' } else { renderer.setSize(2*innerWidth,2*innerWidth) renderer.domElement.style.width = '100%' renderer.domElement.style.height = 'auto' } })() wall = new THREE.Mesh(new THREE.SphereGeometry(0.8), new THREE.MeshLambertMaterial({ color: new THREE.Color('grey') })) segment = new THREE.Mesh(new THREE.SphereGeometry(0.7), new THREE.MeshLambertMaterial({ color: new THREE.Color('green') })) apple = new THREE.Mesh(new THREE.SphereGeometry(0.3), new THREE.MeshLambertMaterial({ color: new THREE.Color('green') })) let level, counter field = [] drive = [[-1, 0], [0, 1], [1, 0], [0, -1]] dir = 1 snake = [] food = [] score = () => document.title = 'Score ' + counter piece = () => { do { randomx = Math.round(2 * level * Math.random()) randomy = Math.round(2 * level * Math.random()) if (field[randomx][randomy].name == 'open' && field[randomx + drive[0][0]][randomy + drive[0][1]].name == 'open' && field[randomx + drive[1][0]][randomy + drive[1][1]].name == 'open' && field[randomx + drive[2][0]][randomy + drive[2][1]].name == 'open' && field[randomx + drive[3][0]][randomy + drive[3][1]].name == 'open') { field[randomx][randomy] = apple.clone() field[randomx][randomy].name = 'food' field[randomx][randomy].position.set(randomx - level, randomy - level, 0) scene.add(field[randomx][randomy]) break } } while (true) food.push(field[randomx][randomy]) } eat = next => { piece() food.forEach((piece, index) => { if (piece.position.x == next.position.x && piece.position.y == next.position.y) { scene.remove(field[piece.position.x + level][piece.position.y + level]) field[piece.position.x + level][piece.position.y + level] = { name: 'open' } food.splice(index, 1) } }) } ;(begin = () => { food.forEach(piece => { scene.remove(field[piece.position.x + level][piece.position.y + level]) field[piece.position.x + level][piece.position.y + level] = { name: 'open' } }) while (food.length) food.pop() snake.forEach(segment => { scene.remove(field[segment.position.x + level][segment.position.y + level]) field[segment.position.x + level][segment.position.y + level] = { name: 'open' } }) while (snake.length) snake.pop() field.forEach(line => { line.forEach(square => { if (square.name == 'wall') scene.remove(field[square.position.x + level][square.position.y + level]) }) while (line.length) line.pop() }) while (field.length) field.pop() level = 8 counter = 0 camera.position.set(0, 0, 2 * level) camera.lookAt(scene.position) light.position.copy(camera.position) light.lookAt(scene.position) for (x = 0; x < 2 * level + 1; x++) { field[x] = [] for (y = 0; y < 2 * level + 1; y++) { if (x == 0 || y == 0 || x == 2 * level || y == 2 * level) { field[x][y] = wall.clone() field[x][y].name = 'wall' field[x][y].position.set(x - level, y - level, 0) scene.add(field[x][y]) } else { field[x][y] = { name: 'open' } } }} for (i = 0; i < level; i++) { field[level][level - i] = segment.clone() field[level][level - i].name = 'wall' field[level][level - i].position.set(0, -i, 0) scene.add(field[level][level - i]) snake[i] = field[level][level - i] } dir = 1 for (n = food.length; n < level/3 - 2; n++) piece() ;(async function step() { ;(find = () => { path = [] path.push(snake[0]) if (path[0].position.x == food[0].position.x) { if (path[0].position.y < food[0].position.y) { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (Math.round(Math.random())) { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } } } } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } } } } } } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (Math.round(Math.random())) { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } } } } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } } } } } } } else { if (path[0].position.y == food[0].position.y) { if (path[0].position.x < food[0].position.x) { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (Math.round(Math.random())) { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } } } } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } } } } } } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (Math.round(Math.random())) { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } } } } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } } } } } } } else { if (Math.round(Math.random())) { if (path[0].position.x < food[0].position.x) { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (path[0].position.y < food[0].position.y) { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } } } } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } } } } } } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (path[0].position.y < food[0].position.y) { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } } } } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } } } } } } } else { if (path[0].position.y < food[0].position.y) { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (path[0].position.x < food[0].position.x) { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } } } } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } } } } } } else { if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') { dir = 3 } else { if (path[0].position.x < food[0].position.x) { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } } } } else { if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') { dir = 0 } else { if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') { dir = 1 } else { if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') { dir = 2 } } } } } } } } } })() let x = [], y = [] drive.forEach(neighbor => { if (field[snake[0].position.x + neighbor[0] + level][snake[0].position.y + neighbor[1] + level].name == 'open' || field[snake[0].position.x + neighbor[0] + level][snake[0].position.y + neighbor[1] + level].name == 'food') { x.push(0) } else { x.push(1) } }) x.push(snake[0].position.x) x.push(snake[0].position.y) x.push(food[0].position.x) x.push(food[0].position.y) for (i = 0; i < 4; i++) { if (i == dir) { y.push(1) } else { y.push(0) } } if (training) { await (() => { return new Promise(resolve => { robot.fit( tf.tidy(() => { return tf.tensor2d(x, [8, 1]).expandDims() }), tf.tidy(() => { return tf.tensor2d(y, [1, 4]) }), { callbacks: { onEpochEnd: (epoch, log) => document.title = stepper + ' ' + log.loss }} ).then(() => resolve()) }) })() } else { await (() => { return new Promise(resolve => { robot.predict( tf.tidy(() => { return tf.tensor2d(x, [8, 1]).expandDims() }), ).data().then(data => { console.log(data[0],data[1],data[2],data[3]) if (data[0] > data[1] && data[0] > data[2] && data[0] > data[3]) dir = 0 if (data[1] > data[0] && data[1] > data[2] && data[1] > data[3]) dir = 1 if (data[2] > data[0] && data[2] > data[1] && data[2] > data[3]) dir = 2 if (data[3] > data[0] && data[3] > data[1] && data[3] > data[2]) dir = 3 resolve() }) }) })() } stepper++ next = field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level] if (next.name == 'open' || next.name == 'food') { if (next.name == 'food') { counter++ eat(next) } field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level] = segment.clone() if (next.name == 'food') field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].scale.set(1.2, 1.2, 1.2) field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].name = 'wall' field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].position.set(snake[0].position.x + drive[dir][0], snake[0].position.y + drive[dir][1], 0) scene.add(field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level]) snake.unshift(field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level]) if (next.name == 'open') { scene.remove(field[snake[snake.length - 1].position.x + level][snake[snake.length - 1].position.y + level]) field[snake[snake.length - 1].position.x + level][snake[snake.length - 1].position.y + level] = { name: 'open' } snake.pop() } } else { if (training) { begin() } else { setTimeout(begin, 100) } return } if (training) { step() } else { setTimeout(step, 100) } })() })() ;(animate = () => { requestAnimationFrame(animate) renderer.render(scene, camera) })() document.onkeyup = e => { if (e.keyCode == 32) { if (training) { firebase.auth().onAuthStateChanged(user => { if (user) { robot.save('localstorage://robot').then(() => { firebase.database().ref().set({ info: localStorage.getItem('tensorflowjs_models/robot/info'), model_topology: localStorage.getItem('tensorflowjs_models/robot/model_topology'), weight_data: localStorage.getItem('tensorflowjs_models/robot/weight_data'), weight_specs: localStorage.getItem('tensorflowjs_models/robot/weight_specs') }).then(() => { tf.io.removeModel('localstorage://robot').then(() => training = false) }) }) } else firebase.auth().signInWithRedirect(new firebase.auth.GoogleAuthProvider()) }) } else { initialize() stepper = 0 training = true } } } document.body.appendChild(renderer.domElement) })()
At the end of the code we see a function that uploads trained model to database on a key press. Further this model can be loaded on a separate page as you can see here. If you would like to know more details about this solution make sure to contact us by the address at the bottom of the page.