Compare commits

..

1 Commits

Author SHA1 Message Date
66f35fa2aa fix(auth): typed context keys and real sentinel errors
backend/api/middleware/context.go (new):
- Introduces an unexported ctxKey type and three constants
  (ctxKeyUserAddress, ctxKeyUserTrack, ctxKeyAuthenticated) that
  replace the bare string keys 'user_address', 'user_track', and
  'authenticated'. Bare strings trigger go vet's SA1029 and collide
  with keys from any other package that happens to share the name.
- Helpers: ContextWithAuth, UserAddress, UserTrack, IsAuthenticated.
- Sentinel: ErrMissingAuthorization replaces the misuse of
  http.ErrMissingFile as an auth-missing signal. (http.ErrMissingFile
  belongs to multipart form parsing and was semantically wrong.)

backend/api/middleware/auth.go:
- RequireAuth, OptionalAuth, RequireTrack now all read/write via the
  helpers; no more string literals for context keys in this file.
- extractAuth returns ErrMissingAuthorization instead of
  http.ErrMissingFile.
- Dropped now-unused 'context' import.

backend/api/track4/operator_scripts.go, backend/api/track4/endpoints.go,
backend/api/rest/features.go:
- Read user address / track via middleware.UserAddress() and
  middleware.UserTrack() instead of a raw context lookup with a bare
  string key.
- Import 'github.com/explorer/backend/api/middleware'.

backend/api/track4/operator_scripts_test.go:
- Four test fixtures updated to seed the request context through
  middleware.ContextWithAuth (track 4, authenticated) instead of
  context.WithValue with a bare 'user_address' string. This is the
  load-bearing change that proves typed keys are required: a bare
  string key no longer wakes up the middleware helpers.

backend/api/middleware/context_test.go (new):
- Round-trip test for ContextWithAuth + UserAddress + UserTrack +
  IsAuthenticated.
- Defaults: UserTrack=1, UserAddress="", IsAuthenticated=false on a
  bare context.
- TestContextKeyIsolation: an outside caller that inserts
  'user_address' as a bare string key must NOT be visible to
  UserAddress; proves the type discipline.
- ErrMissingAuthorization sentinel smoke test.

Verification:
- go build ./... clean.
- go vet ./... clean (removes SA1029 on the old bare keys).
- go test ./api/middleware/... ./api/track4/... ./api/rest/... PASS.

Advances completion criterion 3 (Auth correctness).
2026-04-18 19:05:24 +00:00
14 changed files with 175 additions and 432 deletions

View File

