2022-07-14
In this post we will show you how to train supervised model based on TensorFlow.js to play classic snake arcade. It involves a model with eight inputs - four for the neighbours of snake head, two for snake head coordinates on a board and two for coordinates of food piece. There are four outputs - one for each direction snake might turn on the next step of the game. Therefore activation of the last layer of the model is softmax and loss function is categoricalCrossentropy. There is a lengthy function in the code that defines the logic of simulating the game to be able to train the model.
As you will see we use a hack to store the model on Firebase Realtime Database while temporarily placing it in browser local storage before uploading. This is a shortcomming of TensorFlow.js that is quite easily overcome. As you will see from the code we use Three.js library to draw the actual game elements. There is a game loop at the end of code where actual training happens and it takes about 20.000 steps of simulation to train the model that is capable of playing the game itself.
let robot, training, stepper = 0
;(async function app() {
compile = () => {
robot.compile({ optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy'] })
}
initialize = () => {
robot = tf.sequential({
layers: [
tf.layers.flatten({ inputShape: [8, 1] }),
tf.layers.dense({
units: 100,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}),
tf.layers.dense({
units: 4,
activation: 'softmax',
kernelInitializer: 'varianceScaling'
})
]
})
compile()
}
await (() => {
return new Promise(resolve => {
firebase.database().ref().once('value').then(result => {
localStorage.setItem('tensorflowjs_models/robot/info', result.val().info)
localStorage.setItem('tensorflowjs_models/robot/model_topology', result.val().model_topology)
localStorage.setItem('tensorflowjs_models/robot/weight_data', result.val().weight_data)
localStorage.setItem('tensorflowjs_models/robot/weight_specs', result.val().weight_specs)
tf.loadLayersModel('localstorage://robot').then(result => {
tf.io.removeModel('localstorage://robot').then(() => {
robot = result
compile()
resolve()
})
})
})
})
})()
scene = new THREE.Scene()
scene.background = new THREE.Color( 0xffffff )
camera = new THREE.PerspectiveCamera(60)
scene.add(camera)
sky = new THREE.AmbientLight()
scene.add(sky)
light = new THREE.DirectionalLight()
scene.add(light)
renderer = new THREE.WebGLRenderer()
;(onresize = onorientationchange = () => {
if (innerWidth > innerHeight) {
renderer.setSize(2*innerHeight,2*innerHeight)
renderer.domElement.style.width = 'auto'
renderer.domElement.style.height = '100%'
} else {
renderer.setSize(2*innerWidth,2*innerWidth)
renderer.domElement.style.width = '100%'
renderer.domElement.style.height = 'auto'
}
})()
wall = new THREE.Mesh(new THREE.SphereGeometry(0.8), new THREE.MeshLambertMaterial({ color: new THREE.Color('grey') }))
segment = new THREE.Mesh(new THREE.SphereGeometry(0.7), new THREE.MeshLambertMaterial({ color: new THREE.Color('green') }))
apple = new THREE.Mesh(new THREE.SphereGeometry(0.3), new THREE.MeshLambertMaterial({ color: new THREE.Color('green') }))
let level, counter
field = []
drive = [[-1, 0], [0, 1], [1, 0], [0, -1]]
dir = 1
snake = []
food = []
score = () => document.title = 'Score ' + counter
piece = () => {
do {
randomx = Math.round(2 * level * Math.random())
randomy = Math.round(2 * level * Math.random())
if (field[randomx][randomy].name == 'open' &&
field[randomx + drive[0][0]][randomy + drive[0][1]].name == 'open' &&
field[randomx + drive[1][0]][randomy + drive[1][1]].name == 'open' &&
field[randomx + drive[2][0]][randomy + drive[2][1]].name == 'open' &&
field[randomx + drive[3][0]][randomy + drive[3][1]].name == 'open') {
field[randomx][randomy] = apple.clone()
field[randomx][randomy].name = 'food'
field[randomx][randomy].position.set(randomx - level, randomy - level, 0)
scene.add(field[randomx][randomy])
break
}
} while (true)
food.push(field[randomx][randomy])
}
eat = next => {
piece()
food.forEach((piece, index) => {
if (piece.position.x == next.position.x && piece.position.y == next.position.y) {
scene.remove(field[piece.position.x + level][piece.position.y + level])
field[piece.position.x + level][piece.position.y + level] = { name: 'open' }
food.splice(index, 1)
}
})
}
;(begin = () => {
food.forEach(piece => {
scene.remove(field[piece.position.x + level][piece.position.y + level])
field[piece.position.x + level][piece.position.y + level] = { name: 'open' }
})
while (food.length) food.pop()
snake.forEach(segment => {
scene.remove(field[segment.position.x + level][segment.position.y + level])
field[segment.position.x + level][segment.position.y + level] = { name: 'open' }
})
while (snake.length) snake.pop()
field.forEach(line => {
line.forEach(square => {
if (square.name == 'wall') scene.remove(field[square.position.x + level][square.position.y + level])
})
while (line.length) line.pop()
})
while (field.length) field.pop()
level = 8
counter = 0
camera.position.set(0, 0, 2 * level)
camera.lookAt(scene.position)
light.position.copy(camera.position)
light.lookAt(scene.position)
for (x = 0; x < 2 * level + 1; x++) { field[x] = []
for (y = 0; y < 2 * level + 1; y++) {
if (x == 0 || y == 0 || x == 2 * level || y == 2 * level) {
field[x][y] = wall.clone()
field[x][y].name = 'wall'
field[x][y].position.set(x - level, y - level, 0)
scene.add(field[x][y])
} else {
field[x][y] = { name: 'open' }
}
}}
for (i = 0; i < level; i++) {
field[level][level - i] = segment.clone()
field[level][level - i].name = 'wall'
field[level][level - i].position.set(0, -i, 0)
scene.add(field[level][level - i])
snake[i] = field[level][level - i]
}
dir = 1
for (n = food.length; n < level/3 - 2; n++) piece()
;(async function step() {
;(find = () => {
path = []
path.push(snake[0])
if (path[0].position.x == food[0].position.x) {
if (path[0].position.y < food[0].position.y) {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (Math.round(Math.random())) {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
}
}
}
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
}
}
}
}
}
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (Math.round(Math.random())) {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
}
}
}
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
}
}
}
}
}
}
} else {
if (path[0].position.y == food[0].position.y) {
if (path[0].position.x < food[0].position.x) {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (Math.round(Math.random())) {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
}
}
}
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
}
}
}
}
}
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (Math.round(Math.random())) {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
}
}
}
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
}
}
}
}
}
}
} else {
if (Math.round(Math.random())) {
if (path[0].position.x < food[0].position.x) {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (path[0].position.y < food[0].position.y) {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
}
}
}
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
}
}
}
}
}
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (path[0].position.y < food[0].position.y) {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
}
}
}
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
}
}
}
}
}
}
} else {
if (path[0].position.y < food[0].position.y) {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (path[0].position.x < food[0].position.x) {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
}
}
}
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
}
}
}
}
}
} else {
if (field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y - 1 + level].name == 'food') {
dir = 3
} else {
if (path[0].position.x < food[0].position.x) {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
}
}
}
} else {
if (field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x - 1 + level][path[0].position.y + level].name == 'food') {
dir = 0
} else {
if (field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'open' || field[path[0].position.x + level][path[0].position.y + 1 + level].name == 'food') {
dir = 1
} else {
if (field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'open' || field[path[0].position.x + 1 + level][path[0].position.y + level].name == 'food') {
dir = 2
}
}
}
}
}
}
}
}
}
})()
let x = [], y = []
drive.forEach(neighbor => {
if (field[snake[0].position.x + neighbor[0] + level][snake[0].position.y + neighbor[1] + level].name == 'open' || field[snake[0].position.x + neighbor[0] + level][snake[0].position.y + neighbor[1] + level].name == 'food') {
x.push(0)
} else {
x.push(1)
}
})
x.push(snake[0].position.x)
x.push(snake[0].position.y)
x.push(food[0].position.x)
x.push(food[0].position.y)
for (i = 0; i < 4; i++) {
if (i == dir) {
y.push(1)
} else {
y.push(0)
}
}
if (training) {
await (() => {
return new Promise(resolve => {
robot.fit(
tf.tidy(() => { return tf.tensor2d(x, [8, 1]).expandDims() }),
tf.tidy(() => { return tf.tensor2d(y, [1, 4]) }),
{ callbacks: { onEpochEnd: (epoch, log) => document.title = stepper + ' ' + log.loss }}
).then(() => resolve())
})
})()
} else {
await (() => {
return new Promise(resolve => {
robot.predict(
tf.tidy(() => { return tf.tensor2d(x, [8, 1]).expandDims() }),
).data().then(data => {
console.log(data[0],data[1],data[2],data[3])
if (data[0] > data[1] && data[0] > data[2] && data[0] > data[3]) dir = 0
if (data[1] > data[0] && data[1] > data[2] && data[1] > data[3]) dir = 1
if (data[2] > data[0] && data[2] > data[1] && data[2] > data[3]) dir = 2
if (data[3] > data[0] && data[3] > data[1] && data[3] > data[2]) dir = 3
resolve()
})
})
})()
}
stepper++
next = field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level]
if (next.name == 'open' || next.name == 'food') {
if (next.name == 'food') {
counter++
eat(next)
}
field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level] = segment.clone()
if (next.name == 'food') field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].scale.set(1.2, 1.2, 1.2)
field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].name = 'wall'
field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level].position.set(snake[0].position.x + drive[dir][0], snake[0].position.y + drive[dir][1], 0)
scene.add(field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level])
snake.unshift(field[snake[0].position.x + drive[dir][0] + level][snake[0].position.y + drive[dir][1] + level])
if (next.name == 'open') {
scene.remove(field[snake[snake.length - 1].position.x + level][snake[snake.length - 1].position.y + level])
field[snake[snake.length - 1].position.x + level][snake[snake.length - 1].position.y + level] = { name: 'open' }
snake.pop()
}
} else {
if (training) { begin() } else { setTimeout(begin, 100) }
return
}
if (training) { step() } else { setTimeout(step, 100) }
})()
})()
;(animate = () => {
requestAnimationFrame(animate)
renderer.render(scene, camera)
})()
document.onkeyup = e => {
if (e.keyCode == 32) {
if (training) {
firebase.auth().onAuthStateChanged(user => {
if (user) {
robot.save('localstorage://robot').then(() => {
firebase.database().ref().set({
info: localStorage.getItem('tensorflowjs_models/robot/info'),
model_topology: localStorage.getItem('tensorflowjs_models/robot/model_topology'),
weight_data: localStorage.getItem('tensorflowjs_models/robot/weight_data'),
weight_specs: localStorage.getItem('tensorflowjs_models/robot/weight_specs')
}).then(() => {
tf.io.removeModel('localstorage://robot').then(() => training = false)
})
})
} else firebase.auth().signInWithRedirect(new firebase.auth.GoogleAuthProvider())
})
} else {
initialize()
stepper = 0
training = true
}
}
}
document.body.appendChild(renderer.domElement)
})()
At the end of the code we see a function that uploads trained model to database on a key press. Further this model can be loaded on a separate page as you can see here. If you would like to know more details about this solution make sure to contact us by the address at the bottom of the page.