Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ require (
github.com/go-redis/redis_rate/v10 v10.0.1
github.com/go-redis/redismock/v9 v9.2.0
github.com/go-sql-driver/mysql v1.9.2
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/gomodule/redigo v1.9.3
github.com/google/go-github v17.0.0+incompatible
github.com/google/uuid v1.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq
github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
Expand Down
25 changes: 11 additions & 14 deletions manager/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,20 +503,17 @@ func (cfg *Config) Validate() error {
}
}

if cfg.Auth.JWT.Realm == "" {
return errors.New("jwt requires parameter realm")
}

if cfg.Auth.JWT.Key == "" {
return errors.New("jwt requires parameter key")
}

if cfg.Auth.JWT.Timeout == 0 {
return errors.New("jwt requires parameter timeout")
}

if cfg.Auth.JWT.MaxRefresh == 0 {
return errors.New("jwt requires parameter maxRefresh")
// Auth validation: only validate JWT fields if a key is configured (JWT is optional for backward compatibility)
if cfg.Auth.JWT.Key != "" {
if cfg.Auth.JWT.Realm == "" {
return errors.New("jwt requires parameter realm when key is set")
}
if cfg.Auth.JWT.Timeout == 0 {
return errors.New("jwt requires parameter timeout when key is set")
}
if cfg.Auth.JWT.MaxRefresh == 0 {
return errors.New("jwt requires parameter maxRefresh when key is set")
}
}

if cfg.Database.Type == "" {
Expand Down
3 changes: 3 additions & 0 deletions manager/rpcserver/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"d7y.io/dragonfly/v2/manager/database"
"d7y.io/dragonfly/v2/manager/models"
"d7y.io/dragonfly/v2/manager/searcher"
"d7y.io/dragonfly/v2/pkg/rpc/auth"
managerserver "d7y.io/dragonfly/v2/pkg/rpc/manager/server"
)

Expand Down Expand Up @@ -59,6 +60,8 @@ func New(
searcher: searcher,
}

// Provide JWT key from config to manager server via auth package.
auth.SetServerKey("manager", cfg.Auth.JWT.Key)
return s, managerserver.New(
newManagerServerV1(s.config, database, s.cache, s.searcher),
newManagerServerV2(s.config, database, s.cache, s.searcher),
Expand Down
28 changes: 28 additions & 0 deletions pkg/rpc/auth/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package auth

import (
"context"

"google.golang.org/grpc/credentials"
)

// PerRPCCreds attaches a Bearer JWT to outgoing gRPC calls.
type PerRPCCreds struct {
token string
// If needed later, add refresh hooks.
}

// NewPerRPCCreds constructs credentials with a given token value.
func NewPerRPCCreds(token string) credentials.PerRPCCredentials {
return &PerRPCCreds{token: token}
}

func (c *PerRPCCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + c.token,
}, nil
}

// RequireTransportSecurity returns false for backward compatibility with existing deployments.
// In production, configure TLS separately via server.TLS config to secure JWT transmission.
func (c *PerRPCCreds) RequireTransportSecurity() bool { return false }
77 changes: 77 additions & 0 deletions pkg/rpc/auth/interceptors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package auth

import (
"context"
"strings"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// UnaryServerJWTInterceptor returns a unary server interceptor that validates JWT in metadata.
func UnaryServerJWTInterceptor(key string, expectedAudience string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := validateFromMetadata(ctx, key, expectedAudience, info.FullMethod); err != nil {
return nil, err
}
return handler(ctx, req)
}
}

// StreamServerJWTInterceptor returns a stream server interceptor that validates JWT in metadata.
func StreamServerJWTInterceptor(key string, expectedAudience string) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := validateFromMetadata(ss.Context(), key, expectedAudience, info.FullMethod); err != nil {
return err
}
return handler(srv, ss)
}
}
Comment on lines +13 to +31

Copilot AI Dec 6, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interceptors will enforce JWT authentication on ALL gRPC endpoints, including health checks and reflection services. This will break health check probes (e.g., Kubernetes liveness/readiness probes) and gRPC reflection tools that typically don't provide authentication.

Consider allowing certain methods to bypass authentication:

func validateFromMetadata(ctx context.Context, key string, expectedAudience string, info interface{}) error {
    // Skip auth for health checks and reflection
    if method := extractMethodName(info); isPublicMethod(method) {
        return nil
    }
    // ... existing validation logic
}

