Residual model from scratch with Tensorflow.js [Part 1]

Yury Kozyrev
3 min readSep 20, 2020

--

All over the internet, you can find a lot of examples of how to write sequential models using tf.sequential(). Really, it’s pretty straightforward.

But what if you want to build something custom and not linear, say a residual model like ResNet-18 or its simplified version? You have to use tf.model() then, which is much more flexible and… it has way fewer examples and tutorials.

I suffered for about a day with that and I hope this article might save you some hours.

What is a residual network?

First, I had a sequential convolutional network with a simple structure “convolution, convolution, max pooling, repeat, flatten it” like this:

Sequential convolution network VGG19

And there is a problem: when deeper networks start converging, a degradation problem is exposed: with the network depth increasing, accuracy gets saturated and then degrades rapidly.

The solution for this problem is a “residual connection” when you bypass your input to the result of the convolutions.

Residual connection called “identity”

You can find more theory and math in this detailed article.

Also, you can check this research paper Deep Residual Learning for Image Recognition.

In practice, it means that you need to sum your input and result of the convolutions. They better have the same shape :)

It all looks beautiful in theory, but how to implement it? Moreover, how to do it with Tensorflow.js?

Let’s write some code!

First I started with the article 18 Tips for Training your own Tensorflow.js Models in the Browser. There are plenty of good advice, but there are 2 problems I eager to eliminate:

  • There are code snippets, but no big picture in the end. It’s unclear how to put it all together
  • 2 years after it’s a bit outdated

Moving from tf.sequential()

Let’s create a simple convolutional network and convert it into one using tf.model()

And that is how it looks like with tf.model()

Nice and easy! Now you explicitly perform every operation.

Simple residual connection

Probably you can already guess that there is tf.layers.add() that we could use the same way. And you are right!

We’ve got our first residual neural network! Of course, it’s a bit useless in the current shape.

But as you might guess it’s not a production-ready network and it needs some adjustments. In the Part 2 I’m going to explain and describe a simplified version of the ResNet-18 done in the tfjs.

As a sneak peek, here is one of 3 pictures (and some papers, of course) that I was inspiring while building my first residual network in tfjs. Why so many and why you can not just use one — I will try to explain.

Part 2 is finished, check it out!

--

--

Yury Kozyrev

Former Yandex Software Engineer, passionate Engineering Manager in Berlin