Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ linters:
- "!**/pkg/auth/factory/**"
- "!**/pkg/auth/types/aws_credentials.go"
- "!**/pkg/auth/types/github_oidc_credentials.go"
- "!**/internal/aws_utils/**"
- "$test"
deny:
# AWS: Identity and auth-related SDKs
Expand Down
1 change: 1 addition & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ var (
ErrInvalidTerraformSingleComponentAndMultiComponentFlags = errors.New("the single-component flags (`--from-plan`, `--planfile`) can't be used with the multi-component (bulk operations) flags (`--affected`, `--all`, `--query`, `--components`)")

ErrYamlFuncInvalidArguments = errors.New("invalid number of arguments in the Atmos YAML function")
ErrAwsGetCallerIdentity = errors.New("failed to get AWS caller identity")
ErrDescribeComponent = errors.New("failed to describe component")
ErrReadTerraformState = errors.New("failed to read Terraform state")
ErrEvaluateTerraformBackendVariable = errors.New("failed to evaluate terraform backend variable")
Expand Down
53 changes: 52 additions & 1 deletion internal/aws_utils/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func LoadAWSConfigWithAuth(
baseCfg, err := config.LoadDefaultConfig(ctx, cfgOpts...)
if err != nil {
log.Debug("Failed to load AWS config", "error", err)
return aws.Config{}, fmt.Errorf("%w: %v", errUtils.ErrLoadAwsConfig, err)
return aws.Config{}, fmt.Errorf("%w: %w", errUtils.ErrLoadAwsConfig, err)
}
log.Debug("Successfully loaded AWS SDK config", "region", baseCfg.Region)

Expand Down Expand Up @@ -126,3 +126,54 @@ func LoadAWSConfig(ctx context.Context, region string, roleArn string, assumeRol

return LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, nil)
}

// AWSCallerIdentityResult holds the result of GetAWSCallerIdentity.
type AWSCallerIdentityResult struct {
Account string
Arn string
UserID string
Region string
}

// GetAWSCallerIdentity retrieves AWS caller identity using STS GetCallerIdentity API.
// Returns account ID, ARN, user ID, and region.
// This function keeps AWS SDK STS imports contained within aws_utils package.
func GetAWSCallerIdentity(
ctx context.Context,
region string,
roleArn string,
assumeRoleDuration time.Duration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentityResult, error) {
defer perf.Track(nil, "aws_utils.GetAWSCallerIdentity")()

// Load AWS config.
cfg, err := LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, authContext)
if err != nil {
return nil, err
}

// Create STS client and get caller identity.
stsClient := sts.NewFromConfig(cfg)
output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, fmt.Errorf("%w: %w", errUtils.ErrAwsGetCallerIdentity, err)
}

result := &AWSCallerIdentityResult{
Region: cfg.Region,
}

// Extract values from pointers.
if output.Account != nil {
result.Account = *output.Account
}
if output.Arn != nil {
result.Arn = *output.Arn
}
if output.UserId != nil {
result.UserID = *output.UserId
}

return result, nil
}
164 changes: 164 additions & 0 deletions internal/exec/aws_getter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package exec

import (
"context"
"fmt"
"sync"

awsUtils "github.com/cloudposse/atmos/internal/aws_utils"
log "github.com/cloudposse/atmos/pkg/logger"
"github.com/cloudposse/atmos/pkg/perf"
"github.com/cloudposse/atmos/pkg/schema"
)

// AWSCallerIdentity holds the information returned by AWS STS GetCallerIdentity.
type AWSCallerIdentity struct {
Account string
Arn string
UserID string
Region string // The AWS region from the loaded config.
}

// AWSGetter provides an interface for retrieving AWS caller identity information.
// This interface enables dependency injection and testability.
//
//go:generate go run go.uber.org/mock/[email protected] -source=$GOFILE -destination=mock_aws_getter_test.go -package=exec
type AWSGetter interface {
// GetCallerIdentity retrieves the AWS caller identity for the current credentials.
// Returns the account ID, ARN, and user ID of the calling identity.
GetCallerIdentity(
ctx context.Context,
atmosConfig *schema.AtmosConfiguration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentity, error)
}

