Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ See [action.yml](./action.yml) for more detail.
| retry-max-attempts | Limits the number of retry attempts before giving up. Defaults to 12. | No |
| special-characters-workaround | Uncommonly, some environments cannot tolerate special characters in a secret key. This option will retry fetching credentials until the secret access key does not contain special characters. This option overrides disable-retry and retry-max-attempts. | No |
| use-existing-credentials | When set, the action will check if existing credentials are valid and exit if they are. Defaults to false. | No |
| allowed-account-ids | A comma-delimited list of expected AWS account IDs. The action will fail if we receive credentials for the wrong account. | No |
</details>

#### Adjust the retry mechanism
Expand Down
5 changes: 5 additions & 0 deletions action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ inputs:
description: Some environments do not support special characters in AWS_SECRET_ACCESS_KEY. This option will retry fetching credentials until the secret access key does not contain special characters. This option overrides disable-retry and retry-max-attempts. This option is disabled by default
required: false
use-existing-credentials:
required: false
description: When enabled, this option will check if there are already valid credentials in the environment. If there are, new credentials will not be fetched. If there are not, the action will run as normal.
allowed-account-ids:
required: false
description: An option comma-delimited list of expected AWS account IDs. The action will fail if we receive credentials for the wrong account.

outputs:
aws-account-id:
description: The AWS account ID for the provided credentials
Expand Down
26 changes: 22 additions & 4 deletions src/CredentialsClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { STSClient } from '@aws-sdk/client-sts';
import type { AwsCredentialIdentity } from '@aws-sdk/types';
import { NodeHttpHandler } from '@smithy/node-http-handler';
import { HttpsProxyAgent } from 'https-proxy-agent';
import { errorMessage } from './helpers';
import { errorMessage, getCallerIdentity } from './helpers';

const USER_AGENT = 'configure-aws-credentials-for-github-actions';

Expand Down Expand Up @@ -40,7 +40,11 @@ export class CredentialsClient {
return this._stsClient;
}

public async validateCredentials(expectedAccessKeyId?: string, roleChaining?: boolean) {
public async validateCredentials(
expectedAccessKeyId?: string,
roleChaining?: boolean,
expectedAccountIds?: string[],
) {
let credentials: AwsCredentialIdentity;
try {
credentials = await this.loadCredentials();
Expand All @@ -50,13 +54,27 @@ export class CredentialsClient {
} catch (error) {
throw new Error(`Credentials could not be loaded, please check your action inputs: ${errorMessage(error)}`);
}
if (expectedAccountIds && expectedAccountIds.length > 0 && expectedAccountIds[0] !== '') {
let callerIdentity: Awaited<ReturnType<typeof getCallerIdentity>>;
try {
callerIdentity = await getCallerIdentity(this.stsClient);
} catch (error) {
throw new Error(`Could not validate account ID of credentials: ${errorMessage(error)}`);
}
if (!callerIdentity.Account || !expectedAccountIds.includes(callerIdentity.Account)) {
throw new Error(
`The account ID of the provided credentials (${
callerIdentity.Account ?? 'unknown'
}) does not match any of the expected account IDs: ${expectedAccountIds.join(', ')}`,
);
}
}

if (!roleChaining) {
const actualAccessKeyId = credentials.accessKeyId;

if (expectedAccessKeyId && expectedAccessKeyId !== actualAccessKeyId) {
throw new Error(
'Unexpected failure: Credentials loaded by the SDK do not match the access key ID configured by the action',
'Credentials loaded by the SDK do not match the expected access key ID configured by the action',
);
}
}
Expand Down
16 changes: 10 additions & 6 deletions src/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as core from '@actions/core';
import type { Credentials } from '@aws-sdk/client-sts';
import type { Credentials, STSClient } from '@aws-sdk/client-sts';
import { GetCallerIdentityCommand } from '@aws-sdk/client-sts';
import type { CredentialsClient } from './CredentialsClient';

Expand Down Expand Up @@ -109,15 +109,19 @@ export function exportRegion(region: string, outputEnvCredentials?: boolean) {
}
}

export async function getCallerIdentity(client: STSClient): Promise<{ Account: string; Arn: string; UserId?: string }> {
const identity = await client.send(new GetCallerIdentityCommand({}));
if (!identity.Account || !identity.Arn) {
throw new Error('Could not get Account ID or ARN from STS. Did you set credentials?');
}
return { Account: identity.Account, Arn: identity.Arn, UserId: identity.UserId };
}

// Obtains account ID from STS Client and sets it as output
export async function exportAccountId(credentialsClient: CredentialsClient, maskAccountId?: boolean) {
const client = credentialsClient.stsClient;
const identity = await client.send(new GetCallerIdentityCommand({}));
const identity = await getCallerIdentity(credentialsClient.stsClient);
const accountId = identity.Account;
const arn = identity.Arn;
if (!accountId || !arn) {
throw new Error('Could not get Account ID or ARN from STS. Did you set credentials?');
}
if (maskAccountId) {
core.setSecret(accountId);
core.setSecret(arn);
Expand Down
14 changes: 11 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ export async function run() {
const specialCharacterWorkaround = getBooleanInput('special-characters-workaround', { required: false });
const useExistingCredentials = core.getInput('use-existing-credentials', { required: false });
let maxRetries = Number.parseInt(core.getInput('retry-max-attempts', { required: false })) || 12;
const expectedAccountIds = core
.getInput('allowed-account-ids', { required: false })
.split(',')
.map((s) => s.trim());

if (specialCharacterWorkaround) {
// 😳
Expand Down Expand Up @@ -136,15 +140,15 @@ export async function run() {
exportCredentials({ AccessKeyId, SecretAccessKey, SessionToken }, outputCredentials, outputEnvCredentials);
} else if (!webIdentityTokenFile && !roleChaining) {
// Proceed only if credentials can be picked up
await credentialsClient.validateCredentials();
await credentialsClient.validateCredentials(undefined, roleChaining, expectedAccountIds);
sourceAccountId = await exportAccountId(credentialsClient, maskAccountId);
}

if (AccessKeyId || roleChaining) {
// Validate that the SDK can actually pick up credentials.
// This validates cases where this action is using existing environment credentials,
// and cases where the user intended to provide input credentials but the secrets inputs resolved to empty strings.
await credentialsClient.validateCredentials(AccessKeyId, roleChaining);
await credentialsClient.validateCredentials(AccessKeyId, roleChaining, expectedAccountIds);
sourceAccountId = await exportAccountId(credentialsClient, maskAccountId);
}

Expand Down Expand Up @@ -179,7 +183,11 @@ export async function run() {
// is set to `true` then we are NOT in a self-hosted runner.
// Second: Customer provided credentials manually (IAM User keys stored in GH Secrets)
if (!process.env.GITHUB_ACTIONS || AccessKeyId) {
await credentialsClient.validateCredentials(roleCredentials.Credentials?.AccessKeyId);
await credentialsClient.validateCredentials(
roleCredentials.Credentials?.AccessKeyId,
roleChaining,
expectedAccountIds,
);
}
if (outputEnvCredentials) {
await exportAccountId(credentialsClient, maskAccountId);
Expand Down
150 changes: 150 additions & 0 deletions test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,156 @@ describe('Configure AWS Credentials', {}, () => {
});
});

describe('Account ID Validation', {}, () => {
beforeEach(() => {
vi.clearAllMocks();
mockedSTSClient.reset();
});

it('succeeds when account ID matches allowed list', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': '111111111111'
}));
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).not.toHaveBeenCalled();
expect(core.info).toHaveBeenCalledWith('Proceeding with IAM user credentials');
});

it('succeeds with multiple allowed account IDs when account matches', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': '999999999999,111111111111,222222222222'
}));
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).not.toHaveBeenCalled();
});

