package gateway import ( "context" "os" "time" "github.com/redis/go-redis/v9" ) // RedisRateLimiter is a Redis-based rate limiter implementation // Use this in production for distributed rate limiting type RedisRateLimiter struct { client *redis.Client ctx context.Context config RateLimitConfig } // NewRedisRateLimiter creates a new Redis rate limiter func NewRedisRateLimiter(redisURL string, config RateLimitConfig) (*RedisRateLimiter, error) { opts, err := redis.ParseURL(redisURL) if err != nil { return nil, err } client := redis.NewClient(opts) ctx := context.Background() // Test connection if err := client.Ping(ctx).Err(); err != nil { return nil, err } return &RedisRateLimiter{ client: client, ctx: ctx, config: config, }, nil } // NewRedisRateLimiterFromClient creates a new Redis rate limiter from an existing client func NewRedisRateLimiterFromClient(client *redis.Client, config RateLimitConfig) *RedisRateLimiter { return &RedisRateLimiter{ client: client, ctx: context.Background(), config: config, } } // Allow checks if a request is allowed for the given key // Uses sliding window algorithm with Redis func (rl *RedisRateLimiter) Allow(key string) bool { now := time.Now() windowStart := now.Add(-time.Minute) // Use sorted set to track requests in the current window zsetKey := "ratelimit:" + key // Remove old entries (outside the window) rl.client.ZRemRangeByScore(rl.ctx, zsetKey, "0", formatTime(windowStart)) // Count requests in current window count, err := rl.client.ZCard(rl.ctx, zsetKey).Result() if err != nil { // On error, allow the request (fail open) return true } // Check if limit exceeded if int(count) >= rl.config.RequestsPerMinute { return false } // Add current request to the window member := formatTime(now) score := float64(now.Unix()) rl.client.ZAdd(rl.ctx, zsetKey, redis.Z{ Score: score, Member: member, }) // Set expiration on the key (cleanup) rl.client.Expire(rl.ctx, zsetKey, time.Minute*2) return true } // GetRemaining returns the number of requests remaining in the current window func (rl *RedisRateLimiter) GetRemaining(key string) int { now := time.Now() windowStart := now.Add(-time.Minute) zsetKey := "ratelimit:" + key // Remove old entries rl.client.ZRemRangeByScore(rl.ctx, zsetKey, "0", formatTime(windowStart)) // Count requests in current window count, err := rl.client.ZCard(rl.ctx, zsetKey).Result() if err != nil { return rl.config.RequestsPerMinute } remaining := rl.config.RequestsPerMinute - int(count) if remaining < 0 { return 0 } return remaining } // Reset resets the rate limit for a key func (rl *RedisRateLimiter) Reset(key string) error { zsetKey := "ratelimit:" + key return rl.client.Del(rl.ctx, zsetKey).Err() } // Close closes the Redis connection func (rl *RedisRateLimiter) Close() error { return rl.client.Close() } // formatTime formats time for Redis sorted set func formatTime(t time.Time) string { return t.Format(time.RFC3339Nano) } // NewRateLimiter creates a rate limiter based on environment // Returns Redis rate limiter if REDIS_URL is set, otherwise in-memory rate limiter func NewRateLimiter(config RateLimitConfig) (RateLimiter, error) { redisURL := os.Getenv("REDIS_URL") if redisURL != "" { return NewRedisRateLimiter(redisURL, config) } return NewInMemoryRateLimiter(config), nil }