// defaultAWSGetter is the production implementation that uses real AWS SDK calls.
type defaultAWSGetter struct{}

// GetCallerIdentity retrieves the AWS caller identity using the STS GetCallerIdentity API.
func (d *defaultAWSGetter) GetCallerIdentity(
ctx context.Context,
atmosConfig *schema.AtmosConfiguration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentity, error) {
defer perf.Track(atmosConfig, "exec.AWSGetter.GetCallerIdentity")()

log.Debug("Getting AWS caller identity")

// Use the aws_utils helper to get caller identity (keeps AWS SDK imports in aws_utils).
result, err := awsUtils.GetAWSCallerIdentity(ctx, "", "", 0, authContext)
if err != nil {
return nil, err // Error already wrapped by aws_utils.
}

identity := &AWSCallerIdentity{
Account: result.Account,
Arn: result.Arn,
UserID: result.UserID,
Region: result.Region,
}

log.Debug("Retrieved AWS caller identity",
"account", identity.Account,
"arn", identity.Arn,
"user_id", identity.UserID,
"region", identity.Region,
)

return identity, nil
}

// awsGetter is the global instance used by YAML functions.
// This allows test code to replace it with a mock.
var awsGetter AWSGetter = &defaultAWSGetter{}

// SetAWSGetter allows tests to inject a mock AWSGetter.
// Returns a function to restore the original getter.
func SetAWSGetter(getter AWSGetter) func() {
defer perf.Track(nil, "exec.SetAWSGetter")()

original := awsGetter
awsGetter = getter
return func() {
awsGetter = original
}
}

// cachedAWSIdentity holds the cached AWS caller identity.
// The cache is per-CLI-invocation (stored in memory) to avoid repeated STS calls.
type cachedAWSIdentity struct {
identity *AWSCallerIdentity
err error
}

var (
awsIdentityCache map[string]*cachedAWSIdentity
awsIdentityCacheMu sync.RWMutex
)

func init() {
awsIdentityCache = make(map[string]*cachedAWSIdentity)
}

// getCacheKey generates a cache key based on the auth context.
// Different auth contexts (different credentials) get different cache entries.
// Includes Profile, CredentialsFile, and ConfigFile since all three affect AWS config loading.
func getCacheKey(authContext *schema.AWSAuthContext) string {
if authContext == nil {
return "default"
}
return fmt.Sprintf("%s:%s:%s", authContext.Profile, authContext.CredentialsFile, authContext.ConfigFile)
}

// getAWSCallerIdentityCached retrieves the AWS caller identity with caching.
// Results are cached per auth context to avoid repeated STS calls within the same CLI invocation.
func getAWSCallerIdentityCached(
ctx context.Context,
atmosConfig *schema.AtmosConfiguration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentity, error) {
defer perf.Track(atmosConfig, "exec.getAWSCallerIdentityCached")()

cacheKey := getCacheKey(authContext)

// Check cache first (read lock).
awsIdentityCacheMu.RLock()
if cached, ok := awsIdentityCache[cacheKey]; ok {
awsIdentityCacheMu.RUnlock()
log.Debug("Using cached AWS caller identity", "cache_key", cacheKey)
return cached.identity, cached.err
}
awsIdentityCacheMu.RUnlock()

// Cache miss - acquire write lock and fetch.
awsIdentityCacheMu.Lock()
defer awsIdentityCacheMu.Unlock()

// Double-check after acquiring write lock.
if cached, ok := awsIdentityCache[cacheKey]; ok {
log.Debug("Using cached AWS caller identity (double-check)", "cache_key", cacheKey)
return cached.identity, cached.err
}

// Fetch from AWS.
identity, err := awsGetter.GetCallerIdentity(ctx, atmosConfig, authContext)

// Cache the result (including errors to avoid repeated failed calls).
awsIdentityCache[cacheKey] = &cachedAWSIdentity{
identity: identity,
err: err,
}

return identity, err
}

