diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java index c2b47e6e944f..61f055e99b9a 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java @@ -190,7 +190,7 @@ private static OAuthTokenResponse refreshToken( config.credential(), config.scope(), config.oauth2ServerUri(), - ImmutableMap.of()); + optionalOAuthParams); } } @@ -563,7 +563,7 @@ private OAuthTokenResponse refreshExpiredToken(RESTClient client) { client, config, basicHeaders, token(), tokenType(), optionalOAuthParams()); } else { return fetchToken( - client, Map.of(), credential(), scope(), oauth2ServerUri(), ImmutableMap.of()); + client, Map.of(), credential(), scope(), oauth2ServerUri(), optionalOAuthParams()); } } diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Util.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Util.java index fbcb87fb06e2..b2991870e35e 100644 --- a/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Util.java +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Util.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.responses.OAuthTokenResponse; import org.junit.jupiter.api.Test; @@ -135,4 +136,79 @@ public void testCredentialFlowForSessionRefresh() throws IOException { any()); } } + + @Test + void refreshExpiredTokenShouldIncludeOptionalOAuthParams() throws IOException { + // expired token triggers refreshExpiredToken() which should forward optionalOAuthParams + String audience = "https://my-catalog.example.com"; + AuthConfig authConfig = + ImmutableAuthConfig.builder() + .keepRefreshed(true) + .exchangeEnabled(false) + .token("expired_token") + .credential("testClientId:testClientSecret") + .oauth2ServerUri("/v1/token") + .expiresAtMillis(System.currentTimeMillis() - 10_000) + .optionalOAuthParams(ImmutableMap.of("audience", audience, "scope", "catalog")) + .build(); + + OAuthTokenResponse response = + OAuthTokenResponse.builder().withToken("refreshed_token").withTokenType(BEARER).build(); + + try (RESTClient client = Mockito.mock(RESTClient.class); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(Map.of(), authConfig)) { + Mockito.when(client.postForm(any(), anyMap(), any(), anyMap(), any())).thenReturn(response); + + session.refresh(client); + + Mockito.verify(client) + .postForm( + any(), + argThat( + formData -> + CLIENT_CREDENTIALS.equals(formData.get(GRANT_TYPE)) + && audience.equals(formData.get("audience")) + && "catalog".equals(formData.get("scope"))), + Mockito.eq(OAuthTokenResponse.class), + anyMap(), + any()); + } + } + + @Test + void refreshCurrentTokenNonExchangeShouldIncludeOptionalOAuthParams() throws IOException { + String audience = "https://my-catalog.example.com"; + AuthConfig authConfig = + ImmutableAuthConfig.builder() + .keepRefreshed(true) + .exchangeEnabled(false) + .token("valid_token") + .credential("testClientId:testClientSecret") + .oauth2ServerUri("/v1/token") + .expiresAtMillis(System.currentTimeMillis() + 300_000) + .optionalOAuthParams(ImmutableMap.of("audience", audience, "scope", "catalog")) + .build(); + + OAuthTokenResponse response = + OAuthTokenResponse.builder().withToken("refreshed_token").withTokenType(BEARER).build(); + + try (RESTClient client = Mockito.mock(RESTClient.class); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(Map.of(), authConfig)) { + Mockito.when(client.postForm(any(), anyMap(), any(), anyMap(), any())).thenReturn(response); + + session.refresh(client); + + Mockito.verify(client) + .postForm( + any(), + argThat( + formData -> + CLIENT_CREDENTIALS.equals(formData.get(GRANT_TYPE)) + && audience.equals(formData.get("audience")) + && "catalog".equals(formData.get("scope"))), + Mockito.eq(OAuthTokenResponse.class), + anyMap(), + any()); + } + } }