Harden explorer AI runtime and API ownership
This commit is contained in:
292
backend/api/rest/ai_runtime.go
Normal file
292
backend/api/rest/ai_runtime.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AIRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
entries map[string][]time.Time
|
||||
}
|
||||
|
||||
func NewAIRateLimiter() *AIRateLimiter {
|
||||
return &AIRateLimiter{
|
||||
entries: make(map[string][]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AIRateLimiter) Allow(key string, limit int, window time.Duration) (bool, time.Duration) {
|
||||
if limit <= 0 {
|
||||
return true, 0
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-window)
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
timestamps := l.entries[key]
|
||||
kept := timestamps[:0]
|
||||
for _, ts := range timestamps {
|
||||
if ts.After(cutoff) {
|
||||
kept = append(kept, ts)
|
||||
}
|
||||
}
|
||||
|
||||
if len(kept) >= limit {
|
||||
retryAfter := kept[0].Add(window).Sub(now)
|
||||
l.entries[key] = kept
|
||||
if retryAfter < 0 {
|
||||
retryAfter = 0
|
||||
}
|
||||
return false, retryAfter
|
||||
}
|
||||
|
||||
kept = append(kept, now)
|
||||
l.entries[key] = kept
|
||||
return true, 0
|
||||
}
|
||||
|
||||
type AIMetrics struct {
|
||||
mu sync.Mutex
|
||||
ContextRequests int64 `json:"contextRequests"`
|
||||
ChatRequests int64 `json:"chatRequests"`
|
||||
RateLimited int64 `json:"rateLimited"`
|
||||
UpstreamFailures int64 `json:"upstreamFailures"`
|
||||
LastRequestAt string `json:"lastRequestAt,omitempty"`
|
||||
LastErrorCode string `json:"lastErrorCode,omitempty"`
|
||||
StatusCounts map[string]int64 `json:"statusCounts"`
|
||||
ErrorCounts map[string]int64 `json:"errorCounts"`
|
||||
LastDurationsMs map[string]float64 `json:"lastDurationsMs"`
|
||||
LastRequests []map[string]string `json:"lastRequests"`
|
||||
}
|
||||
|
||||
func NewAIMetrics() *AIMetrics {
|
||||
return &AIMetrics{
|
||||
StatusCounts: make(map[string]int64),
|
||||
ErrorCounts: make(map[string]int64),
|
||||
LastDurationsMs: make(map[string]float64),
|
||||
LastRequests: []map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AIMetrics) Record(endpoint string, statusCode int, duration time.Duration, errorCode, clientIP string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if endpoint == "context" {
|
||||
m.ContextRequests++
|
||||
}
|
||||
if endpoint == "chat" {
|
||||
m.ChatRequests++
|
||||
}
|
||||
if errorCode == "rate_limited" {
|
||||
m.RateLimited++
|
||||
}
|
||||
if strings.HasPrefix(errorCode, "upstream_") {
|
||||
m.UpstreamFailures++
|
||||
}
|
||||
|
||||
statusKey := endpoint + ":" + http.StatusText(statusCode)
|
||||
m.StatusCounts[statusKey]++
|
||||
if errorCode != "" {
|
||||
m.ErrorCounts[errorCode]++
|
||||
m.LastErrorCode = errorCode
|
||||
}
|
||||
m.LastRequestAt = time.Now().UTC().Format(time.RFC3339)
|
||||
m.LastDurationsMs[endpoint] = float64(duration.Milliseconds())
|
||||
m.LastRequests = append([]map[string]string{{
|
||||
"endpoint": endpoint,
|
||||
"status": http.StatusText(statusCode),
|
||||
"statusCode": http.StatusText(statusCode),
|
||||
"clientIp": clientIP,
|
||||
"at": m.LastRequestAt,
|
||||
"errorCode": errorCode,
|
||||
}}, m.LastRequests...)
|
||||
if len(m.LastRequests) > 12 {
|
||||
m.LastRequests = m.LastRequests[:12]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AIMetrics) Snapshot() map[string]any {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
statusCounts := make(map[string]int64, len(m.StatusCounts))
|
||||
for key, value := range m.StatusCounts {
|
||||
statusCounts[key] = value
|
||||
}
|
||||
|
||||
errorCounts := make(map[string]int64, len(m.ErrorCounts))
|
||||
for key, value := range m.ErrorCounts {
|
||||
errorCounts[key] = value
|
||||
}
|
||||
|
||||
lastDurations := make(map[string]float64, len(m.LastDurationsMs))
|
||||
for key, value := range m.LastDurationsMs {
|
||||
lastDurations[key] = value
|
||||
}
|
||||
|
||||
lastRequests := make([]map[string]string, len(m.LastRequests))
|
||||
for i := range m.LastRequests {
|
||||
copyMap := make(map[string]string, len(m.LastRequests[i]))
|
||||
for key, value := range m.LastRequests[i] {
|
||||
copyMap[key] = value
|
||||
}
|
||||
lastRequests[i] = copyMap
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"contextRequests": m.ContextRequests,
|
||||
"chatRequests": m.ChatRequests,
|
||||
"rateLimited": m.RateLimited,
|
||||
"upstreamFailures": m.UpstreamFailures,
|
||||
"lastRequestAt": m.LastRequestAt,
|
||||
"lastErrorCode": m.LastErrorCode,
|
||||
"statusCounts": statusCounts,
|
||||
"errorCounts": errorCounts,
|
||||
"lastDurationsMs": lastDurations,
|
||||
"lastRequests": lastRequests,
|
||||
}
|
||||
}
|
||||
|
||||
func clientIPAddress(r *http.Request) string {
|
||||
for _, header := range []string{"X-Forwarded-For", "X-Real-IP"} {
|
||||
if raw := strings.TrimSpace(r.Header.Get(header)); raw != "" {
|
||||
if header == "X-Forwarded-For" {
|
||||
parts := strings.Split(raw, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
return strings.TrimSpace(r.RemoteAddr)
|
||||
}
|
||||
|
||||
func explorerAIContextRateLimit() (int, time.Duration) {
|
||||
return 60, time.Minute
|
||||
}
|
||||
|
||||
func explorerAIChatRateLimit() (int, time.Duration) {
|
||||
return 12, time.Minute
|
||||
}
|
||||
|
||||
func (s *Server) allowAIRequest(r *http.Request, endpoint string) (bool, time.Duration) {
|
||||
limit := 0
|
||||
window := time.Minute
|
||||
switch endpoint {
|
||||
case "context":
|
||||
limit, window = explorerAIContextRateLimit()
|
||||
case "chat":
|
||||
limit, window = explorerAIChatRateLimit()
|
||||
}
|
||||
|
||||
clientIP := clientIPAddress(r)
|
||||
return s.aiLimiter.Allow(endpoint+":"+clientIP, limit, window)
|
||||
}
|
||||
|
||||
func (s *Server) logAIRequest(endpoint string, statusCode int, duration time.Duration, clientIP, model, errorCode string) {
|
||||
statusText := http.StatusText(statusCode)
|
||||
if statusText == "" {
|
||||
statusText = "unknown"
|
||||
}
|
||||
log.Printf("AI endpoint=%s status=%d duration_ms=%d client_ip=%s model=%s error_code=%s", endpoint, statusCode, duration.Milliseconds(), clientIP, model, errorCode)
|
||||
}
|
||||
|
||||
func (s *Server) handleAIMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeMethodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
|
||||
contextLimit, contextWindow := explorerAIContextRateLimit()
|
||||
chatLimit, chatWindow := explorerAIChatRateLimit()
|
||||
|
||||
response := map[string]any{
|
||||
"generatedAt": time.Now().UTC().Format(time.RFC3339),
|
||||
"rateLimits": map[string]any{
|
||||
"context": map[string]any{
|
||||
"requests": contextLimit,
|
||||
"window": contextWindow.String(),
|
||||
},
|
||||
"chat": map[string]any{
|
||||
"requests": chatLimit,
|
||||
"window": chatWindow.String(),
|
||||
},
|
||||
},
|
||||
"metrics": s.aiMetrics.Snapshot(),
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, statusCode int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
type AIUpstreamError struct {
|
||||
StatusCode int
|
||||
Code string
|
||||
Message string
|
||||
Details string
|
||||
}
|
||||
|
||||
func (e *AIUpstreamError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Details != "" {
|
||||
return e.Message + ": " + e.Details
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func mapAIUpstreamError(err error) (int, string, string, string) {
|
||||
if err == nil {
|
||||
return http.StatusOK, "", "", ""
|
||||
}
|
||||
upstreamErr, ok := err.(*AIUpstreamError)
|
||||
if !ok {
|
||||
return http.StatusBadGateway, "bad_gateway", "explorer ai request failed", err.Error()
|
||||
}
|
||||
|
||||
switch upstreamErr.Code {
|
||||
case "upstream_quota_exhausted":
|
||||
return http.StatusServiceUnavailable, upstreamErr.Code, "explorer ai upstream quota exhausted", upstreamErr.Details
|
||||
case "upstream_auth_failed":
|
||||
return http.StatusBadGateway, upstreamErr.Code, "explorer ai upstream authentication failed", upstreamErr.Details
|
||||
case "upstream_timeout":
|
||||
return http.StatusGatewayTimeout, upstreamErr.Code, "explorer ai upstream timed out", upstreamErr.Details
|
||||
case "upstream_bad_response":
|
||||
return http.StatusBadGateway, upstreamErr.Code, "explorer ai upstream returned an invalid response", upstreamErr.Details
|
||||
default:
|
||||
return http.StatusBadGateway, upstreamErr.Code, upstreamErr.Message, upstreamErr.Details
|
||||
}
|
||||
}
|
||||
|
||||
func writeErrorDetailed(w http.ResponseWriter, statusCode int, code, message, details string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
},
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user