diff --git a/go.mod b/go.mod index 29b545aa081..3359ec4bc95 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/go-redis/redis_rate/v10 v10.0.1 github.com/go-redis/redismock/v9 v9.2.0 github.com/go-sql-driver/mysql v1.9.2 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/gomodule/redigo v1.9.3 github.com/google/go-github v17.0.0+incompatible github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index a1bd50ce6f0..56fe46ed878 100644 --- a/go.sum +++ b/go.sum @@ -298,6 +298,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= diff --git a/manager/config/config.go b/manager/config/config.go index 556b6e4f8e9..80319a4bc3c 100644 --- a/manager/config/config.go +++ b/manager/config/config.go @@ -503,20 +503,17 @@ func (cfg *Config) Validate() error { } } - if cfg.Auth.JWT.Realm == "" { - return errors.New("jwt requires parameter realm") - } - - if cfg.Auth.JWT.Key == "" { - return errors.New("jwt requires parameter key") - } - - if cfg.Auth.JWT.Timeout == 0 { - return errors.New("jwt requires parameter timeout") - } - - if cfg.Auth.JWT.MaxRefresh == 0 { - return errors.New("jwt requires parameter maxRefresh") + // Auth validation: only validate JWT fields if a key is configured (JWT is optional for backward compatibility) + if cfg.Auth.JWT.Key != "" { + if cfg.Auth.JWT.Realm == "" { + return errors.New("jwt requires parameter realm when key is set") + } + if cfg.Auth.JWT.Timeout == 0 { + return errors.New("jwt requires parameter timeout when key is set") + } + if cfg.Auth.JWT.MaxRefresh == 0 { + return errors.New("jwt requires parameter maxRefresh when key is set") + } } if cfg.Database.Type == "" { diff --git a/manager/rpcserver/rpcserver.go b/manager/rpcserver/rpcserver.go index 9c5f7a843d9..68f13ac1a0b 100644 --- a/manager/rpcserver/rpcserver.go +++ b/manager/rpcserver/rpcserver.go @@ -26,6 +26,7 @@ import ( "d7y.io/dragonfly/v2/manager/database" "d7y.io/dragonfly/v2/manager/models" "d7y.io/dragonfly/v2/manager/searcher" + "d7y.io/dragonfly/v2/pkg/rpc/auth" managerserver "d7y.io/dragonfly/v2/pkg/rpc/manager/server" ) @@ -59,6 +60,8 @@ func New( searcher: searcher, } + // Provide JWT key from config to manager server via auth package. + auth.SetServerKey("manager", cfg.Auth.JWT.Key) return s, managerserver.New( newManagerServerV1(s.config, database, s.cache, s.searcher), newManagerServerV2(s.config, database, s.cache, s.searcher), diff --git a/pkg/rpc/auth/credentials.go b/pkg/rpc/auth/credentials.go new file mode 100644 index 00000000000..5ccf7f4e184 --- /dev/null +++ b/pkg/rpc/auth/credentials.go @@ -0,0 +1,28 @@ +package auth + +import ( + "context" + + "google.golang.org/grpc/credentials" +) + +// PerRPCCreds attaches a Bearer JWT to outgoing gRPC calls. +type PerRPCCreds struct { + token string + // If needed later, add refresh hooks. +} + +// NewPerRPCCreds constructs credentials with a given token value. +func NewPerRPCCreds(token string) credentials.PerRPCCredentials { + return &PerRPCCreds{token: token} +} + +func (c *PerRPCCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return map[string]string{ + "authorization": "Bearer " + c.token, + }, nil +} + +// RequireTransportSecurity returns false for backward compatibility with existing deployments. +// In production, configure TLS separately via server.TLS config to secure JWT transmission. +func (c *PerRPCCreds) RequireTransportSecurity() bool { return false } diff --git a/pkg/rpc/auth/interceptors.go b/pkg/rpc/auth/interceptors.go new file mode 100644 index 00000000000..53336425e92 --- /dev/null +++ b/pkg/rpc/auth/interceptors.go @@ -0,0 +1,77 @@ +package auth + +import ( + "context" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// UnaryServerJWTInterceptor returns a unary server interceptor that validates JWT in metadata. +func UnaryServerJWTInterceptor(key string, expectedAudience string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + if err := validateFromMetadata(ctx, key, expectedAudience, info.FullMethod); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// StreamServerJWTInterceptor returns a stream server interceptor that validates JWT in metadata. +func StreamServerJWTInterceptor(key string, expectedAudience string) grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := validateFromMetadata(ss.Context(), key, expectedAudience, info.FullMethod); err != nil { + return err + } + return handler(srv, ss) + } +} + +func validateFromMetadata(ctx context.Context, key string, expectedAudience string, method string) error { + // Skip auth for health checks and gRPC reflection to allow probes and debugging tools + if isPublicMethod(method) { + return nil + } + + // If no key is configured, JWT auth is disabled (backward compatible) + if key == "" { + return nil + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "missing metadata") + } + vals := md.Get("authorization") + if len(vals) == 0 { + return status.Error(codes.Unauthenticated, "missing authorization") + } + parts := strings.Fields(vals[0]) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return status.Error(codes.Unauthenticated, "invalid authorization header") + } + token := parts[1] + if _, err := ValidateHS256(key, token, expectedAudience); err != nil { + return status.Error(codes.Unauthenticated, err.Error()) + } + return nil +} + +// isPublicMethod determines if a gRPC method should bypass JWT authentication. +// Health checks and reflection services are exempt to support infrastructure probes and debugging. +func isPublicMethod(method string) bool { + publicPrefixes := []string{ + "/grpc.health.v1.Health/", + "/grpc.reflection.v1alpha.ServerReflection/", + "/grpc.reflection.v1.ServerReflection/", + } + for _, prefix := range publicPrefixes { + if strings.HasPrefix(method, prefix) { + return true + } + } + return false +} diff --git a/pkg/rpc/auth/jwt.go b/pkg/rpc/auth/jwt.go new file mode 100644 index 00000000000..ecceea2275b --- /dev/null +++ b/pkg/rpc/auth/jwt.go @@ -0,0 +1,122 @@ +package auth + +import ( + "errors" + "sync" + "time" + + jwtlib "github.com/golang-jwt/jwt/v5" +) + +// Claims is a minimal JWT claims set used for inter-component gRPC auth. +type Claims struct { + Issuer string `json:"iss"` + Audience string `json:"aud"` + IssuedAt time.Time `json:"iat"` + Expires time.Time `json:"exp"` +} + +// Global registry for per-component server keys with thread-safe access. +// This allows components to register their JWT keys at startup for use in interceptors. +var ( + serverKeysMu sync.RWMutex + serverKeys = map[string]string{} +) + +// SetServerKey sets the shared signing key for a component's server (e.g., "manager", "scheduler"). +func SetServerKey(component, key string) { + serverKeysMu.Lock() + defer serverKeysMu.Unlock() + serverKeys[component] = key +} + +// GetServerKey retrieves the key for a component server. +func GetServerKey(component string) string { + serverKeysMu.RLock() + defer serverKeysMu.RUnlock() + return serverKeys[component] +} + +// SignHS256 signs the provided claims with the given shared secret key using HS256. +func SignHS256(key string, c Claims) (string, error) { + if key == "" { + return "", errors.New("jwt: empty signing key") + } + claims := jwtlib.MapClaims{ + "iss": c.Issuer, + "aud": c.Audience, + "iat": c.IssuedAt.Unix(), + "exp": c.Expires.Unix(), + } + token := jwtlib.NewWithClaims(jwtlib.SigningMethodHS256, claims) + return token.SignedString([]byte(key)) +} + +// ValidateHS256 validates token signature and basic claims. Returns parsed claims. +func ValidateHS256(key string, tokenStr string, expectedAudience string) (Claims, error) { + var out Claims + if key == "" { + return out, errors.New("jwt: empty validation key") + } + parser := jwtlib.NewParser(jwtlib.WithValidMethods([]string{jwtlib.SigningMethodHS256.Alg()})) + token, err := parser.Parse(tokenStr, func(t *jwtlib.Token) (any, error) { + return []byte(key), nil + }) + if err != nil || !token.Valid { + return out, errors.New("jwt: invalid token") + } + claims, ok := token.Claims.(jwtlib.MapClaims) + if !ok { + return out, errors.New("jwt: invalid claims type") + } + // Audience check + if audAny, ok := claims["aud"]; ok { + if audStr, ok := audAny.(string); ok { + if expectedAudience != "" && audStr != expectedAudience { + return out, errors.New("jwt: audience mismatch") + } + out.Audience = audStr + } + } + // Issuer + if issAny, ok := claims["iss"]; ok { + if issStr, ok := issAny.(string); ok { + out.Issuer = issStr + } + } + // Time checks + now := time.Now() + if expAny, ok := claims["exp"]; ok { + switch v := expAny.(type) { + case float64: + out.Expires = time.Unix(int64(v), 0) + case int64: + out.Expires = time.Unix(v, 0) + case uint64: + out.Expires = time.Unix(int64(v), 0) + } + if now.After(out.Expires) { + return out, errors.New("jwt: token expired") + } + } + if iatAny, ok := claims["iat"]; ok { + switch v := iatAny.(type) { + case float64: + out.IssuedAt = time.Unix(int64(v), 0) + case int64: + out.IssuedAt = time.Unix(v, 0) + } + } + return out, nil +} + +// DurationClaims constructs Claims with given ttl. +func DurationClaims(issuer, audience string, ttl time.Duration) Claims { + now := time.Now() + return Claims{ + Issuer: issuer, + Audience: audience, + IssuedAt: now, + Expires: now.Add(ttl), + } +} diff --git a/pkg/rpc/manager/server/server.go b/pkg/rpc/manager/server/server.go index 2349384d8fd..01fca47b299 100644 --- a/pkg/rpc/manager/server/server.go +++ b/pkg/rpc/manager/server/server.go @@ -39,6 +39,7 @@ import ( logger "d7y.io/dragonfly/v2/internal/dflog" "d7y.io/dragonfly/v2/pkg/rpc" + "d7y.io/dragonfly/v2/pkg/rpc/auth" ) const ( @@ -55,6 +56,7 @@ const ( // New returns grpc server instance and register service on grpc server. func New(managerServerV1 managerv1.ManagerServer, managerServerV2 managerv2.ManagerServer, requestRateLimit float64, opts ...grpc.ServerOption) *grpc.Server { limiter := rpc.NewRateLimiterInterceptor(requestRateLimit, int64(requestRateLimit)) + jwtKey := auth.GetServerKey("manager") grpcServer := grpc.NewServer(append([]grpc.ServerOption{ grpc.MaxRecvMsgSize(math.MaxInt32), @@ -69,6 +71,7 @@ func New(managerServerV1 managerv1.ManagerServer, managerServerV2 managerv2.Mana MaxConnectionAgeGrace: DefaultMaxConnectionAgeGrace, }), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + auth.UnaryServerJWTInterceptor(jwtKey, "manager"), grpc_ratelimit.UnaryServerInterceptor(limiter), grpc_prometheus.UnaryServerInterceptor, grpc_zap.UnaryServerInterceptor(logger.GrpcLogger.Desugar()), @@ -76,6 +79,7 @@ func New(managerServerV1 managerv1.ManagerServer, managerServerV2 managerv2.Mana grpc_recovery.UnaryServerInterceptor(), )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + auth.StreamServerJWTInterceptor(jwtKey, "manager"), grpc_ratelimit.StreamServerInterceptor(limiter), grpc_prometheus.StreamServerInterceptor, grpc_zap.StreamServerInterceptor(logger.GrpcLogger.Desugar()), diff --git a/pkg/rpc/scheduler/server/server.go b/pkg/rpc/scheduler/server/server.go index e0cea79e33b..4317889801a 100644 --- a/pkg/rpc/scheduler/server/server.go +++ b/pkg/rpc/scheduler/server/server.go @@ -39,6 +39,7 @@ import ( logger "d7y.io/dragonfly/v2/internal/dflog" "d7y.io/dragonfly/v2/pkg/rpc" + "d7y.io/dragonfly/v2/pkg/rpc/auth" ) const ( @@ -55,6 +56,7 @@ const ( // New returns a grpc server instance and register service on grpc server. func New(schedulerServerV1 schedulerv1.SchedulerServer, schedulerServerV2 schedulerv2.SchedulerServer, requestRateLimit float64, opts ...grpc.ServerOption) *grpc.Server { limiter := rpc.NewRateLimiterInterceptor(requestRateLimit, int64(requestRateLimit)) + jwtKey := auth.GetServerKey("scheduler") grpcServer := grpc.NewServer(append([]grpc.ServerOption{ grpc.MaxRecvMsgSize(math.MaxInt32), @@ -69,6 +71,7 @@ func New(schedulerServerV1 schedulerv1.SchedulerServer, schedulerServerV2 schedu MaxConnectionAgeGrace: DefaultMaxConnectionAgeGrace, }), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + auth.UnaryServerJWTInterceptor(jwtKey, "scheduler"), grpc_ratelimit.UnaryServerInterceptor(limiter), rpc.ConvertErrorUnaryServerInterceptor, grpc_prometheus.UnaryServerInterceptor, @@ -77,6 +80,7 @@ func New(schedulerServerV1 schedulerv1.SchedulerServer, schedulerServerV2 schedu grpc_recovery.UnaryServerInterceptor(), )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + auth.StreamServerJWTInterceptor(jwtKey, "scheduler"), grpc_ratelimit.StreamServerInterceptor(limiter), rpc.ConvertErrorStreamServerInterceptor, grpc_prometheus.StreamServerInterceptor, diff --git a/scheduler/config/config.go b/scheduler/config/config.go index 112b9babb62..6d0e87cd909 100644 --- a/scheduler/config/config.go +++ b/scheduler/config/config.go @@ -64,6 +64,9 @@ type Config struct { // Network configuration. Network NetworkConfig `yaml:"network" mapstructure:"network"` + + // Auth configuration. + Auth AuthConfig `yaml:"auth" mapstructure:"auth"` } type ServerConfig struct { @@ -317,6 +320,32 @@ type NetworkConfig struct { EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"` } +type AuthConfig struct { + // JWT configuration. + JWT JWTConfig `yaml:"jwt" mapstructure:"jwt"` +} + +type JWTConfig struct { + // Realm name to display to the user, default value is Dragonfly. + Realm string `yaml:"realm" mapstructure:"realm"` + + // Key is the secret key used for signing JWT tokens. + // SECURITY: This key must be kept secret and should be loaded from a secure secret store in production + // (e.g., HashiCorp Vault, AWS Secrets Manager, Kubernetes Secrets, or environment variables with restricted access). + // Use a strong random key (minimum 32 bytes recommended). + // Example generation: openssl rand -base64 32 + // If empty, JWT authentication is disabled (not recommended for production). + Key string `yaml:"key" mapstructure:"key"` + + // Timeout is the duration that a JWT token remains valid. + // For inter-component authentication, use a longer duration (e.g., 24h). + // For user-facing authentication, use a shorter duration (e.g., 2h). + Timeout time.Duration `yaml:"timeout" mapstructure:"timeout"` + + // MaxRefresh allows clients to refresh their token until MaxRefresh has passed. + MaxRefresh time.Duration `yaml:"maxRefresh" mapstructure:"maxRefresh"` +} + // New default configuration. func New() *Config { return &Config{ @@ -389,6 +418,15 @@ func New() *Config { Network: NetworkConfig{ EnableIPv6: DefaultNetworkEnableIPv6, }, + Auth: AuthConfig{ + JWT: JWTConfig{ + Realm: "Dragonfly", + // Default timeout of 24 hours for inter-component authentication. + // Tokens are long-lived since services are trusted and restart frequently. + Timeout: 24 * time.Hour, + MaxRefresh: 12 * time.Hour, + }, + }, } } @@ -552,6 +590,19 @@ func (cfg *Config) Validate() error { } } + // Auth validation: only validate JWT fields if a key is configured (JWT is optional for backward compatibility) + if cfg.Auth.JWT.Key != "" { + if cfg.Auth.JWT.Realm == "" { + return errors.New("jwt requires parameter realm when key is set") + } + if cfg.Auth.JWT.Timeout == 0 { + return errors.New("jwt requires parameter timeout when key is set") + } + if cfg.Auth.JWT.MaxRefresh == 0 { + return errors.New("jwt requires parameter maxRefresh when key is set") + } + } + return nil } diff --git a/scheduler/rpcserver/rpcserver.go b/scheduler/rpcserver/rpcserver.go index d3c61f6434b..8bc163fd334 100644 --- a/scheduler/rpcserver/rpcserver.go +++ b/scheduler/rpcserver/rpcserver.go @@ -19,6 +19,7 @@ package rpcserver import ( "google.golang.org/grpc" + "d7y.io/dragonfly/v2/pkg/rpc/auth" "d7y.io/dragonfly/v2/pkg/rpc/scheduler/server" "d7y.io/dragonfly/v2/scheduler/config" "d7y.io/dragonfly/v2/scheduler/job" @@ -39,6 +40,8 @@ func New( dynconfig config.DynconfigInterface, opts ...grpc.ServerOption, ) *grpc.Server { + // Provide JWT key from config to scheduler server via auth package. + auth.SetServerKey("scheduler", cfg.Auth.JWT.Key) return server.New( newSchedulerServerV1(cfg, resource, scheduling, dynconfig), newSchedulerServerV2(cfg, resource, persistentResource, persistentCacheResource, scheduling, job, dynconfig), diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 609c9b18a6d..8549a8477d4 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -36,7 +36,9 @@ import ( "d7y.io/dragonfly/v2/pkg/gc" pkgredis "d7y.io/dragonfly/v2/pkg/redis" "d7y.io/dragonfly/v2/pkg/rpc" + "d7y.io/dragonfly/v2/pkg/rpc/auth" managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" + "d7y.io/dragonfly/v2/pkg/types" "d7y.io/dragonfly/v2/scheduler/announcer" "d7y.io/dragonfly/v2/scheduler/config" "d7y.io/dragonfly/v2/scheduler/job" @@ -96,6 +98,17 @@ func New(ctx context.Context, cfg *config.Config, d dfpath.Dfpath) (*Server, err // Initialize dial options of manager grpc client. managerDialOptions := []grpc.DialOption{grpc.WithStatsHandler(otelgrpc.NewClientHandler())} + // Attach JWT per-RPC creds for inter-component calls if a key is provided. + if key := cfg.Auth.JWT.Key; key != "" { + // Use configured JWT timeout instead of hardcoded value to match server validation expectations. + claims := auth.DurationClaims(types.SchedulerName, types.ManagerName, cfg.Auth.JWT.Timeout) + token, err := auth.SignHS256(key, claims) + if err != nil { + logger.Errorf("failed to sign JWT for manager client: %v", err) + return nil, err + } + managerDialOptions = append(managerDialOptions, grpc.WithPerRPCCredentials(auth.NewPerRPCCreds(token))) + } if cfg.Manager.TLS != nil { clientTransportCredentials, err := rpc.NewClientCredentials(cfg.Manager.TLS.CACert, cfg.Manager.TLS.Cert, cfg.Manager.TLS.Key) if err != nil {