it('fails when account ID does not match allowed list', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': '999999999999'
}));
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).toHaveBeenCalledWith(
'The account ID of the provided credentials (111111111111) does not match any of the expected account IDs: 999999999999'
);
});

it('fails when account ID does not match any in multiple allowed accounts', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': '999999999999,888888888888'
}));
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).toHaveBeenCalledWith(
'The account ID of the provided credentials (111111111111) does not match any of the expected account IDs: 999999999999, 888888888888'
);
});

it('works with assume role when account ID matches', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_ASSUMEROLE_INPUTS,
'allowed-account-ids': '111111111111'
}));
mockedSTSClient.on(AssumeRoleCommand).resolves(mocks.outputs.STS_CREDENTIALS);
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials')
.mockResolvedValueOnce({ accessKeyId: 'MYAWSACCESSKEYID' })
.mockResolvedValueOnce({ accessKeyId: 'STSAWSACCESSKEYID' });

await run();
expect(core.setFailed).not.toHaveBeenCalled();
expect(core.info).toHaveBeenCalledWith('Authenticated as assumedRoleId AROAFAKEASSUMEDROLEID');
});

it('works with OIDC when account ID matches', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.GH_OIDC_INPUTS,
'allowed-account-ids': '111111111111'
}));
vi.spyOn(core, 'getIDToken').mockResolvedValue('testoidctoken');
mockedSTSClient.on(AssumeRoleWithWebIdentityCommand).resolves(mocks.outputs.STS_CREDENTIALS);
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
process.env.ACTIONS_ID_TOKEN_REQUEST_TOKEN = 'fake-token';

await run();
expect(core.setFailed).not.toHaveBeenCalled();
expect(core.info).toHaveBeenCalledWith('Authenticated as assumedRoleId AROAFAKEASSUMEDROLEID');
});

it('handles GetCallerIdentity API failure gracefully', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': '111111111111'
}));
mockedSTSClient.on(GetCallerIdentityCommand).rejects(new Error('API Error'));
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).toHaveBeenCalledWith('Could not validate account ID of credentials: API Error');
});

it('ignores validation when allowed-account-ids is empty', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': ''
}));
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).not.toHaveBeenCalled();
expect(core.info).toHaveBeenCalledWith('Proceeding with IAM user credentials');
});

it('handles whitespace in allowed-account-ids input', async () => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput({
...mocks.IAM_USER_INPUTS,
'allowed-account-ids': ' 111111111111 , 222222222222 '
}));
mockedSTSClient.on(GetCallerIdentityCommand).resolves({ ...mocks.outputs.GET_CALLER_IDENTITY });
// biome-ignore lint/suspicious/noExplicitAny: any required to mock private method
vi.spyOn(CredentialsClient.prototype as any, 'loadCredentials').mockResolvedValue({
accessKeyId: 'MYAWSACCESSKEYID',
});

await run();
expect(core.setFailed).not.toHaveBeenCalled();
});
});

describe('HTTP Proxy Configuration', {}, () => {
beforeEach(() => {
vi.spyOn(core, 'getInput').mockImplementation(mocks.getInput(mocks.GH_OIDC_INPUTS));
Expand Down
Loading