Scaling websockets to million connections

Dhadve YashDhadve Yash
6 min read

Understanding how to scale

Scaling a websocket is a hard task compared to http, as the connections are persistent in websocket and we cannot share the same connection between 2 different servers.

Like in http we can just create many instances of our servers and load balance them using proxy like nginx or traefik.

But in case of websockets a server keeps the connection active with the client forever till the client disconnects.

As you can see in above image if client5 wants to send a message to client1 he can’t, as the ws server2 has no record of client1. Only if there was a way to bridge that gap…

Ooh yeah there indeed is one!

From above diagram, you can see we can now send messages from one server to other via redis. But what messages do we send? and to which servers? well lets just keep it simple and forward incoming messages to all the available servers.

But what if we keep track of all the connections instead of forwarding all the messages bunch of servers and we will just get the address of a server and send it to that server. Well it does sound appealing but cost of managing a websocket connection is higher than simply forwarding it.

Heres an example to better visualise it

A client sends a message to web socket server it is connected to. Then that server forwards it to redis. And all the servers subscribed to that redis will receive the message. And then each server will check if the reciever is in their pool of connections, if yes then send it to that client.

so now lets get building it. I will be using docker to run redis instance, if you don’t know how to use docker then heres a quick tutorial

Building websocket server

I will be using code from previous blog Chat app with golang and make changes to it

lets import the redis sdk -

go get github.com/redis/go-redis/v9

main.go

package main

import (
    // ... other imports
    "github.com/redis/go-redis/v9"
)


type RedisStore struct {
    db *redis.Client
}

func NewRedisStore() (*RedisStore, error) {
    rdb := redis.NewClient(&redis.Options{
        Addr:             fmt.Sprintf("%s:%s", "localhost", "6379"),
        Password:         "",
        DB:               0,
        DisableIndentity: true, // Disable set-info on connect
    })

    return &RedisStore{
        db: rdb,
    }, nil

}

func (s *RedisStore) PubRedis(ctx context.Context, mes msg) {
    messageJSON, err := json.Marshal(mes)
    if err != nil {
        log.Printf("Failed to serialize message: %v", err)
        return
    }

    s.db.Publish(ctx, "msg", messageJSON)
}

func (s *RedisStore) SubRedis() {
    subMsg := s.db.Subscribe(context.Background(), "msg")
    msgch := subMsg.Channel()

    go func() {
        for msgfromcha := range msgch {
            var mes msg
            err := json.Unmarshal([]byte(msgfromcha.Payload), &mes)
            if err != nil {
                log.Printf("Failed to unmarshal message: %v", err)
                continue
            }

            if UsernameToWebSocket[mes.Reciever] != nil {
                UsernameToWebSocket[mes.Reciever].WriteJSON(mes)
            }
        }
    }()
}

var Rstore *RedisStore

func main() {
    Rstore, _ = NewRedisStore()
    Rstore.SubRedis()

    // ... previous code
}

func handleIncomingMessage(sender *websocket.Conn, data []byte) error {
    // ... previous code
    switch DataRecieved.MessageType {
    case ChatMsg:
        // UsernameToWebSocket[DataRecieved.Reciever].WriteJSON(DataRecieved)
        Rstore.PubRedis(context.Background(), DataRecieved)
        // ... previous code
    }
    // ... previous code
}

Complete code -

package main

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "net/http"

    "github.com/gorilla/websocket"
    "github.com/redis/go-redis/v9"
)
var (
    upgrader = websocket.Upgrader{
        ReadBufferSize:  1024,
        WriteBufferSize: 1024,
        CheckOrigin: func(r *http.Request) bool {
            return true
        },
    }
    WebSocketToUsername = make(map[*websocket.Conn]string)
    UsernameToWebSocket = make(map[string]*websocket.Conn)
)

type RedisStore struct {
    db *redis.Client
}

func NewRedisStore() (*RedisStore, error) {
    rdb := redis.NewClient(&redis.Options{
        Addr:             fmt.Sprintf("%s:%s", "localhost", "6379"),
        Password:         "",
        DB:               0,
        DisableIndentity: true, // Disable set-info on connect
    })
    return &RedisStore{
        db: rdb,
    }, nil
}

func (s *RedisStore) PubRedis(ctx context.Context, mes msg) {
    messageJSON, err := json.Marshal(mes)
    if err != nil {
        log.Printf("Failed to serialize message: %v", err)
        return
    }

    s.db.Publish(ctx, "msg", messageJSON)
}

