Brijesh's Git Server — watchman @ 242871c6f0dd3c1670d49334a7f551c68b7a518e

observability tool, needs to be rewritten once identity is stable

middleware/rate_limiting.go (view raw)

 1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
package middleware

import (
	"encoding/json"
	"log"
	"net"
	"net/http"
	"sync"
	"time"
	"watchman/schema"

	"golang.org/x/time/rate"
)

// Defining visitor struct for use in array of unique visitors
type visitor struct {
	limiter  *rate.Limiter
	lastSeen time.Time
}

// The visitors map is used to keep track of the visitors based on their IP addresses
// The mu mutex is used to protect the visitors map from concurrent reads and writes
var (
	visitors = make(map[string]*visitor)
	mu       sync.Mutex
)

// Run cleanup function in a background go routine
func init() {
	go cleanupVisitors()
}

// Get the visitor from the visitors map based on the IP address
func getVisitor(ip string) *rate.Limiter {
	mu.Lock()
	defer mu.Unlock()

	v, exists := visitors[ip]
	if !exists {
		limiter := rate.NewLimiter(1, 3)
		// Include the current time when creating a new visitor
		visitors[ip] = &visitor{limiter, time.Now()}
		return limiter
	}

	// Update the last seen time for the visitor
	v.lastSeen = time.Now()
	return v.limiter
}

// Delete the visitor if it was last seen over 3 minutes ago
func cleanupVisitors() {
	for {
		time.Sleep(time.Minute)

		mu.Lock()
		for ip, v := range visitors {
			if time.Since(v.lastSeen) > 3*time.Minute {
				delete(visitors, ip)
			}
		}
		mu.Unlock()
	}
}

func Ratelimit(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ip, _, err := net.SplitHostPort(r.RemoteAddr)
		if err != nil {
			log.Print(err.Error())
			http.Error(w, "Internal Server Error", http.StatusInternalServerError)
			return
		}

		limiter := getVisitor(ip)
		if !limiter.Allow() {
			response := schema.Response_Type{
				Status:    "ERROR",
				Message:   "You made too many requests",
				RequestID: r.Context().Value(schema.RequestIDKey{}).(string),
			}

			w.Header().Set("Content-Type", "application/json")
			err := json.NewEncoder(w).Encode(response)
			if err != nil {
				http.Error(w, err.Error(), http.StatusInternalServerError)
				return
			}
			return
		}

		next.ServeHTTP(w, r)
	})
}