Rate limiter with Redis and Golang
Rate limiter is a mechanism that controls the number of requests or tasks performed within a certain period. It helps prevent excessive resource usage or denial of service (DDoS) attacks. When the number of requests exceeds a set limit, subsequent requests may be denied or delayed until the limit is reset. This ensures that the system operates stably and fairly for all users.
Some algorithm applies to implement Rate limiter
Leaky Bucket
Fixed Window Counter
Sliding Window Log
Sliding Window Counter
Now, there are a lot of services that support and provide for you to config Rate limiter on your website as soon as
Today, I will guide you through the implementation Rate limiter with Redis, Golang, and Fixed Window Counter algorithm:
Concept: The Fixed Window Counter algorithm counts the number of requests in a fixed period, called a "window". For example, you can define a window of 1 minute.
Requests: Every time a request arrives, the system checks the number of requests made in the current window. The new request will be rejected if the number of requests exceeds the specified limit during that period. The window is reset every fixed period.
Now, let's implement with Golang
Init the Redis connection
func initRedis(redisUrl string) (*redis.Client, error) {
opts, err := redis.ParseURL(redisUrl)
if err != nil {
log.Fatal("failed to connect redis:", err)
return nil, nil
}
opts.PoolSize = 30
opts.ReadTimeout = 5 * time.Second
opts.WriteTimeout = 5 * time.Second
opts.Username = ""
redisClient := redis.NewClient(opts)
cmd := redisClient.Ping(context.Background())
if cmd.Err() != nil {
log.Fatal("failed to ping redis: ", cmd.Err())
return nil, nil
}
return redisClient, nil
}
The function gets IP from the request. I will limit the number of requests in a minute following client's IP with 60 requests / 1 minute
func getIPFromRequest(r *http.Request) string {
ips := r.Header.Get("X-Forwarded-For")
ipList := strings.Split(ips, ",")
for _, ip := range ipList {
if ip = strings.TrimSpace(ip); ip != "" && ip != "::1" && ip != "127.0.0.1" {
return ip
}
}
ip := r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
Middleware to check rate limit. I will write a Lua script with Redis. I will increase one unit when the request is accessed and check if the counter is greater than the maximum number of requests in a minute then return "not pass" or else "pass". If the request passes the middleware then It forwards it to the service to process and return status code 200 else it returns status code 429 for too many requests.
func (h *HandlerAPI) RateLimiter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Origin, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// rate limit
ip := getIPFromRequest(r)
script := `
local currentCount = tonumber(redis.call('GET', KEYS[1]) or '0')
if currentCount == 0 then
redis.call('SET', KEYS[1], 0, 'EX', ARGV[1])
end
redis.call('SET', KEYS[1], currentCount + 1, 'KEEPTTL')
if currentCount > tonumber(ARGV[2]) then
return "not pass"
else
return "pass"
end`
// Running Lua Script
resultStr, err := h.RedisClient.Eval(context.Background(), script, []string{ip}, 60, MaxRequestOneMinute).Result()
if err != nil {
logrus.Warnf("Running Lua Script is failed with err: %v", err)
return
}
if resultStr == "not pass" {
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
Test handler will return "Hello Viet Nam" with status code 200
func (h *HandlerAPI) testHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusOK)
_, err := fmt.Fprintf(w, "Hello Viet Nam")
if err != nil {
return
}
}
The main goroutine
func main() {
redisClient, err := initRedis("redis://default:@localhost:6379")
if err != nil {
panic("failed to init redis")
}
handler := HandlerAPI{
RedisClient: redisClient,
}
mux := http.NewServeMux()
mux.Handle("/test", handler.RateLimiter(http.HandlerFunc(handler.testHandler)))
// Start the server
log.Fatal(http.ListenAndServe(":3000", mux))
}
I will write the function to test the Rate limiter. I test to call 100 current request.
func TestHandlerAPI_testHandler(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
callAPI()
}()
}
wg.Wait()
fmt.Println("finish")
}
func callAPI() {
url := "http://localhost:3000/test"
method := "GET"
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
fmt.Println(err)
return
}
res, err := client.Do(req)
if err != nil {
fmt.Println(err)
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
return
}
}(res.Body)
fmt.Println(res.StatusCode)
}
The result
Reference:
Subscribe to my newsletter
Read articles from Nguyen Van Tuan directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Nguyen Van Tuan
Nguyen Van Tuan
I'm Tuan. I graduated Hanoi University of Science and Technology in 2019 Major: Information Technology Leetcode : nguyenvantuan2391996 My blog: https://tuannguyenhust.hashnode.dev/ Linkedin : Tuan Nguyen Van