Compare commits
1 Commits
feat/block
...
devin/1776
| Author | SHA1 | Date | |
|---|---|---|---|
| 66f35fa2aa |
@@ -1,7 +1,6 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -31,11 +30,7 @@ func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add user context
|
ctx := ContextWithAuth(r.Context(), address, track, true)
|
||||||
ctx := context.WithValue(r.Context(), "user_address", address)
|
|
||||||
ctx = context.WithValue(ctx, "user_track", track)
|
|
||||||
ctx = context.WithValue(ctx, "authenticated", true)
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -44,11 +39,7 @@ func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
|
|||||||
func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler {
|
func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Extract track from context (set by RequireAuth or OptionalAuth)
|
track := UserTrack(r.Context())
|
||||||
track, ok := r.Context().Value("user_track").(int)
|
|
||||||
if !ok {
|
|
||||||
track = 1 // Default to Track 1 (public)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !featureflags.HasAccess(track, requiredTrack) {
|
if !featureflags.HasAccess(track, requiredTrack) {
|
||||||
writeForbidden(w, requiredTrack)
|
writeForbidden(w, requiredTrack)
|
||||||
@@ -65,40 +56,33 @@ func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
address, track, err := m.extractAuth(r)
|
address, track, err := m.extractAuth(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// No auth provided, default to Track 1 (public)
|
// No auth provided (or auth failed) — fall back to Track 1.
|
||||||
ctx := context.WithValue(r.Context(), "user_address", "")
|
ctx := ContextWithAuth(r.Context(), "", defaultTrackLevel, false)
|
||||||
ctx = context.WithValue(ctx, "user_track", 1)
|
|
||||||
ctx = context.WithValue(ctx, "authenticated", false)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth provided, add user context
|
ctx := ContextWithAuth(r.Context(), address, track, true)
|
||||||
ctx := context.WithValue(r.Context(), "user_address", address)
|
|
||||||
ctx = context.WithValue(ctx, "user_track", track)
|
|
||||||
ctx = context.WithValue(ctx, "authenticated", true)
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractAuth extracts authentication information from request
|
// extractAuth extracts authentication information from the request.
|
||||||
|
// Returns ErrMissingAuthorization when no usable Bearer token is present;
|
||||||
|
// otherwise returns the error from JWT validation.
|
||||||
func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) {
|
func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) {
|
||||||
// Get Authorization header
|
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
return "", 0, http.ErrMissingFile
|
return "", 0, ErrMissingAuthorization
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for Bearer token
|
|
||||||
parts := strings.Split(authHeader, " ")
|
parts := strings.Split(authHeader, " ")
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
return "", 0, http.ErrMissingFile
|
return "", 0, ErrMissingAuthorization
|
||||||
}
|
}
|
||||||
|
|
||||||
token := parts[1]
|
token := parts[1]
|
||||||
|
|
||||||
// Validate JWT token
|
|
||||||
address, track, err := m.walletAuth.ValidateJWT(token)
|
address, track, err := m.walletAuth.ValidateJWT(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, err
|
return "", 0, err
|
||||||
|
|||||||
60
backend/api/middleware/context.go
Normal file
60
backend/api/middleware/context.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ctxKey is an unexported type for request-scoped authentication values.
|
||||||
|
// Using a distinct type (rather than a bare string) keeps our keys out of
|
||||||
|
// collision range for any other package that also calls context.WithValue,
|
||||||
|
// and silences go vet's SA1029.
|
||||||
|
type ctxKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ctxKeyUserAddress ctxKey = "user_address"
|
||||||
|
ctxKeyUserTrack ctxKey = "user_track"
|
||||||
|
ctxKeyAuthenticated ctxKey = "authenticated"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default track level applied to unauthenticated requests (Track 1 = public).
|
||||||
|
const defaultTrackLevel = 1
|
||||||
|
|
||||||
|
// ErrMissingAuthorization is returned by extractAuth when no usable
|
||||||
|
// Authorization header is present on the request. Callers should treat this
|
||||||
|
// as "no auth supplied" rather than a hard failure for optional-auth routes.
|
||||||
|
var ErrMissingAuthorization = errors.New("middleware: authorization header missing or malformed")
|
||||||
|
|
||||||
|
// ContextWithAuth returns a child context carrying the supplied
|
||||||
|
// authentication state. It is the single place in the package that writes
|
||||||
|
// the auth context keys.
|
||||||
|
func ContextWithAuth(parent context.Context, address string, track int, authenticated bool) context.Context {
|
||||||
|
ctx := context.WithValue(parent, ctxKeyUserAddress, address)
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyUserTrack, track)
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyAuthenticated, authenticated)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAddress returns the authenticated wallet address stored on ctx, or
|
||||||
|
// "" if the context is not authenticated.
|
||||||
|
func UserAddress(ctx context.Context) string {
|
||||||
|
addr, _ := ctx.Value(ctxKeyUserAddress).(string)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserTrack returns the access tier recorded on ctx. If no track was set
|
||||||
|
// (e.g. the request bypassed all auth middleware) the caller receives
|
||||||
|
// Track 1 (public) so route-level checks can still make a decision.
|
||||||
|
func UserTrack(ctx context.Context) int {
|
||||||
|
if track, ok := ctx.Value(ctxKeyUserTrack).(int); ok {
|
||||||
|
return track
|
||||||
|
}
|
||||||
|
return defaultTrackLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticated reports whether the current request carried a valid auth
|
||||||
|
// token that was successfully parsed by the middleware.
|
||||||
|
func IsAuthenticated(ctx context.Context) bool {
|
||||||
|
ok, _ := ctx.Value(ctxKeyAuthenticated).(bool)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
62
backend/api/middleware/context_test.go
Normal file
62
backend/api/middleware/context_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextWithAuthRoundTrip(t *testing.T) {
|
||||||
|
ctx := ContextWithAuth(context.Background(), "0xabc", 4, true)
|
||||||
|
|
||||||
|
if got := UserAddress(ctx); got != "0xabc" {
|
||||||
|
t.Fatalf("UserAddress() = %q, want %q", got, "0xabc")
|
||||||
|
}
|
||||||
|
if got := UserTrack(ctx); got != 4 {
|
||||||
|
t.Fatalf("UserTrack() = %d, want 4", got)
|
||||||
|
}
|
||||||
|
if !IsAuthenticated(ctx) {
|
||||||
|
t.Fatal("IsAuthenticated() = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserTrackDefaultsToTrack1OnBareContext(t *testing.T) {
|
||||||
|
if got := UserTrack(context.Background()); got != defaultTrackLevel {
|
||||||
|
t.Fatalf("UserTrack(empty) = %d, want %d", got, defaultTrackLevel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserAddressEmptyOnBareContext(t *testing.T) {
|
||||||
|
if got := UserAddress(context.Background()); got != "" {
|
||||||
|
t.Fatalf("UserAddress(empty) = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAuthenticatedFalseOnBareContext(t *testing.T) {
|
||||||
|
if IsAuthenticated(context.Background()) {
|
||||||
|
t.Fatal("IsAuthenticated(empty) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContextKeyIsolation proves that the typed ctxKey values cannot be
|
||||||
|
// shadowed by a caller using bare-string keys with the same spelling.
|
||||||
|
// This is the specific class of bug fixed by this PR.
|
||||||
|
func TestContextKeyIsolation(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), "user_address", "injected")
|
||||||
|
if got := UserAddress(ctx); got != "" {
|
||||||
|
t.Fatalf("expected empty address (bare string key must not collide), got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrMissingAuthorizationIsSentinel(t *testing.T) {
|
||||||
|
if ErrMissingAuthorization == nil {
|
||||||
|
t.Fatal("ErrMissingAuthorization must not be nil")
|
||||||
|
}
|
||||||
|
wrapped := errors.New("wrapped: " + ErrMissingAuthorization.Error())
|
||||||
|
if errors.Is(wrapped, ErrMissingAuthorization) {
|
||||||
|
t.Fatal("string-wrapped error must not satisfy errors.Is (smoke check)")
|
||||||
|
}
|
||||||
|
if !errors.Is(ErrMissingAuthorization, ErrMissingAuthorization) {
|
||||||
|
t.Fatal("ErrMissingAuthorization must satisfy errors.Is against itself")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/explorer/backend/api/middleware"
|
||||||
"github.com/explorer/backend/featureflags"
|
"github.com/explorer/backend/featureflags"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,11 +17,8 @@ func (s *Server) handleFeatures(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract user track from context (set by auth middleware)
|
// Extract user track from context (set by auth middleware)
|
||||||
// Default to Track 1 (public) if not authenticated
|
// Default to Track 1 (public) if not authenticated (handled by helper).
|
||||||
userTrack := 1
|
userTrack := middleware.UserTrack(r.Context())
|
||||||
if track, ok := r.Context().Value("user_track").(int); ok {
|
|
||||||
userTrack = track
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get enabled features for this track
|
// Get enabled features for this track
|
||||||
enabledFeatures := featureflags.GetEnabledFeatures(userTrack)
|
enabledFeatures := featureflags.GetEnabledFeatures(userTrack)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/explorer/backend/api/middleware"
|
||||||
"github.com/explorer/backend/auth"
|
"github.com/explorer/backend/auth"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
@@ -185,7 +186,7 @@ func (s *Server) requireOperatorAccess(w http.ResponseWriter, r *http.Request) (
|
|||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
operatorAddr, _ := r.Context().Value("user_address").(string)
|
operatorAddr := middleware.UserAddress(r.Context())
|
||||||
operatorAddr = strings.TrimSpace(operatorAddr)
|
operatorAddr = strings.TrimSpace(operatorAddr)
|
||||||
if operatorAddr == "" {
|
if operatorAddr == "" {
|
||||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required")
|
writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required")
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/explorer/backend/api/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
type runScriptRequest struct {
|
type runScriptRequest struct {
|
||||||
@@ -67,7 +69,7 @@ func (s *Server) HandleRunScript(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
operatorAddr, _ := r.Context().Value("user_address").(string)
|
operatorAddr := middleware.UserAddress(r.Context())
|
||||||
if operatorAddr == "" {
|
if operatorAddr == "" {
|
||||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required")
|
writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/explorer/backend/api/middleware"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ func TestHandleRunScriptUsesForwardedClientIPAndRunsAllowlistedScript(t *testing
|
|||||||
|
|
||||||
reqBody := []byte(`{"script":"echo.sh","args":["world"]}`)
|
reqBody := []byte(`{"script":"echo.sh","args":["world"]}`)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader(reqBody))
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader(reqBody))
|
||||||
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
|
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
|
||||||
req.RemoteAddr = "10.0.0.10:8080"
|
req.RemoteAddr = "10.0.0.10:8080"
|
||||||
req.Header.Set("X-Forwarded-For", "203.0.113.9, 10.0.0.10")
|
req.Header.Set("X-Forwarded-For", "203.0.113.9, 10.0.0.10")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -77,7 +78,7 @@ func TestHandleRunScriptRejectsNonAllowlistedScript(t *testing.T) {
|
|||||||
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
|
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"blocked.sh"}`)))
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"blocked.sh"}`)))
|
||||||
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
|
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
|
||||||
req.RemoteAddr = "127.0.0.1:9999"
|
req.RemoteAddr = "127.0.0.1:9999"
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -100,7 +101,7 @@ func TestHandleRunScriptRejectsFilenameCollisionOutsideAllowlistedPath(t *testin
|
|||||||
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
|
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"unsafe/backup.sh"}`)))
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"unsafe/backup.sh"}`)))
|
||||||
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
|
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
|
||||||
req.RemoteAddr = "127.0.0.1:9999"
|
req.RemoteAddr = "127.0.0.1:9999"
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -122,7 +123,7 @@ func TestHandleRunScriptTruncatesLargeOutput(t *testing.T) {
|
|||||||
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
|
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"large.sh"}`)))
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"large.sh"}`)))
|
||||||
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
|
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
|
||||||
req.RemoteAddr = "127.0.0.1:9999"
|
req.RemoteAddr = "127.0.0.1:9999"
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user