Stormin' The Castle


Unleash the GPU in your Colab Notebook

by John Robinson @johnrobinsn

Colab Notebooks
I'm lucky enough to have my own machine learning rig with a couple of Nvidia Titan RTX GPUs for my own dev. But the GPU-centric nature of most ML projects these days makes it pretty hard to share the results of ML experiments. Most people just don't have CUDA-capable hardware readily available. Google Colabs to the rescue. But Colab notebooks have their limitations, especially in terms of interactivity. In this article, I'm going to show you at least one way to solve for this.

Colabs are based on Jupyter notebooks. Notebooks are used for rapidly running experiments and data visualizations using python and data. Notebooks have become an indispensible way for folks to prototype, experiment and share machine learning and data science work. Colabs are a hosted product that is "free" in the Google sense. But the big killer feature of Colabs is that Google provides "free" GPU runtimes instances. It's as easy as creating a Colab notebook and selecting the "GPU" option for the runtime type of the notebook. They even have support for TPUs as well.

Colabs are a great tool when you want to share machine learning notebooks backed with "free" GPUs. But while I love notebooks for prototyping and experimenting with datasets and machine learning models. The notebook metaphor can get in the way, especially when you want to demonstrate something with more interactivity. After all these notebooks are running in a browser. Isn't there a way to unlock the full power of the browser platform and bridge that to the GPU-powered runtime that is powering the notebook.

Sometimes you just want a damn web app.

So how can we unlock this? As it turns out the python platform as supported by Colabs is quite powerful and extensibile. Below I'll describe how to run a Flask application server embedded within your Colab notebook and also show how you can host a web application within your notebook that can take full advantage of the browser platform.

Here is a link to a Google Colab Notebook that demonstrates this working.

The following code block shows how we can use python threads to run a flask server on another thread within the notebook itself quite easily. Refer to the notebook to try it out for yourself. One trick is how to cleanly bounce or restart this server once the thread has been spawned. We can do this by making our own threading subclass and by leveraging the werzeug.serving class directly. In this way, you can just reload the cell to stop and restart the flask server to pick up any changes that we've made to routes etc.

I won't go too deeply into Flask specifics here. But there are plenty of resources available.

The Flask App Server

# Run a python(flask)-based web service in your note book
# You can reload this cell to restart the server if you make changes

default_port = 6060

from werkzeug.serving import make_server
from flask import Flask
import threading

class ServerThread(threading.Thread):

def __init__(self, app, port):
self.port = port
self.srv = make_server('', port, app)
self.ctx = app.app_context()

def run(self):
print('starting server on port:',self.port)

def shutdown(self):

def start_server(port=default_port):
global server
if 'server' in globals() and server:
print('stopping server')

app = Flask('myapp')

# you can add your own routes here as needed
def hello():
# A wee bit o'html
return '<h1 style="color:red;">Hello From Flask!</h1>'

server = ServerThread(app,port)

def stop_server():
global server
if server:
server = None

# Start the server here

Just to demonstrate that our flask app server is up and running, We can make an HTTP request to the listening port of the app server.

import requests

r = requests.get('http://localhost:6060')

A Web Application

From the embedded Flask app server, we can serve up our web application and access the python context (with our GPU) through web service calls to the Flask server. We can host our web application inside of an iframe that we create within our notebook dynamically with the following bit of python.

import IPython.display

def display(port, height):
shell = """
(async () => {
const url = await google.colab.kernel.proxyPort(%PORT%, {"cache": true});
const iframe = document.createElement('iframe');
iframe.src = url;
iframe.setAttribute('width', '100%');
iframe.setAttribute('height', '%HEIGHT%');
iframe.setAttribute('frameborder', 0);
replacements = [
("%PORT%", "%d" % port),
("%HEIGHT%", "%d" % height),
for (k, v) in replacements:
shell = shell.replace(k, v)

script = IPython.display.Javascript(shell)

display(6060, 400)


In a future article, I'll show you how to use the facility that I've describe in this article to provide a dynamic 3d visualization of a machine learning model.

Stay Tuned...

Share on Twitter |  Discuss on Twitter

John Robinson © 2022-2023