func isPublicMethod(method string) bool {
    publicMethods := []string{
        "/grpc.health.v1.Health/",
        "/grpc.reflection.v1alpha.ServerReflection/",
        "/grpc.reflection.v1.ServerReflection/",
    }
    for _, prefix := range publicMethods {
        if strings.HasPrefix(method, prefix) {
            return true
        }
    }
    return false
}

Alternatively, use per-service interceptors instead of global ones to exclude health/reflection services.

Copilot uses AI. Check for mistakes.

func validateFromMetadata(ctx context.Context, key string, expectedAudience string, method string) error {
// Skip auth for health checks and gRPC reflection to allow probes and debugging tools
if isPublicMethod(method) {
return nil
}

// If no key is configured, JWT auth is disabled (backward compatible)
if key == "" {
return nil
}

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "missing metadata")
}
vals := md.Get("authorization")
if len(vals) == 0 {
return status.Error(codes.Unauthenticated, "missing authorization")
}
parts := strings.Fields(vals[0])
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return status.Error(codes.Unauthenticated, "invalid authorization header")
}
token := parts[1]
if _, err := ValidateHS256(key, token, expectedAudience); err != nil {
return status.Error(codes.Unauthenticated, err.Error())
}
return nil
}

// isPublicMethod determines if a gRPC method should bypass JWT authentication.
// Health checks and reflection services are exempt to support infrastructure probes and debugging.
func isPublicMethod(method string) bool {
publicPrefixes := []string{
"/grpc.health.v1.Health/",
"/grpc.reflection.v1alpha.ServerReflection/",
"/grpc.reflection.v1.ServerReflection/",
}
for _, prefix := range publicPrefixes {
if strings.HasPrefix(method, prefix) {
return true
}
}
return false
}
122 changes: 122 additions & 0 deletions pkg/rpc/auth/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package auth

import (
"errors"
"sync"
"time"

jwtlib "github.com/golang-jwt/jwt/v5"
)

// Claims is a minimal JWT claims set used for inter-component gRPC auth.
type Claims struct {
Issuer string `json:"iss"`
Audience string `json:"aud"`
IssuedAt time.Time `json:"iat"`
Expires time.Time `json:"exp"`
}

// Global registry for per-component server keys with thread-safe access.
// This allows components to register their JWT keys at startup for use in interceptors.
var (
serverKeysMu sync.RWMutex
serverKeys = map[string]string{}
)

// SetServerKey sets the shared signing key for a component's server (e.g., "manager", "scheduler").
func SetServerKey(component, key string) {
serverKeysMu.Lock()
defer serverKeysMu.Unlock()
serverKeys[component] = key
}

// GetServerKey retrieves the key for a component server.
func GetServerKey(component string) string {
serverKeysMu.RLock()
defer serverKeysMu.RUnlock()
return serverKeys[component]
}

// SignHS256 signs the provided claims with the given shared secret key using HS256.
func SignHS256(key string, c Claims) (string, error) {
if key == "" {
return "", errors.New("jwt: empty signing key")
}
claims := jwtlib.MapClaims{
"iss": c.Issuer,
"aud": c.Audience,
"iat": c.IssuedAt.Unix(),
"exp": c.Expires.Unix(),
}
token := jwtlib.NewWithClaims(jwtlib.SigningMethodHS256, claims)
return token.SignedString([]byte(key))
}

// ValidateHS256 validates token signature and basic claims. Returns parsed claims.
func ValidateHS256(key string, tokenStr string, expectedAudience string) (Claims, error) {
var out Claims
if key == "" {
return out, errors.New("jwt: empty validation key")
}
parser := jwtlib.NewParser(jwtlib.WithValidMethods([]string{jwtlib.SigningMethodHS256.Alg()}))
token, err := parser.Parse(tokenStr, func(t *jwtlib.Token) (any, error) {
return []byte(key), nil
})
if err != nil || !token.Valid {
return out, errors.New("jwt: invalid token")
}
claims, ok := token.Claims.(jwtlib.MapClaims)
if !ok {
return out, errors.New("jwt: invalid claims type")
}
// Audience check
if audAny, ok := claims["aud"]; ok {
if audStr, ok := audAny.(string); ok {
if expectedAudience != "" && audStr != expectedAudience {
return out, errors.New("jwt: audience mismatch")
}
out.Audience = audStr
}
}
// Issuer
if issAny, ok := claims["iss"]; ok {
if issStr, ok := issAny.(string); ok {
out.Issuer = issStr
}
}
// Time checks
now := time.Now()
if expAny, ok := claims["exp"]; ok {
switch v := expAny.(type) {
case float64:
out.Expires = time.Unix(int64(v), 0)
case int64:
out.Expires = time.Unix(v, 0)
case uint64:
out.Expires = time.Unix(int64(v), 0)
}
if now.After(out.Expires) {
return out, errors.New("jwt: token expired")
}
}
if iatAny, ok := claims["iat"]; ok {
switch v := iatAny.(type) {
case float64:
out.IssuedAt = time.Unix(int64(v), 0)
case int64:
out.IssuedAt = time.Unix(v, 0)
}
}
return out, nil
}