func (s *RedisStore) SubRedis() {
    subMsg := s.db.Subscribe(context.Background(), "msg")
    msgch := subMsg.Channel()

    go func() {
        for msgfromcha := range msgch {
            var mes msg
            err := json.Unmarshal([]byte(msgfromcha.Payload), &mes)
            if err != nil {
                log.Printf("Failed to unmarshal message: %v", err)
                continue
            }

            if UsernameToWebSocket[mes.Reciever] != nil {
                UsernameToWebSocket[mes.Reciever].WriteJSON(mes)
            }
        }
    }()
}

var Rstore *RedisStore

func main() {
    Rstore, _ = NewRedisStore()
    Rstore.SubRedis()

    http.HandleFunc("/ws", SocketHandler)
    fmt.Println("web soc running on port 9000")
    err := http.ListenAndServe(":9000", nil)
    if err != nil {
        fmt.Println(err)
    }
}

func SocketHandler(w http.ResponseWriter, r *http.Request) {
    ws, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Println("there was a connection error : ", err)
        return
    }
    defer ws.Close()

    for {
        _, bytes, err := ws.ReadMessage()
        if err != nil {
            handleDisconnection(ws)
            break
        }
        err1 := handleIncomingMessage(ws, bytes)
        if err1 != nil {
            log.Print("Error handling message", err1)
        }
    }
    handleDisconnection(ws)
}

func handleDisconnection(sender *websocket.Conn) {
    user_id, _ := WebSocketToUsername[sender]
    delete(WebSocketToUsername, sender)
    delete(UsernameToWebSocket, user_id)
}

type msg struct {
    MessageType string
    Data        string
    Reciever    string
    Sender      string
}

const ChatMsg = "chatmsg"
const LoginMsg = "loginmsg"

func handleIncomingMessage(sender *websocket.Conn, data []byte) error {
    var DataRecieved msg
    err := json.Unmarshal(data, &DataRecieved)
    if err != nil {
        return err
    }
    fmt.Println(DataRecieved)

    switch DataRecieved.MessageType {
    case ChatMsg:
        Rstore.PubRedis(context.Background(), DataRecieved)
    case LoginMsg:
        if _, ok := UsernameToWebSocket[DataRecieved.Sender]; ok {
            sender.WriteJSON("User already exists")
            return nil
        }
        WebSocketToUsername[sender] = DataRecieved.Sender
        UsernameToWebSocket[DataRecieved.Sender] = sender
    }
    return nil
}

Setting up Redis PubSub using Docker-compose.yaml

version: '3.9'
services:
  redis-ws:
    image: redis:6.2-alpine
    container_name: redis-ws
    ports:
      - 6379:6379

Running our application

now open 3 terminals and enter these commands -

go run main.go
go run main.go # in this server change the port
docker-compose up

And we are done 👍 you can test it using postman like we did in previous blog.

Load balancing

`Thats great! you are now able to scale you application to million users by horizontal scaling you websocket servers using redis. But theres still a major problem here, and that is we are creating multiple severs with different ports and possibly different IP’s. so how do we manage that? For that we will be using a proxy server exposed to our side world(client) which takes in initial https request coming from clients and distribute them to our internal websocket servers.

I will be using nginx with docker for load balancing. if you don’t know how to use nginx and docker heres a Nginx and docker tutorial

we just have to add following nginx.conf and docker-compose.yaml files to our project and we are done

version: '3.9'
services:

  nginx:
    image: nginx:latest
    container_name: load_bala
    ports:
      - "8080:80"
    environment:
      - NGINX_PORT=80
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - serv1
      - serv2

  redis-ws:
    image: redis:6.2-alpine
    container_name: redis-ws
    ports:
      - 6379:6379

  serv1:
    build: ./
    container_name: serv1
    ports:
      - 9000:9000
    depends_on:
      - redis-ws

  serv2:
    build: ./
    container_name: serv2
    ports:
      - 9001:9000
    depends_on:
      - redis-ws
events {

}
http {

    upstream backend {
        server serv1:9000;
        server serv2:9000;
    }
    server {
        listen 80;
        server_name _;

        # location / {
        #     add_header Cache-Control no-store;
        #     proxy_pass http://backend;
        # }

        location / {
            proxy_pass http://backend;
            proxy_http_version 1.1;
            proxy_set_header Upgrade $http_upgrade;
            proxy_set_header Connection "upgrade";
            proxy_set_header Host $host;
        }
    }
}
// also in our server, we have to change the redis configuration
// we have to change our ip from localhost to redis-ws which is the name of our redis container
    rdb = redis.NewClient(&redis.Options{
        Addr:             fmt.Sprintf("%s:%s", "redis-ws" /* "localhost" */, "6379"),
        Password:         "",
        DB:               0,
        DisableIndentity: true, // Disable set-info on connect
    })

now just do docker compose up in your terminal and you are on !!

0
Subscribe to my newsletter

Read articles from Dhadve Yash directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Dhadve Yash
Dhadve Yash