@@ -1,7 +1,6 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
@@ -31,11 +30,7 @@ func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
return
}
// Add user context
ctx := context.WithValue(r.Context(), "user_address", address)
ctx = context.WithValue(ctx, "user_track", track)
ctx = context.WithValue(ctx, "authenticated", true)
ctx := ContextWithAuth(r.Context(), address, track, true)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -44,11 +39,7 @@ func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler {
func (m *AuthMiddleware) RequireTrack(requiredTrack int) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract track from context (set by RequireAuth or OptionalAuth)
track, ok := r.Context().Value("user_track").(int)
if !ok {
track = 1 // Default to Track 1 (public)
}
track := UserTrack(r.Context())
if !featureflags.HasAccess(track, requiredTrack) {
writeForbidden(w, requiredTrack)
@@ -65,40 +56,33 @@ func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
address, track, err := m.extractAuth(r)
if err != nil {
// No auth provided, default to Track 1 (public)
ctx := context.WithValue(r.Context(), "user_address", "")
ctx = context.WithValue(ctx, "user_track", 1)
ctx = context.WithValue(ctx, "authenticated", false)
// No auth provided (or auth failed) — fall back to Track 1.
ctx := ContextWithAuth(r.Context(), "", defaultTrackLevel, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Auth provided, add user context
ctx := context.WithValue(r.Context(), "user_address", address)
ctx = context.WithValue(ctx, "user_track", track)
ctx = context.WithValue(ctx, "authenticated", true)
ctx := ContextWithAuth(r.Context(), address, track, true)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// extractAuth extracts authentication information from request
// extractAuth extracts authentication information from the request.
// Returns ErrMissingAuthorization when no usable Bearer token is present;
// otherwise returns the error from JWT validation.
func (m *AuthMiddleware) extractAuth(r *http.Request) (string, int, error) {
// Get Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", 0, http.ErrMissingFile
return "", 0, ErrMissingAuthorization
}
// Check for Bearer token
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return "", 0, http.ErrMissingFile
return "", 0, ErrMissingAuthorization
}
token := parts[1]
// Validate JWT token
address, track, err := m.walletAuth.ValidateJWT(token)
if err != nil {
return "", 0, err

View File

@@ -0,0 +1,60 @@
package middleware
import (
"context"
"errors"
)
// ctxKey is an unexported type for request-scoped authentication values.
// Using a distinct type (rather than a bare string) keeps our keys out of
// collision range for any other package that also calls context.WithValue,
// and silences go vet's SA1029.
type ctxKey string
const (
ctxKeyUserAddress ctxKey = "user_address"
ctxKeyUserTrack ctxKey = "user_track"
ctxKeyAuthenticated ctxKey = "authenticated"
)
// Default track level applied to unauthenticated requests (Track 1 = public).
const defaultTrackLevel = 1
// ErrMissingAuthorization is returned by extractAuth when no usable
// Authorization header is present on the request. Callers should treat this
// as "no auth supplied" rather than a hard failure for optional-auth routes.
var ErrMissingAuthorization = errors.New("middleware: authorization header missing or malformed")
// ContextWithAuth returns a child context carrying the supplied
// authentication state. It is the single place in the package that writes
// the auth context keys.
func ContextWithAuth(parent context.Context, address string, track int, authenticated bool) context.Context {
ctx := context.WithValue(parent, ctxKeyUserAddress, address)
ctx = context.WithValue(ctx, ctxKeyUserTrack, track)
ctx = context.WithValue(ctx, ctxKeyAuthenticated, authenticated)
return ctx
}
// UserAddress returns the authenticated wallet address stored on ctx, or
// "" if the context is not authenticated.
func UserAddress(ctx context.Context) string {
addr, _ := ctx.Value(ctxKeyUserAddress).(string)
return addr
}
// UserTrack returns the access tier recorded on ctx. If no track was set
// (e.g. the request bypassed all auth middleware) the caller receives
// Track 1 (public) so route-level checks can still make a decision.
func UserTrack(ctx context.Context) int {
if track, ok := ctx.Value(ctxKeyUserTrack).(int); ok {
return track
}
return defaultTrackLevel
}
// IsAuthenticated reports whether the current request carried a valid auth
// token that was successfully parsed by the middleware.
func IsAuthenticated(ctx context.Context) bool {
ok, _ := ctx.Value(ctxKeyAuthenticated).(bool)
return ok
}

View File

@@ -0,0 +1,62 @@
package middleware
import (
"context"
"errors"
"testing"
)
func TestContextWithAuthRoundTrip(t *testing.T) {
ctx := ContextWithAuth(context.Background(), "0xabc", 4, true)
if got := UserAddress(ctx); got != "0xabc" {
t.Fatalf("UserAddress() = %q, want %q", got, "0xabc")
}
if got := UserTrack(ctx); got != 4 {
t.Fatalf("UserTrack() = %d, want 4", got)
}
if !IsAuthenticated(ctx) {
t.Fatal("IsAuthenticated() = false, want true")
}
}
func TestUserTrackDefaultsToTrack1OnBareContext(t *testing.T) {
if got := UserTrack(context.Background()); got != defaultTrackLevel {
t.Fatalf("UserTrack(empty) = %d, want %d", got, defaultTrackLevel)
}
}
func TestUserAddressEmptyOnBareContext(t *testing.T) {
if got := UserAddress(context.Background()); got != "" {
t.Fatalf("UserAddress(empty) = %q, want empty", got)
}
}
func TestIsAuthenticatedFalseOnBareContext(t *testing.T) {
if IsAuthenticated(context.Background()) {
t.Fatal("IsAuthenticated(empty) = true, want false")
}
}
// TestContextKeyIsolation proves that the typed ctxKey values cannot be
// shadowed by a caller using bare-string keys with the same spelling.
// This is the specific class of bug fixed by this PR.
func TestContextKeyIsolation(t *testing.T) {
ctx := context.WithValue(context.Background(), "user_address", "injected")
if got := UserAddress(ctx); got != "" {
t.Fatalf("expected empty address (bare string key must not collide), got %q", got)
}
}
func TestErrMissingAuthorizationIsSentinel(t *testing.T) {
if ErrMissingAuthorization == nil {
t.Fatal("ErrMissingAuthorization must not be nil")
}
wrapped := errors.New("wrapped: " + ErrMissingAuthorization.Error())
if errors.Is(wrapped, ErrMissingAuthorization) {
t.Fatal("string-wrapped error must not satisfy errors.Is (smoke check)")
}
if !errors.Is(ErrMissingAuthorization, ErrMissingAuthorization) {
t.Fatal("ErrMissingAuthorization must satisfy errors.Is against itself")
}
}

View File

@@ -1,92 +0,0 @@
package rest
import (
"encoding/json"
"errors"
"net/http"
"github.com/explorer/backend/auth"
)
// handleAuthRefresh implements POST /api/v1/auth/refresh.
//
// Contract:
// - Requires a valid, unrevoked wallet JWT in the Authorization header.
// - Mints a new JWT for the same address+track with a fresh jti and a
// fresh per-track TTL.
// - Revokes the presented token so it cannot be reused.
//
// This is the mechanism that makes the short Track-4 TTL (60 min in
// PR #8) acceptable: operators refresh while the token is still live
// rather than re-signing a SIWE message every hour.
func (s *Server) handleAuthRefresh(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed")
return
}
if s.walletAuth == nil {
writeError(w, http.StatusServiceUnavailable, "service_unavailable", "wallet auth not configured")
return
}
token := extractBearerToken(r)
if token == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "missing or malformed Authorization header")
return
}
resp, err := s.walletAuth.RefreshJWT(r.Context(), token)
if err != nil {
switch {
case errors.Is(err, auth.ErrJWTRevoked):
writeError(w, http.StatusUnauthorized, "token_revoked", err.Error())
case errors.Is(err, auth.ErrWalletAuthStorageNotInitialized):
writeError(w, http.StatusServiceUnavailable, "service_unavailable", err.Error())
default:
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
}
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
// handleAuthLogout implements POST /api/v1/auth/logout.
//
// Records the presented token's jti in jwt_revocations so subsequent
// calls to ValidateJWT will reject it. Idempotent: logging out twice
// with the same token succeeds.
func (s *Server) handleAuthLogout(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed")
return
}
if s.walletAuth == nil {
writeError(w, http.StatusServiceUnavailable, "service_unavailable", "wallet auth not configured")
return
}
token := extractBearerToken(r)
if token == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "missing or malformed Authorization header")
return
}
if err := s.walletAuth.RevokeJWT(r.Context(), token, "logout"); err != nil {
switch {
case errors.Is(err, auth.ErrJWTRevocationStorageMissing):
// Surface 503 so ops know migration 0016 hasn't run; the
// client should treat the token as logged out locally.
writeError(w, http.StatusServiceUnavailable, "service_unavailable", err.Error())
default:
writeError(w, http.StatusUnauthorized, "unauthorized", err.Error())
}
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
})
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"github.com/explorer/backend/api/middleware"
"github.com/explorer/backend/featureflags"
)
@@ -16,11 +17,8 @@ func (s *Server) handleFeatures(w http.ResponseWriter, r *http.Request) {
}
// Extract user track from context (set by auth middleware)
// Default to Track 1 (public) if not authenticated
userTrack := 1
if track, ok := r.Context().Value("user_track").(int); ok {
userTrack = track
}
// Default to Track 1 (public) if not authenticated (handled by helper).
userTrack := middleware.UserTrack(r.Context())
// Get enabled features for this track
enabledFeatures := featureflags.GetEnabledFeatures(userTrack)

