Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 142 additions & 5 deletions pro/auth/azure-ad.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"io"
"net/http"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
Expand All @@ -19,6 +21,8 @@ import (
"golang.org/x/oauth2/microsoft"
)

const AzureAD_TIMEOUT = 10 * time.Second

var azure_ad_functions = map[string]interface{}{
init_provider: initAzureAD,
get_user_info: getAzureUserInfo,
Expand All @@ -27,15 +31,31 @@ var azure_ad_functions = map[string]interface{}{
verify_user: verifyAzureUser,
}

var azure_ad_verifier *oidc.IDTokenVerifier

// == handle azure ad authentication here ==

func initAzureAD(redirectURL string, clientID string, clientSecret string) {
tenantID := logic.GetAzureTenant()
if tenantID != "" {
ctx, cancel := context.WithTimeout(context.Background(), AzureAD_TIMEOUT)
defer cancel()

issuer := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", tenantID)
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
logger.Log(1, "error when initializing Azure AD OIDC provider:", err.Error())
} else {
azure_ad_verifier = provider.Verifier(&oidc.Config{ClientID: clientID})
}
}

auth_provider = &oauth2.Config{
RedirectURL: redirectURL,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: []string{"User.Read", "email", "profile", "openid"},
Endpoint: microsoft.AzureADEndpoint(logic.GetAzureTenant()),
Endpoint: microsoft.AzureADEndpoint(tenantID),
}
}

Expand Down Expand Up @@ -68,7 +88,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
return
}

content, err := getAzureUserInfo(rState, rCode)
azureInfo, err := getAzureUserInfoWithToken(rState, rCode)
if err != nil {
logger.Log(1, "error when getting user info from azure:", err.Error())
if strings.Contains(err.Error(), "invalid oauth state") || strings.Contains(err.Error(), "failed to fetch user email from SSO state") {
Expand All @@ -78,6 +98,8 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
handleOauthNotConfigured(w)
return
}
content := azureInfo.OAuthUser

var inviteExists bool
// check if invite exists for User
in, err := logic.GetUserInvite(content.Email)
Expand Down Expand Up @@ -165,6 +187,16 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
return
}

// Check device claims if user has ExternalIdentityProviderID set (synced from IDP)
// Validate device authorization - allow if device is registered/compliant, block if not
if user.ExternalIdentityProviderID != "" {
if err := checkDeviceClaims(azureInfo.RawIDToken); err != nil {
logger.Log(1, "Device authorization check failed for user with ExternalIdentityProviderID:", err.Error())
handleDeviceClaimsMissing(w)
return
}
}

if user.AccountDisabled {
handleUserAccountDisabled(w)
return
Expand Down Expand Up @@ -214,15 +246,37 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
}

// AzureUserInfo extends OAuthUser with ID token for device claims verification
type AzureUserInfo struct {
*OAuthUser
RawIDToken string
}

func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
azureInfo, err := getAzureUserInfoWithToken(state, code)
if err != nil {
return nil, err
}
return azureInfo.OAuthUser, nil
}

func getAzureUserInfoWithToken(state string, code string) (*AzureUserInfo, error) {
oauth_state_string, isValid := logic.IsStateValid(state)
if (!isValid || state != oauth_state_string) && !isStateCached(state) {
return nil, fmt.Errorf("invalid oauth state")
}
var token, err = auth_provider.Exchange(context.Background(), code, oauth2.SetAuthURLParam("prompt", "login"))

ctx, cancel := context.WithTimeout(context.Background(), AzureAD_TIMEOUT)
defer cancel()

var token, err = auth_provider.Exchange(ctx, code, oauth2.SetAuthURLParam("prompt", "login"))
if err != nil {
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
}

// Extract raw ID token for later device claims verification
rawIDToken, _ := token.Extra("id_token").(string)

var data []byte
data, err = json.Marshal(token)
if err != nil {
Expand Down Expand Up @@ -256,9 +310,92 @@ func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
}
if userInfo.Email == "" {
err = errors.New("failed to fetch user email from SSO state")
return userInfo, err
return &AzureUserInfo{OAuthUser: userInfo, RawIDToken: rawIDToken}, err
}
return userInfo, nil
return &AzureUserInfo{OAuthUser: userInfo, RawIDToken: rawIDToken}, nil
}