// ClearAWSIdentityCache clears the AWS identity cache.
// This is useful in tests or when credentials change during execution.
func ClearAWSIdentityCache() {
defer perf.Track(nil, "exec.ClearAWSIdentityCache")()

awsIdentityCacheMu.Lock()
defer awsIdentityCacheMu.Unlock()
awsIdentityCache = make(map[string]*cachedAWSIdentity)
}
151 changes: 151 additions & 0 deletions internal/exec/yaml_func_aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package exec

import (
"context"

errUtils "github.com/cloudposse/atmos/errors"
log "github.com/cloudposse/atmos/pkg/logger"
"github.com/cloudposse/atmos/pkg/perf"
"github.com/cloudposse/atmos/pkg/schema"
u "github.com/cloudposse/atmos/pkg/utils"
)

const (
execAWSYAMLFunction = "Executing Atmos YAML function"
invalidYAMLFunction = "Invalid YAML function"
failedGetIdentity = "Failed to get AWS caller identity"
functionKey = "function"
)

// processTagAwsValue is a shared helper for AWS YAML functions.
// It validates the input tag, retrieves AWS caller identity, and returns the requested value.
func processTagAwsValue(
atmosConfig *schema.AtmosConfiguration,
input string,
expectedTag string,
stackInfo *schema.ConfigAndStacksInfo,
extractor func(*AWSCallerIdentity) string,
) any {
log.Debug(execAWSYAMLFunction, functionKey, input)

// Validate the tag matches expected.
if input != expectedTag {
log.Error(invalidYAMLFunction, functionKey, input, "expected", expectedTag)
errUtils.CheckErrorPrintAndExit(errUtils.ErrYamlFuncInvalidArguments, "", "")
return nil
}

// Get auth context from stack info if available.
var authContext *schema.AWSAuthContext
if stackInfo != nil && stackInfo.AuthContext != nil && stackInfo.AuthContext.AWS != nil {
authContext = stackInfo.AuthContext.AWS
}

// Get the AWS caller identity (cached).
ctx := context.Background()
identity, err := getAWSCallerIdentityCached(ctx, atmosConfig, authContext)
if err != nil {
log.Error(failedGetIdentity, "error", err)
errUtils.CheckErrorPrintAndExit(err, "", "")
return nil
}

// Extract the requested value.
return extractor(identity)
}

// processTagAwsAccountID processes the !aws.account_id YAML function.
// It returns the AWS account ID of the current caller identity.
// The function takes no parameters.
//
// Usage in YAML:
//
// account_id: !aws.account_id
func processTagAwsAccountID(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsAccountID")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsAccountID, stackInfo, func(id *AWSCallerIdentity) string {
return id.Account
})

if result != nil {
log.Debug("Resolved !aws.account_id", "account_id", result)
}
return result
}

// processTagAwsCallerIdentityArn processes the !aws.caller_identity_arn YAML function.
// It returns the ARN of the current AWS caller identity.
// The function takes no parameters.
//
// Usage in YAML:
//
// caller_arn: !aws.caller_identity_arn
func processTagAwsCallerIdentityArn(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsCallerIdentityArn")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsCallerIdentityArn, stackInfo, func(id *AWSCallerIdentity) string {
return id.Arn
})

if result != nil {
log.Debug("Resolved !aws.caller_identity_arn", "arn", result)
}
return result
}

// processTagAwsCallerIdentityUserID processes the !aws.caller_identity_user_id YAML function.
// It returns the unique user ID of the current AWS caller identity.
// The function takes no parameters.
//
// Usage in YAML:
//
// user_id: !aws.caller_identity_user_id
func processTagAwsCallerIdentityUserID(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsCallerIdentityUserID")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsCallerIdentityUserID, stackInfo, func(id *AWSCallerIdentity) string {
return id.UserID
})

if result != nil {
log.Debug("Resolved !aws.caller_identity_user_id", "user_id", result)
}
return result
}

// processTagAwsRegion processes the !aws.region YAML function.
// It returns the AWS region from the current configuration.
// The function takes no parameters.
//
// Usage in YAML:
//
// region: !aws.region
func processTagAwsRegion(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsRegion")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsRegion, stackInfo, func(id *AWSCallerIdentity) string {
return id.Region
})

if result != nil {
log.Debug("Resolved !aws.region", "region", result)
}
return result
}
Loading
Loading