Skip to content

Commit 14d13d9

Browse files
committed
Add support for hardware.list endpoint
1 parent b14eeaa commit 14d13d9

File tree

4 files changed

+76
-24
lines changed

4 files changed

+76
-24
lines changed

index.d.ts

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ declare module 'replicate' {
1414
models?: Model[];
1515
}
1616

17+
export interface Hardware {
18+
sku: string;
19+
name: string
20+
}
21+
1722
export interface Model {
1823
url: string;
1924
owner: string;
@@ -115,6 +120,25 @@ declare module 'replicate' {
115120
get(collection_slug: string): Promise<Collection>;
116121
};
117122

123+
deployments: {
124+
predictions: {
125+
create(
126+
deployment_owner: string,
127+
deployment_name: string,
128+
options: {
129+
input: object;
130+
stream?: boolean;
131+
webhook?: string;
132+
webhook_events_filter?: WebhookEventType[];
133+
}
134+
): Promise<Prediction>;
135+
};
136+
};
137+
138+
hardware: {
139+
list(): Promise<Hardware[]>
140+
}
141+
118142
models: {
119143
get(model_owner: string, model_name: string): Promise<Model>;
120144
list(): Promise<Page<Model>>;
@@ -157,20 +181,5 @@ declare module 'replicate' {
157181
cancel(training_id: string): Promise<Training>;
158182
list(): Promise<Page<Training>>;
159183
};
160-
161-
deployments: {
162-
predictions: {
163-
create(
164-
deployment_owner: string,
165-
deployment_name: string,
166-
options: {
167-
input: object;
168-
stream?: boolean;
169-
webhook?: string;
170-
webhook_events_filter?: WebhookEventType[];
171-
}
172-
): Promise<Prediction>;
173-
};
174-
};
175184
}
176185
}

index.js

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const { withAutomaticRetries } = require('./lib/util');
33

44
const collections = require('./lib/collections');
55
const deployments = require('./lib/deployments');
6+
const hardware = require('./lib/hardware');
67
const models = require('./lib/models');
78
const predictions = require('./lib/predictions');
89
const trainings = require('./lib/trainings');
@@ -49,6 +50,16 @@ class Replicate {
4950
get: collections.get.bind(this),
5051
};
5152

53+
this.deployments = {
54+
predictions: {
55+
create: deployments.predictions.create.bind(this),
56+
}
57+
};
58+
59+
this.hardware = {
60+
list: hardware.list.bind(this),
61+
};
62+
5263
this.models = {
5364
get: models.get.bind(this),
5465
list: models.list.bind(this),
@@ -71,12 +82,6 @@ class Replicate {
7182
cancel: trainings.cancel.bind(this),
7283
list: trainings.list.bind(this),
7384
};
74-
75-
this.deployments = {
76-
predictions: {
77-
create: deployments.predictions.create.bind(this),
78-
}
79-
};
8085
}
8186

8287
/**

index.test.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,20 @@ describe('Replicate client', () => {
136136
nock(BASE_URL)
137137
.get('/models')
138138
.reply(200, {
139-
results: [{ url: 'https://replicate.com/some-user/model-1' }],
139+
results: [ { url: 'https://replicate.com/some-user/model-1' } ],
140140
next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
141141
})
142142
.get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw')
143143
.reply(200, {
144-
results: [{ url: 'https://replicate.com/some-user/model-2' }],
144+
results: [ { url: 'https://replicate.com/some-user/model-2' } ],
145145
next: null,
146146
});
147147

148148
const results: Model[] = [];
149149
for await (const batch of client.paginate(client.models.list)) {
150150
results.push(...batch);
151151
}
152-
expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]);
152+
expect(results).toEqual([ { url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' } ]);
153153

154154
// Add more tests for error handling, edge cases, etc.
155155
});
@@ -662,6 +662,28 @@ describe('Replicate client', () => {
662662
// Add more tests for error handling, edge cases, etc.
663663
});
664664

665+
describe('hardware.list', () => {
666+
test('Calls the correct API route', async () => {
667+
nock(BASE_URL)
668+
.get('/hardware')
669+
.reply(200, [
670+
{ name: "CPU", sku: "cpu" },
671+
{ name: "Nvidia T4 GPU", sku: "gpu-t4" },
672+
{ name: "Nvidia A40 GPU", sku: "gpu-a40-small" },
673+
{ name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" },
674+
{ name: "Nvidia A40 (Large) GPU (8x)", sku: "gpu-a40-large-8x" },
675+
{ name: "Nvidia A100 (40GB) GPU", sku: "gpu-a100-small" },
676+
{ name: "Nvidia A100 (80GB) GPU", sku: "gpu-a100-large" },
677+
]);
678+
679+
const hardware = await client.hardware.list();
680+
expect(hardware.length).toBe(7);
681+
expect(hardware[ 0 ].name).toBe('CPU');
682+
expect(hardware[ 0 ].sku).toBe('cpu');
683+
});
684+
// Add more tests for error handling, edge cases, etc.
685+
});
686+
665687
describe('run', () => {
666688
test('Calls the correct API routes', async () => {
667689
let firstPollingRequest = true;

lib/hardware.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/**
2+
* List hardware
3+
*
4+
* @returns {Promise<object[]>} Resolves with the array of hardware
5+
*/
6+
async function listHardware() {
7+
const response = await this.request('/hardware', {
8+
method: 'GET',
9+
});
10+
11+
return response.json();
12+
}
13+
14+
module.exports = {
15+
list: listHardware,
16+
};

0 commit comments

Comments
 (0)