// checkDeviceClaims validates device authorization from device claims in the ID token
// Returns an error if device claims are present but device is NOT authorized (to block authentication)
// Returns nil if device claims are NOT present OR device IS authorized (to allow authentication)
func checkDeviceClaims(rawIDToken string) error {
if azure_ad_verifier == nil {
// If verifier not available, allow authentication
return nil
}

if rawIDToken == "" {
// If ID token not available, allow authentication
return nil
}

ctx, cancel := context.WithTimeout(context.Background(), AzureAD_TIMEOUT)
defer cancel()

idToken, err := azure_ad_verifier.Verify(ctx, rawIDToken)
if err != nil {
// If token verification fails, allow authentication
logger.Log(1, "Failed to verify Azure AD ID token for device claims check:", err.Error())
return nil
}

var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
// If claims extraction fails, allow authentication
logger.Log(1, "Failed to extract claims from ID token:", err.Error())
return nil
}

// Check for device claims in the ID token
hasDeviceClaims := false
deviceID := ""
deviceRegStatus := ""
fmt.Printf("==> DEVICE CLAIMS: %+v\n", claims)
// Extract device information from claims
if val, exists := claims["deviceid"]; exists && val != nil {
deviceID = fmt.Sprintf("%v", val)
hasDeviceClaims = true
logger.Log(3, fmt.Sprintf("Found deviceid claim: %s", deviceID))
}

if val, exists := claims["deviceregstatus"]; exists && val != nil {
deviceRegStatus = fmt.Sprintf("%v", val)
logger.Log(3, fmt.Sprintf("Found deviceregstatus claim: %s", deviceRegStatus))
}

// If device claims ARE present, validate device authorization
if hasDeviceClaims {
// Check device registration status - allow if device is registered/compliant
// Azure AD sets deviceregstatus to values like "Registered", "Compliant", etc.
// If status indicates the device is registered/compliant, allow authentication
if deviceRegStatus != "" {
// Normalize status to lowercase for comparison
statusLower := strings.ToLower(deviceRegStatus)
// Allow if device is registered, compliant, or managed (authorized devices)
if strings.Contains(statusLower, "registered") ||
strings.Contains(statusLower, "compliant") ||
strings.Contains(statusLower, "managed") {
logger.Log(3, fmt.Sprintf("Device authorization check: Device %s has status '%s' - device is authorized, allowing authentication", deviceID, deviceRegStatus))
return nil
} else {
// Device claims present but device is NOT registered/compliant - block authentication
logger.Log(1, fmt.Sprintf("Device authorization check: Device %s has status '%s' - device is not authorized, blocking authentication", deviceID, deviceRegStatus))
return fmt.Errorf("device claims found but device is not registered/compliant - authentication not allowed")
}
} else if compliant, ok := claims["xms_compliant"].(string); ok {
logger.Log(3, "Azure Device: id=%s compliant=%s", deviceID, compliant)
if compliant != "true" {
return errors.New("access denied: device not compliant")
}
} else if deviceID != "" {
// If device ID exists but no status, allow authentication (device is registered)
logger.Log(3, fmt.Sprintf("Device authorization check: Device %s found but no registration status - allowing authentication", deviceID))
return nil
}
}

// If device claims are NOT present, allow authentication
return nil
}

func verifyAzureUser(token *oauth2.Token) bool {
Expand Down
9 changes: 9 additions & 0 deletions pro/auth/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ var authTypeMismatch = fmt.Sprintf(htmlBaseTemplate, `<h2>It looks like you alre

var userAccountDisabled = fmt.Sprintf(htmlBaseTemplate, `<h2>Your account has been disabled. Please contact your administrator for more information about your account.</h2>`)

var deviceClaimsMissing = fmt.Sprintf(htmlBaseTemplate, `<h2>Device authentication required.</h2>
<p>Your organization requires device-based authentication. Please sign in from a registered Azure AD device.</p>`)

func handleOauthUserNotFound(response http.ResponseWriter) {
response.Header().Set("Content-Type", "text/html; charset=utf-8")
response.WriteHeader(http.StatusNotFound)
Expand Down Expand Up @@ -174,3 +177,9 @@ func handleUserAccountDisabled(response http.ResponseWriter) {
response.WriteHeader(http.StatusUnauthorized)
response.Write([]byte(userAccountDisabled))
}

func handleDeviceClaimsMissing(response http.ResponseWriter) {
response.Header().Set("Content-Type", "text/html; charset=utf-8")
response.WriteHeader(http.StatusForbidden)
response.Write([]byte(deviceClaimsMissing))
}