// DurationClaims constructs Claims with given ttl.
func DurationClaims(issuer, audience string, ttl time.Duration) Claims {
now := time.Now()
return Claims{
Issuer: issuer,
Audience: audience,
IssuedAt: now,
Expires: now.Add(ttl),
}
}
4 changes: 4 additions & 0 deletions pkg/rpc/manager/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (

logger "d7y.io/dragonfly/v2/internal/dflog"
"d7y.io/dragonfly/v2/pkg/rpc"
"d7y.io/dragonfly/v2/pkg/rpc/auth"
)

const (
Expand All @@ -55,6 +56,7 @@ const (
// New returns grpc server instance and register service on grpc server.
func New(managerServerV1 managerv1.ManagerServer, managerServerV2 managerv2.ManagerServer, requestRateLimit float64, opts ...grpc.ServerOption) *grpc.Server {
limiter := rpc.NewRateLimiterInterceptor(requestRateLimit, int64(requestRateLimit))
jwtKey := auth.GetServerKey("manager")

grpcServer := grpc.NewServer(append([]grpc.ServerOption{
grpc.MaxRecvMsgSize(math.MaxInt32),
Expand All @@ -69,13 +71,15 @@ func New(managerServerV1 managerv1.ManagerServer, managerServerV2 managerv2.Mana
MaxConnectionAgeGrace: DefaultMaxConnectionAgeGrace,
}),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
auth.UnaryServerJWTInterceptor(jwtKey, "manager"),
grpc_ratelimit.UnaryServerInterceptor(limiter),
grpc_prometheus.UnaryServerInterceptor,
grpc_zap.UnaryServerInterceptor(logger.GrpcLogger.Desugar()),
grpc_validator.UnaryServerInterceptor(),
grpc_recovery.UnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
auth.StreamServerJWTInterceptor(jwtKey, "manager"),
grpc_ratelimit.StreamServerInterceptor(limiter),
grpc_prometheus.StreamServerInterceptor,
grpc_zap.StreamServerInterceptor(logger.GrpcLogger.Desugar()),
Expand Down
4 changes: 4 additions & 0 deletions pkg/rpc/scheduler/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (

logger "d7y.io/dragonfly/v2/internal/dflog"
"d7y.io/dragonfly/v2/pkg/rpc"
"d7y.io/dragonfly/v2/pkg/rpc/auth"
)

const (
Expand All @@ -55,6 +56,7 @@ const (
// New returns a grpc server instance and register service on grpc server.
func New(schedulerServerV1 schedulerv1.SchedulerServer, schedulerServerV2 schedulerv2.SchedulerServer, requestRateLimit float64, opts ...grpc.ServerOption) *grpc.Server {
limiter := rpc.NewRateLimiterInterceptor(requestRateLimit, int64(requestRateLimit))
jwtKey := auth.GetServerKey("scheduler")

grpcServer := grpc.NewServer(append([]grpc.ServerOption{
grpc.MaxRecvMsgSize(math.MaxInt32),
Expand All @@ -69,6 +71,7 @@ func New(schedulerServerV1 schedulerv1.SchedulerServer, schedulerServerV2 schedu
MaxConnectionAgeGrace: DefaultMaxConnectionAgeGrace,
}),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
auth.UnaryServerJWTInterceptor(jwtKey, "scheduler"),
grpc_ratelimit.UnaryServerInterceptor(limiter),
rpc.ConvertErrorUnaryServerInterceptor,
grpc_prometheus.UnaryServerInterceptor,
Expand All @@ -77,6 +80,7 @@ func New(schedulerServerV1 schedulerv1.SchedulerServer, schedulerServerV2 schedu
grpc_recovery.UnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
auth.StreamServerJWTInterceptor(jwtKey, "scheduler"),
grpc_ratelimit.StreamServerInterceptor(limiter),
rpc.ConvertErrorStreamServerInterceptor,
grpc_prometheus.StreamServerInterceptor,
Expand Down
Loading
Loading