136 lines
3.3 KiB
Go
136 lines
3.3 KiB
Go
package gateway
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// RedisRateLimiter is a Redis-based rate limiter implementation
|
|
// Use this in production for distributed rate limiting
|
|
type RedisRateLimiter struct {
|
|
client *redis.Client
|
|
ctx context.Context
|
|
config RateLimitConfig
|
|
}
|
|
|
|
// NewRedisRateLimiter creates a new Redis rate limiter
|
|
func NewRedisRateLimiter(redisURL string, config RateLimitConfig) (*RedisRateLimiter, error) {
|
|
opts, err := redis.ParseURL(redisURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := redis.NewClient(opts)
|
|
ctx := context.Background()
|
|
|
|
// Test connection
|
|
if err := client.Ping(ctx).Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &RedisRateLimiter{
|
|
client: client,
|
|
ctx: ctx,
|
|
config: config,
|
|
}, nil
|
|
}
|
|
|
|
// NewRedisRateLimiterFromClient creates a new Redis rate limiter from an existing client
|
|
func NewRedisRateLimiterFromClient(client *redis.Client, config RateLimitConfig) *RedisRateLimiter {
|
|
return &RedisRateLimiter{
|
|
client: client,
|
|
ctx: context.Background(),
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request is allowed for the given key
|
|
// Uses sliding window algorithm with Redis
|
|
func (rl *RedisRateLimiter) Allow(key string) bool {
|
|
now := time.Now()
|
|
windowStart := now.Add(-time.Minute)
|
|
|
|
// Use sorted set to track requests in the current window
|
|
zsetKey := "ratelimit:" + key
|
|
|
|
// Remove old entries (outside the window)
|
|
rl.client.ZRemRangeByScore(rl.ctx, zsetKey, "0", formatTime(windowStart))
|
|
|
|
// Count requests in current window
|
|
count, err := rl.client.ZCard(rl.ctx, zsetKey).Result()
|
|
if err != nil {
|
|
// On error, allow the request (fail open)
|
|
return true
|
|
}
|
|
|
|
// Check if limit exceeded
|
|
if int(count) >= rl.config.RequestsPerMinute {
|
|
return false
|
|
}
|
|
|
|
// Add current request to the window
|
|
member := formatTime(now)
|
|
score := float64(now.Unix())
|
|
rl.client.ZAdd(rl.ctx, zsetKey, redis.Z{
|
|
Score: score,
|
|
Member: member,
|
|
})
|
|
|
|
// Set expiration on the key (cleanup)
|
|
rl.client.Expire(rl.ctx, zsetKey, time.Minute*2)
|
|
|
|
return true
|
|
}
|
|
|
|
// GetRemaining returns the number of requests remaining in the current window
|
|
func (rl *RedisRateLimiter) GetRemaining(key string) int {
|
|
now := time.Now()
|
|
windowStart := now.Add(-time.Minute)
|
|
zsetKey := "ratelimit:" + key
|
|
|
|
// Remove old entries
|
|
rl.client.ZRemRangeByScore(rl.ctx, zsetKey, "0", formatTime(windowStart))
|
|
|
|
// Count requests in current window
|
|
count, err := rl.client.ZCard(rl.ctx, zsetKey).Result()
|
|
if err != nil {
|
|
return rl.config.RequestsPerMinute
|
|
}
|
|
|
|
remaining := rl.config.RequestsPerMinute - int(count)
|
|
if remaining < 0 {
|
|
return 0
|
|
}
|
|
return remaining
|
|
}
|
|
|
|
// Reset resets the rate limit for a key
|
|
func (rl *RedisRateLimiter) Reset(key string) error {
|
|
zsetKey := "ratelimit:" + key
|
|
return rl.client.Del(rl.ctx, zsetKey).Err()
|
|
}
|
|
|
|
// Close closes the Redis connection
|
|
func (rl *RedisRateLimiter) Close() error {
|
|
return rl.client.Close()
|
|
}
|
|
|
|
// formatTime formats time for Redis sorted set
|
|
func formatTime(t time.Time) string {
|
|
return t.Format(time.RFC3339Nano)
|
|
}
|
|
|
|
// NewRateLimiter creates a rate limiter based on environment
|
|
// Returns Redis rate limiter if REDIS_URL is set, otherwise in-memory rate limiter
|
|
func NewRateLimiter(config RateLimitConfig) (RateLimiter, error) {
|
|
redisURL := os.Getenv("REDIS_URL")
|
|
if redisURL != "" {
|
|
return NewRedisRateLimiter(redisURL, config)
|
|
}
|
|
return NewInMemoryRateLimiter(config), nil
|
|
}
|
|
|