3232package com .google .auth .oauth2 ;
3333
3434import com .google .api .client .http .GenericUrl ;
35+ import com .google .api .client .http .HttpContent ;
36+ import com .google .api .client .http .HttpHeaders ;
37+ import com .google .api .client .http .HttpMethods ;
3538import com .google .api .client .http .HttpRequest ;
3639import com .google .api .client .http .HttpRequestFactory ;
3740import com .google .api .client .http .HttpResponse ;
4851import java .util .Map ;
4952import java .util .regex .Matcher ;
5053import java .util .regex .Pattern ;
54+ import javax .annotation .Nullable ;
5155
5256/**
5357 * AWS credentials representing a third-party identity for calling Google APIs.
5660 */
5761public class AwsCredentials extends ExternalAccountCredentials {
5862
63+ static final String AWS_IMDSV2_SESSION_TOKEN_HEADER = "x-aws-ec2-metadata-token" ;
64+ static final String AWS_IMDSV2_SESSION_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds" ;
65+ static final String AWS_IMDSV2_SESSION_TOKEN_TTL = "300" ;
66+
5967 /**
6068 * The AWS credential source. Stores data required to retrieve the AWS credential from the AWS
6169 * metadata server.
6270 */
6371 static class AwsCredentialSource extends CredentialSource {
6472
73+ private static final String IMDSV2_SESSION_TOKEN_URL_FIELD_NAME = "imdsv2_session_token_url" ;
74+
6575 private final String regionUrl ;
6676 private final String url ;
6777 private final String regionalCredentialVerificationUrl ;
78+ private final String imdsv2SessionTokenUrl ;
6879
6980 /**
7081 * The source of the AWS credential. The credential source map must contain the
@@ -107,6 +118,13 @@ static class AwsCredentialSource extends CredentialSource {
107118 this .url = (String ) credentialSourceMap .get ("url" );
108119 this .regionalCredentialVerificationUrl =
109120 (String ) credentialSourceMap .get ("regional_cred_verification_url" );
121+
122+ if (credentialSourceMap .containsKey (IMDSV2_SESSION_TOKEN_URL_FIELD_NAME )) {
123+ this .imdsv2SessionTokenUrl =
124+ (String ) credentialSourceMap .get (IMDSV2_SESSION_TOKEN_URL_FIELD_NAME );
125+ } else {
126+ this .imdsv2SessionTokenUrl = null ;
127+ }
110128 }
111129 }
112130
@@ -135,11 +153,13 @@ public AccessToken refreshAccessToken() throws IOException {
135153
136154 @ Override
137155 public String retrieveSubjectToken () throws IOException {
156+ Map <String , Object > metadataRequestHeaders = createMetadataRequestHeaders (awsCredentialSource );
157+
138158 // The targeted region is required to generate the signed request. The regional
139159 // endpoint must also be used.
140- String region = getAwsRegion ();
160+ String region = getAwsRegion (metadataRequestHeaders );
141161
142- AwsSecurityCredentials credentials = getAwsSecurityCredentials ();
162+ AwsSecurityCredentials credentials = getAwsSecurityCredentials (metadataRequestHeaders );
143163
144164 // Generate the signed request to the AWS STS GetCallerIdentity API.
145165 Map <String , String > headers = new HashMap <>();
@@ -164,10 +184,28 @@ public GoogleCredentials createScoped(Collection<String> newScopes) {
164184 return new AwsCredentials ((AwsCredentials .Builder ) newBuilder (this ).setScopes (newScopes ));
165185 }
166186
167- private String retrieveResource (String url , String resourceName ) throws IOException {
187+ private String retrieveResource (String url , String resourceName , Map <String , Object > headers )
188+ throws IOException {
189+ return retrieveResource (url , resourceName , HttpMethods .GET , headers , /* content= */ null );
190+ }
191+
192+ private String retrieveResource (
193+ String url ,
194+ String resourceName ,
195+ String requestMethod ,
196+ Map <String , Object > headers ,
197+ @ Nullable HttpContent content )
198+ throws IOException {
168199 try {
169200 HttpRequestFactory requestFactory = transportFactory .create ().createRequestFactory ();
170- HttpRequest request = requestFactory .buildGetRequest (new GenericUrl (url ));
201+ HttpRequest request =
202+ requestFactory .buildRequest (requestMethod , new GenericUrl (url ), content );
203+
204+ HttpHeaders requestHeaders = request .getHeaders ();
205+ for (Map .Entry <String , Object > header : headers .entrySet ()) {
206+ requestHeaders .set (header .getKey (), header .getValue ());
207+ }
208+
171209 HttpResponse response = request .execute ();
172210 return response .parseAsString ();
173211 } catch (IOException e ) {
@@ -200,8 +238,42 @@ private String buildSubjectToken(AwsRequestSignature signature)
200238 return URLEncoder .encode (token .toString (), "UTF-8" );
201239 }
202240
241+ Map <String , Object > createMetadataRequestHeaders (AwsCredentialSource awsCredentialSource )
242+ throws IOException {
243+ Map <String , Object > metadataRequestHeaders = new HashMap <>();
244+
245+ // AWS IDMSv2 introduced a requirement for a session token to be present
246+ // with the requests made to metadata endpoints. This requirement is to help
247+ // prevent SSRF attacks.
248+ // Presence of "imdsv2_session_token_url" in Credential Source of config file
249+ // will trigger a flow with session token, else there will not be a session
250+ // token with the metadata requests.
251+ // Both flows work for IDMS v1 and v2. But if IDMSv2 is enabled, then if
252+ // session token is not present, Unauthorized exception will be thrown.
253+ if (awsCredentialSource .imdsv2SessionTokenUrl != null ) {
254+ Map <String , Object > tokenRequestHeaders =
255+ new HashMap <String , Object >() {
256+ {
257+ put (AWS_IMDSV2_SESSION_TOKEN_TTL_HEADER , AWS_IMDSV2_SESSION_TOKEN_TTL );
258+ }
259+ };
260+
261+ String imdsv2SessionToken =
262+ retrieveResource (
263+ awsCredentialSource .imdsv2SessionTokenUrl ,
264+ "Session Token" ,
265+ HttpMethods .PUT ,
266+ tokenRequestHeaders ,
267+ /* content= */ null );
268+
269+ metadataRequestHeaders .put (AWS_IMDSV2_SESSION_TOKEN_HEADER , imdsv2SessionToken );
270+ }
271+
272+ return metadataRequestHeaders ;
273+ }
274+
203275 @ VisibleForTesting
204- String getAwsRegion () throws IOException {
276+ String getAwsRegion (Map < String , Object > metadataRequestHeaders ) throws IOException {
205277 // For AWS Lambda, the region is retrieved through the AWS_REGION environment variable.
206278 String region = getEnvironmentProvider ().getEnv ("AWS_REGION" );
207279 if (region != null ) {
@@ -218,15 +290,16 @@ String getAwsRegion() throws IOException {
218290 "Unable to determine the AWS region. The credential source does not contain the region URL." );
219291 }
220292
221- region = retrieveResource (awsCredentialSource .regionUrl , "region" );
293+ region = retrieveResource (awsCredentialSource .regionUrl , "region" , metadataRequestHeaders );
222294
223295 // There is an extra appended character that must be removed. If `us-east-1b` is returned,
224296 // we want `us-east-1`.
225297 return region .substring (0 , region .length () - 1 );
226298 }
227299
228300 @ VisibleForTesting
229- AwsSecurityCredentials getAwsSecurityCredentials () throws IOException {
301+ AwsSecurityCredentials getAwsSecurityCredentials (Map <String , Object > metadataRequestHeaders )
302+ throws IOException {
230303 // Check environment variables for credentials first.
231304 String accessKeyId = getEnvironmentProvider ().getEnv ("AWS_ACCESS_KEY_ID" );
232305 String secretAccessKey = getEnvironmentProvider ().getEnv ("AWS_SECRET_ACCESS_KEY" );
@@ -243,12 +316,13 @@ AwsSecurityCredentials getAwsSecurityCredentials() throws IOException {
243316 "Unable to determine the AWS IAM role name. The credential source does not contain the"
244317 + " url field." );
245318 }
246- String roleName = retrieveResource (awsCredentialSource .url , "IAM role" );
319+ String roleName = retrieveResource (awsCredentialSource .url , "IAM role" , metadataRequestHeaders );
247320
248321 // Retrieve the AWS security credentials by calling the endpoint specified by the credential
249322 // source.
250323 String awsCredentials =
251- retrieveResource (awsCredentialSource .url + "/" + roleName , "credentials" );
324+ retrieveResource (
325+ awsCredentialSource .url + "/" + roleName , "credentials" , metadataRequestHeaders );
252326
253327 JsonParser parser = OAuth2Utils .JSON_FACTORY .createJsonParser (awsCredentials );
254328 GenericJson genericJson = parser .parseAndClose (GenericJson .class );
0 commit comments