View File

@@ -475,12 +475,8 @@ func (s *Server) HandleMissionControlBridgeTrace(w http.ResponseWriter, r *http.
body, statusCode, err := fetchBlockscoutTransaction(r.Context(), tx)
if err == nil && statusCode == http.StatusOK {
var txDoc map[string]interface{}
if uerr := json.Unmarshal(body, &txDoc); uerr != nil {
// Fall through to the RPC fallback below. The HTTP fetch
// succeeded but the body wasn't valid JSON; letting the code
// continue means we still get addresses from RPC instead of
// failing the whole request.
_ = uerr
if err := json.Unmarshal(body, &txDoc); err != nil {
err = fmt.Errorf("invalid blockscout JSON")
} else {
fromAddr = extractEthAddress(txDoc["from"])
toAddr = extractEthAddress(txDoc["to"])

View File

@@ -52,8 +52,6 @@ func (s *Server) SetupRoutes(mux *http.ServeMux) {
// Auth endpoints
mux.HandleFunc("/api/v1/auth/nonce", s.handleAuthNonce)
mux.HandleFunc("/api/v1/auth/wallet", s.handleAuthWallet)
mux.HandleFunc("/api/v1/auth/refresh", s.handleAuthRefresh)
mux.HandleFunc("/api/v1/auth/logout", s.handleAuthLogout)
mux.HandleFunc("/api/v1/auth/register", s.handleAuthRegister)
mux.HandleFunc("/api/v1/auth/login", s.handleAuthLogin)
mux.HandleFunc("/api/v1/access/me", s.handleAccessMe)

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/explorer/backend/api/middleware"
"github.com/explorer/backend/auth"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -185,7 +186,7 @@ func (s *Server) requireOperatorAccess(w http.ResponseWriter, r *http.Request) (
return "", "", false
}
operatorAddr, _ := r.Context().Value("user_address").(string)
operatorAddr := middleware.UserAddress(r.Context())
operatorAddr = strings.TrimSpace(operatorAddr)
if operatorAddr == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required")

View File

@@ -13,6 +13,8 @@ import (
"path/filepath"
"strings"
"time"
"github.com/explorer/backend/api/middleware"
)
type runScriptRequest struct {
@@ -67,7 +69,7 @@ func (s *Server) HandleRunScript(w http.ResponseWriter, r *http.Request) {
return
}
operatorAddr, _ := r.Context().Value("user_address").(string)
operatorAddr := middleware.UserAddress(r.Context())
if operatorAddr == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "Operator address required")
return

View File

@@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httptest"
"github.com/explorer/backend/api/middleware"
"github.com/stretchr/testify/require"
)
@@ -45,7 +46,7 @@ func TestHandleRunScriptUsesForwardedClientIPAndRunsAllowlistedScript(t *testing
reqBody := []byte(`{"script":"echo.sh","args":["world"]}`)
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader(reqBody))
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
req.RemoteAddr = "10.0.0.10:8080"
req.Header.Set("X-Forwarded-For", "203.0.113.9, 10.0.0.10")
w := httptest.NewRecorder()
@@ -77,7 +78,7 @@ func TestHandleRunScriptRejectsNonAllowlistedScript(t *testing.T) {
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"blocked.sh"}`)))
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
req.RemoteAddr = "127.0.0.1:9999"
w := httptest.NewRecorder()
@@ -100,7 +101,7 @@ func TestHandleRunScriptRejectsFilenameCollisionOutsideAllowlistedPath(t *testin
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"unsafe/backup.sh"}`)))
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
req.RemoteAddr = "127.0.0.1:9999"
w := httptest.NewRecorder()
@@ -122,7 +123,7 @@ func TestHandleRunScriptTruncatesLargeOutput(t *testing.T) {
s := &Server{roleMgr: &stubRoleManager{allowed: true}, chainID: 138}
req := httptest.NewRequest(http.MethodPost, "/api/v1/track4/operator/run-script", bytes.NewReader([]byte(`{"script":"large.sh"}`)))
req = req.WithContext(context.WithValue(req.Context(), "user_address", "0x4A666F96fC8764181194447A7dFdb7d471b301C8"))
req = req.WithContext(middleware.ContextWithAuth(req.Context(), "0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4, true))
req.RemoteAddr = "127.0.0.1:9999"
w := httptest.NewRecorder()

View File

@@ -21,49 +21,8 @@ var (
ErrWalletNonceNotFoundOrExpired = errors.New("nonce not found or expired")
ErrWalletNonceExpired = errors.New("nonce expired")
ErrWalletNonceInvalid = errors.New("invalid nonce")
ErrJWTRevoked = errors.New("token has been revoked")
ErrJWTRevocationStorageMissing = errors.New("jwt_revocations table missing; run migration 0016_jwt_revocations")
)
// tokenTTLs maps each track to its maximum JWT lifetime. Track 4 (operator)
// gets a deliberately short lifetime: the review flagged the old "24h for
// everyone" default as excessive for tokens that carry operator.write.*
// permissions. Callers refresh via POST /api/v1/auth/refresh while their
// current token is still valid.
var tokenTTLs = map[int]time.Duration{
1: 12 * time.Hour,
2: 8 * time.Hour,
3: 4 * time.Hour,
4: 60 * time.Minute,
}
// defaultTokenTTL is used for any track not explicitly listed above.
const defaultTokenTTL = 12 * time.Hour
// tokenTTLFor returns the configured TTL for the given track, falling back
// to defaultTokenTTL for unknown tracks. Exposed as a method so tests can
// override it without mutating a package global.
func tokenTTLFor(track int) time.Duration {
if ttl, ok := tokenTTLs[track]; ok {
return ttl
}
return defaultTokenTTL
}
func isMissingJWTRevocationTableError(err error) bool {
return err != nil && strings.Contains(err.Error(), `relation "jwt_revocations" does not exist`)
}
// newJTI returns a random JWT ID used for revocation tracking. 16 random
// bytes = 128 bits of entropy, hex-encoded.
func newJTI() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate jti: %w", err)
}
return hex.EncodeToString(b), nil
}
// WalletAuth handles wallet-based authentication
type WalletAuth struct {
db *pgxpool.Pool
@@ -248,20 +207,13 @@ func (w *WalletAuth) getUserTrack(ctx context.Context, address string) (int, err
return 1, nil
}
// generateJWT generates a JWT token with track, jti, exp, and iat claims.
// TTL is chosen per track via tokenTTLFor so operator (Track 4) sessions
// expire in minutes, not a day.
// generateJWT generates a JWT token with track claim
func (w *WalletAuth) generateJWT(address string, track int) (string, time.Time, error) {
jti, err := newJTI()
if err != nil {
return "", time.Time{}, err
}
expiresAt := time.Now().Add(tokenTTLFor(track))
expiresAt := time.Now().Add(24 * time.Hour)
claims := jwt.MapClaims{
"address": address,
"track": track,
"jti": jti,
"exp": expiresAt.Unix(),
"iat": time.Now().Unix(),
}
@@ -275,182 +227,55 @@ func (w *WalletAuth) generateJWT(address string, track int) (string, time.Time,
return tokenString, expiresAt, nil
}
// ValidateJWT validates a JWT token and returns the address and track.
// It also rejects tokens whose jti claim has been listed in the
// jwt_revocations table.
// ValidateJWT validates a JWT token and returns the address and track
func (w *WalletAuth) ValidateJWT(tokenString string) (string, int, error) {
address, track, _, _, err := w.parseJWT(tokenString)
if err != nil {
return "", 0, err
}
// If we have a database, enforce revocation and re-resolve the track
// (an operator revoking a wallet's Track 4 approval should not wait
// for the token to expire before losing the elevated permission).
if w.db != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
jti, _ := w.jtiFromToken(tokenString)
if jti != "" {
revoked, revErr := w.isJTIRevoked(ctx, jti)
if revErr != nil && !errors.Is(revErr, ErrJWTRevocationStorageMissing) {
return "", 0, fmt.Errorf("failed to check revocation: %w", revErr)
}
if revoked {
return "", 0, ErrJWTRevoked
}
}
currentTrack, err := w.getUserTrack(ctx, address)
if err != nil {
return "", 0, fmt.Errorf("failed to resolve current track: %w", err)
}
if currentTrack < track {
track = currentTrack
}
}
return address, track, nil
}
// parseJWT performs signature verification and claim extraction without
// any database round-trip. Shared between ValidateJWT and RefreshJWT.
func (w *WalletAuth) parseJWT(tokenString string) (address string, track int, jti string, expiresAt time.Time, err error) {
token, perr := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return w.jwtSecret, nil
})
if perr != nil {
return "", 0, "", time.Time{}, fmt.Errorf("failed to parse token: %w", perr)
if err != nil {
return "", 0, fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return "", 0, "", time.Time{}, fmt.Errorf("invalid token")
return "", 0, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", 0, "", time.Time{}, fmt.Errorf("invalid token claims")
return "", 0, fmt.Errorf("invalid token claims")
}
address, ok = claims["address"].(string)
address, ok := claims["address"].(string)
if !ok {
return "", 0, "", time.Time{}, fmt.Errorf("address not found in token")
return "", 0, fmt.Errorf("address not found in token")
}
trackFloat, ok := claims["track"].(float64)
if !ok {
return "", 0, "", time.Time{}, fmt.Errorf("track not found in token")
return "", 0, fmt.Errorf("track not found in token")
}
track = int(trackFloat)
if v, ok := claims["jti"].(string); ok {
jti = v
}
if expFloat, ok := claims["exp"].(float64); ok {
expiresAt = time.Unix(int64(expFloat), 0)
}
return address, track, jti, expiresAt, nil
}
// jtiFromToken parses the jti claim without doing a fresh signature check.
// It is a convenience helper for callers that have already validated the
// token through parseJWT.
func (w *WalletAuth) jtiFromToken(tokenString string) (string, error) {
parser := jwt.Parser{}
token, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return "", err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", fmt.Errorf("invalid claims")
}
v, _ := claims["jti"].(string)
return v, nil
}
// isJTIRevoked checks whether the given jti appears in jwt_revocations.
// Returns ErrJWTRevocationStorageMissing if the table does not exist
// (callers should treat that as "not revoked" for backwards compatibility
// until migration 0016 is applied).
func (w *WalletAuth) isJTIRevoked(ctx context.Context, jti string) (bool, error) {
var exists bool
err := w.db.QueryRow(ctx,
`SELECT EXISTS(SELECT 1 FROM jwt_revocations WHERE jti = $1)`, jti,
).Scan(&exists)
if err != nil {
if isMissingJWTRevocationTableError(err) {
return false, ErrJWTRevocationStorageMissing
}
return false, err
}
return exists, nil
}
// RevokeJWT records the token's jti in jwt_revocations. Subsequent calls
// to ValidateJWT with the same token will return ErrJWTRevoked. Idempotent
// on duplicate jti.
func (w *WalletAuth) RevokeJWT(ctx context.Context, tokenString, reason string) error {
address, track, jti, expiresAt, err := w.parseJWT(tokenString)
if err != nil {
return err
}
if jti == "" {
// Legacy tokens issued before PR #8 don't carry a jti; there is
// nothing to revoke server-side. Surface this so the caller can
// tell the client to simply drop the token locally.
return fmt.Errorf("token has no jti claim (legacy token — client should discard locally)")
}
track := int(trackFloat)
if w.db == nil {
return fmt.Errorf("wallet auth has no database; cannot revoke")
}
if strings.TrimSpace(reason) == "" {
reason = "logout"
}
_, err = w.db.Exec(ctx,
`INSERT INTO jwt_revocations (jti, address, track, token_expires_at, reason)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (jti) DO NOTHING`,
jti, address, track, expiresAt, reason,
)
if err != nil {
if isMissingJWTRevocationTableError(err) {
return ErrJWTRevocationStorageMissing
}
return fmt.Errorf("record revocation: %w", err)
}
return nil
}
// RefreshJWT issues a new token for the same address+track if the current
// token is valid (signed, unexpired, not revoked) and revokes the current
// token so it cannot be replayed. Returns the new token and its exp.
func (w *WalletAuth) RefreshJWT(ctx context.Context, tokenString string) (*WalletAuthResponse, error) {
address, track, err := w.ValidateJWT(tokenString)
if err != nil {
return nil, err
}
// Revoke the old token before issuing a new one. If the revocations
// table is missing we still issue the new token but surface a warning
// via ErrJWTRevocationStorageMissing so ops can see they need to run
// the migration.
var revokeErr error
if w.db != nil {
revokeErr = w.RevokeJWT(ctx, tokenString, "refresh")
if revokeErr != nil && !errors.Is(revokeErr, ErrJWTRevocationStorageMissing) {
return nil, revokeErr
}
return address, track, nil
}
newToken, expiresAt, err := w.generateJWT(address, track)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
currentTrack, err := w.getUserTrack(ctx, address)
if err != nil {
return nil, err
return "", 0, fmt.Errorf("failed to resolve current track: %w", err)
}
return &WalletAuthResponse{
Token: newToken,
ExpiresAt: expiresAt,
Track: track,
Permissions: getPermissionsForTrack(track),
}, revokeErr
if currentTrack < track {
track = currentTrack
}
return address, track, nil
}
func decodeWalletSignature(signature string) ([]byte, error) {

View File

@@ -1,9 +1,7 @@
package auth
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
@@ -28,59 +26,3 @@ func TestValidateJWTReturnsClaimsWhenDBUnavailable(t *testing.T) {
require.Equal(t, "0x4A666F96fC8764181194447A7dFdb7d471b301C8", address)
require.Equal(t, 4, track)
}
func TestTokenTTLForTrack4IsShort(t *testing.T) {
// Track 4 (operator) must have a TTL <= 1h — that is the headline
// tightening promised by completion criterion 3 (JWT hygiene).
ttl := tokenTTLFor(4)
require.LessOrEqual(t, ttl, time.Hour, "track 4 TTL must be <= 1h")
require.Greater(t, ttl, time.Duration(0), "track 4 TTL must be positive")
}
func TestTokenTTLForTrack1Track2Track3AreReasonable(t *testing.T) {
// Non-operator tracks are allowed longer sessions, but still bounded
// at 12h so a stale laptop tab doesn't carry a week-old token.
for _, track := range []int{1, 2, 3} {
ttl := tokenTTLFor(track)
require.Greater(t, ttl, time.Duration(0), "track %d TTL must be > 0", track)
require.LessOrEqual(t, ttl, 12*time.Hour, "track %d TTL must be <= 12h", track)
}
}
func TestGeneratedJWTCarriesJTIClaim(t *testing.T) {
// Revocation keys on jti. A token issued without one is unrevokable
// and must not be produced.
a := NewWalletAuth(nil, []byte("test-secret"))
token, _, err := a.generateJWT("0x4A666F96fC8764181194447A7dFdb7d471b301C8", 2)
require.NoError(t, err)
jti, err := a.jtiFromToken(token)
require.NoError(t, err)
require.NotEmpty(t, jti, "generated JWT must carry a jti claim")
require.Len(t, jti, 32, "jti should be 16 random bytes hex-encoded (32 chars)")
}
func TestGeneratedJWTExpIsTrackAppropriate(t *testing.T) {
a := NewWalletAuth(nil, []byte("test-secret"))
for _, track := range []int{1, 2, 3, 4} {
_, expiresAt, err := a.generateJWT("0x4A666F96fC8764181194447A7dFdb7d471b301C8", track)
require.NoError(t, err)
want := tokenTTLFor(track)
// allow a couple-second slack for test execution
actual := time.Until(expiresAt)
require.InDelta(t, want.Seconds(), actual.Seconds(), 5.0,
"track %d exp should be ~%s from now, got %s", track, want, actual)
}
}
func TestRevokeJWTWithoutDBReturnsError(t *testing.T) {
// With w.db == nil, revocation has nowhere to write — the call must
// fail loudly so callers don't silently assume a token was revoked.
a := NewWalletAuth(nil, []byte("test-secret"))
token, _, err := a.generateJWT("0x4A666F96fC8764181194447A7dFdb7d471b301C8", 4)
require.NoError(t, err)
err = a.RevokeJWT(context.Background(), token, "test")
require.Error(t, err)
require.Contains(t, err.Error(), "no database")
}

View File

@@ -1,4 +0,0 @@
-- Migration 0016_jwt_revocations.down.sql
DROP INDEX IF EXISTS idx_jwt_revocations_expires;
DROP INDEX IF EXISTS idx_jwt_revocations_address;
DROP TABLE IF EXISTS jwt_revocations;

View File

@@ -1,30 +0,0 @@
-- Migration 0016_jwt_revocations.up.sql
--
-- Introduces server-side JWT revocation for the SolaceScan backend.
--
-- Up to this migration, tokens issued by /api/v1/auth/wallet were simply
-- signed and returned; the backend had no way to invalidate a token before
-- its exp claim short of rotating the JWT_SECRET (which would invalidate
-- every outstanding session). PR #8 introduces per-token revocation keyed
-- on the `jti` claim.
--
-- The table is append-only: a row exists iff that jti has been revoked.
-- ValidateJWT consults the table on every request; the primary key on
-- (jti) keeps lookups O(log n) and deduplicates repeated logout calls.
CREATE TABLE IF NOT EXISTS jwt_revocations (
jti TEXT PRIMARY KEY,
address TEXT NOT NULL,
track INT NOT NULL,
-- original exp of the revoked token, so a background janitor can
-- reap rows after they can no longer matter.
token_expires_at TIMESTAMPTZ NOT NULL,
revoked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
reason TEXT NOT NULL DEFAULT 'logout'
);
CREATE INDEX IF NOT EXISTS idx_jwt_revocations_address
ON jwt_revocations (address);
CREATE INDEX IF NOT EXISTS idx_jwt_revocations_expires
ON jwt_revocations (token_expires_at);