Skip to content

Commit ef46043

Browse files
nasonnsarrazin
andauthored
Dynamic credential resolution for AWS endpoints (#1419)
* Dynamic credential resolution for AWS endpoints If AWS credentials are provided, they will continue to be used. With this change, they become optional parameters and the AWS SDK will attempt to use the default resolution chain if they are not provided. * cleanup deps * restore optional aws region in endpointAwsParametersSchema --------- Co-authored-by: Nathan Sarrazin <[email protected]>
1 parent 403e040 commit ef46043

File tree

3 files changed

+57
-10
lines changed

3 files changed

+57
-10
lines changed

package-lock.json

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"@resvg/resvg-js": "^2.6.2",
7070
"autoprefixer": "^10.4.14",
7171
"aws4": "^1.13.0",
72+
"aws-sigv4-fetch": "^4.0.1",
7273
"browser-image-resizer": "^2.4.1",
7374
"date-fns": "^2.29.3",
7475
"dotenv": "^16.0.3",

src/lib/server/endpoints/aws/endpointAws.ts

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,20 @@ export const endpointAwsParametersSchema = z.object({
88
model: z.any(),
99
type: z.literal("aws"),
1010
url: z.string().url(),
11-
accessKey: z.string().min(1),
12-
secretKey: z.string().min(1),
11+
accessKey: z
12+
.string({
13+
description:
14+
"An AWS Access Key ID. If not provided, the default AWS identity resolution will be used",
15+
})
16+
.min(1)
17+
.optional(),
18+
secretKey: z
19+
.string({
20+
description:
21+
"An AWS Access Key Secret. If not provided, the default AWS identity resolution will be used",
22+
})
23+
.min(1)
24+
.optional(),
1325
sessionToken: z.string().optional(),
1426
service: z.union([z.literal("sagemaker"), z.literal("lambda")]).default("sagemaker"),
1527
region: z.string().optional(),
@@ -18,22 +30,23 @@ export const endpointAwsParametersSchema = z.object({
1830
export async function endpointAws(
1931
input: z.input<typeof endpointAwsParametersSchema>
2032
): Promise<Endpoint> {
21-
let AwsClient;
33+
let createSignedFetcher;
2234
try {
23-
AwsClient = (await import("aws4fetch")).AwsClient;
35+
createSignedFetcher = (await import("aws-sigv4-fetch")).createSignedFetcher;
2436
} catch (e) {
25-
throw new Error("Failed to import aws4fetch");
37+
throw new Error("Failed to import aws-sigv4-fetch");
2638
}
2739

2840
const { url, accessKey, secretKey, sessionToken, model, region, service } =
2941
endpointAwsParametersSchema.parse(input);
3042

31-
const aws = new AwsClient({
32-
accessKeyId: accessKey,
33-
secretAccessKey: secretKey,
34-
sessionToken,
43+
const signedFetch = createSignedFetcher({
3544
service,
3645
region,
46+
credentials:
47+
accessKey && secretKey
48+
? { accessKeyId: accessKey, secretAccessKey: secretKey, sessionToken }
49+
: undefined,
3750
});
3851

3952
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
@@ -52,7 +65,7 @@ export async function endpointAws(
5265
},
5366
{
5467
use_cache: false,
55-
fetch: aws.fetch.bind(aws) as typeof fetch,
68+
fetch: signedFetch,
5669
}
5770
);
5871
};

0 commit comments

Comments
 (0)