Skip to content

Commit 2aecd0d

Browse files
committed
feat: initial implementation for users and markets WS improvements
1 parent 5b064c9 commit 2aecd0d

File tree

9 files changed

+846
-13
lines changed

9 files changed

+846
-13
lines changed

sdk/src/accounts/webSocketDriftClientAccountSubscriberV2.ts

Lines changed: 665 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import {
2+
DataAndSlot,
3+
AccountSubscriber,
4+
NotSubscribedError,
5+
UserAccountEvents,
6+
UserAccountSubscriber,
7+
ResubOpts,
8+
} from './types';
9+
import { Program } from '@coral-xyz/anchor';
10+
import StrictEventEmitter from 'strict-event-emitter-types';
11+
import { EventEmitter } from 'events';
12+
import { Commitment, Context, PublicKey } from '@solana/web3.js';
13+
import { WebSocketProgramAccountSubscriber } from './webSocketProgramAccountSubscriber';
14+
import { UserAccount } from '../types';
15+
16+
export class WebSocketProgramUserAccountSubscriber implements UserAccountSubscriber {
17+
isSubscribed: boolean;
18+
eventEmitter: StrictEventEmitter<EventEmitter, UserAccountEvents>;
19+
20+
private userAccountPublicKey: PublicKey;
21+
private program: Program;
22+
private programSubscriber: WebSocketProgramAccountSubscriber<UserAccount>;
23+
private userAccountAndSlot?: DataAndSlot<UserAccount>;
24+
25+
public constructor(
26+
program: Program,
27+
userAccountPublicKey: PublicKey,
28+
programSubscriber: WebSocketProgramAccountSubscriber<UserAccount>
29+
) {
30+
this.isSubscribed = false;
31+
this.program = program;
32+
this.userAccountPublicKey = userAccountPublicKey;
33+
this.eventEmitter = new EventEmitter();
34+
this.programSubscriber = programSubscriber;
35+
}
36+
37+
async subscribe(userAccount?: UserAccount): Promise<boolean> {
38+
if (this.isSubscribed) {
39+
return true;
40+
}
41+
42+
if (userAccount) {
43+
this.updateData(userAccount, 0);
44+
}
45+
46+
this.programSubscriber.onChange = (
47+
accountId: PublicKey,
48+
data: UserAccount,
49+
context: Context
50+
) => {
51+
if (accountId.equals(this.userAccountPublicKey)) {
52+
this.updateData(data, context.slot);
53+
this.eventEmitter.emit('userAccountUpdate', data);
54+
this.eventEmitter.emit('update');
55+
}
56+
};
57+
58+
this.isSubscribed = true;
59+
return true;
60+
}
61+
62+
async fetch(): Promise<void> {
63+
if (!this.isSubscribed) {
64+
throw new NotSubscribedError(
65+
'Must subscribe before fetching account updates'
66+
);
67+
}
68+
69+
const account = await this.program.account.user.fetch(
70+
this.userAccountPublicKey
71+
);
72+
this.updateData(account as UserAccount, 0);
73+
}
74+
75+
updateData(userAccount: UserAccount, slot: number): void {
76+
this.userAccountAndSlot = {
77+
data: userAccount,
78+
slot,
79+
};
80+
}
81+
82+
async unsubscribe(): Promise<void> {
83+
this.isSubscribed = false;
84+
}
85+
86+
getUserAccountAndSlot(): DataAndSlot<UserAccount> {
87+
if (!this.userAccountAndSlot) {
88+
throw new NotSubscribedError(
89+
'Must subscribe before getting user account data'
90+
);
91+
}
92+
return this.userAccountAndSlot;
93+
}
94+
}

sdk/src/driftClient.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ import { getOracleId } from './oracles/oracleId';
193193
import { SignedMsgOrderParams } from './types';
194194
import { sha256 } from '@noble/hashes/sha256';
195195
import { getOracleConfidenceFromMMOracleData } from './oracles/utils';
196+
import { WebSocketDriftClientAccountSubscriberV2 } from './accounts/webSocketDriftClientAccountSubscriberV2';
196197

197198
type RemainingAccountParams = {
198199
userAccounts: UserAccount[];
@@ -370,6 +371,7 @@ export class DriftClient {
370371
resubTimeoutMs: config.accountSubscription?.resubTimeoutMs,
371372
logResubMessages: config.accountSubscription?.logResubMessages,
372373
commitment: config.accountSubscription?.commitment,
374+
programUserAccountSubscriber: config.accountSubscription?.programUserAccountSubscriber,
373375
};
374376
this.userStatsAccountSubscriptionConfig = {
375377
type: 'websocket',
@@ -435,7 +437,7 @@ export class DriftClient {
435437
}
436438
);
437439
} else {
438-
this.accountSubscriber = new WebSocketDriftClientAccountSubscriber(
440+
this.accountSubscriber = new WebSocketDriftClientAccountSubscriberV2(
439441
this.program,
440442
config.perpMarketIndexes ?? [],
441443
config.spotMarketIndexes ?? [],
@@ -607,7 +609,7 @@ export class DriftClient {
607609
public getSpotMarketAccount(
608610
marketIndex: number
609611
): SpotMarketAccount | undefined {
610-
return this.accountSubscriber.getSpotMarketAccountAndSlot(marketIndex).data;
612+
return this.accountSubscriber.getSpotMarketAccountAndSlot(marketIndex)?.data;
611613
}
612614

613615
/**
@@ -618,7 +620,7 @@ export class DriftClient {
618620
marketIndex: number
619621
): Promise<SpotMarketAccount | undefined> {
620622
await this.accountSubscriber.fetch();
621-
return this.accountSubscriber.getSpotMarketAccountAndSlot(marketIndex).data;
623+
return this.accountSubscriber.getSpotMarketAccountAndSlot(marketIndex)?.data;
622624
}
623625

624626
public getSpotMarketAccounts(): SpotMarketAccount[] {
@@ -927,6 +929,8 @@ export class DriftClient {
927929
authority?: PublicKey,
928930
userAccount?: UserAccount
929931
): Promise<boolean> {
932+
933+
930934
authority = authority ?? this.authority;
931935
const userKey = this.getUserMapKey(subAccountId, authority);
932936

@@ -954,6 +958,7 @@ export class DriftClient {
954958
* Adds and subscribes to users based on params set by the constructor or by updateWallet.
955959
*/
956960
public async addAndSubscribeToUsers(authority?: PublicKey): Promise<boolean> {
961+
console.log('adding and subscribing to users', this.users.size);
957962
// save the rpc calls if driftclient is initialized without a real wallet
958963
if (this.skipLoadUsers) return true;
959964

sdk/src/driftClientConfig.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ import {
55
PublicKey,
66
TransactionVersion,
77
} from '@solana/web3.js';
8-
import { IWallet, TxParams } from './types';
8+
import { IWallet, TxParams, UserAccount } from './types';
99
import { OracleInfo } from './oracles/types';
1010
import { BulkAccountLoader } from './accounts/bulkAccountLoader';
1111
import { DriftEnv } from './config';
1212
import { TxSender } from './tx/types';
1313
import { TxHandler, TxHandlerConfig } from './tx/txHandler';
1414
import { DelistedMarketSetting, GrpcConfigs } from './accounts/types';
1515
import { Coder } from '@coral-xyz/anchor';
16+
import { WebSocketProgramAccountSubscriber } from './accounts/webSocketProgramAccountSubscriber';
1617

1718
export type DriftClientConfig = {
1819
connection: Connection;
@@ -57,6 +58,7 @@ export type DriftClientSubscriptionConfig =
5758
resubTimeoutMs?: number;
5859
logResubMessages?: boolean;
5960
commitment?: Commitment;
61+
programUserAccountSubscriber?: WebSocketProgramAccountSubscriber<UserAccount>;
6062
}
6163
| {
6264
type: 'polling';

sdk/src/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ export * from './accounts/fetch';
1212
export * from './accounts/webSocketDriftClientAccountSubscriber';
1313
export * from './accounts/webSocketInsuranceFundStakeAccountSubscriber';
1414
export * from './accounts/webSocketHighLeverageModeConfigAccountSubscriber';
15+
export { WebSocketProgramAccountSubscriber } from './accounts/webSocketProgramAccountSubscriber';
16+
export { WebSocketProgramUserAccountSubscriber } from './accounts/websocketProgramUserAccountSubscriber';
1517
export * from './accounts/bulkAccountLoader';
1618
export * from './accounts/bulkUserSubscription';
1719
export * from './accounts/bulkUserStatsSubscription';
@@ -133,5 +135,6 @@ export * from './clock/clockSubscriber';
133135
export * from './math/userStatus';
134136
export * from './constants/txConstants';
135137
export * from './indicative-quotes/indicativeQuotesSender';
138+
export { default as driftIDL } from './idl/drift.json';
136139

137140
export { BN, PublicKey, pyth };

sdk/src/memcmp.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,20 @@ export function getSignedMsgUserOrdersFilter(): MemcmpFilter {
112112
},
113113
};
114114
}
115+
116+
export function getPerpMarketAccountsFilter(): MemcmpFilter {
117+
return {
118+
memcmp: {
119+
offset: 0,
120+
bytes: bs58.encode(BorshAccountsCoder.accountDiscriminator('PerpMarket')),
121+
},
122+
};
123+
}
124+
export function getSpotMarketAccountsFilter(): MemcmpFilter {
125+
return {
126+
memcmp: {
127+
offset: 0,
128+
bytes: bs58.encode(BorshAccountsCoder.accountDiscriminator('SpotMarket')),
129+
},
130+
};
131+
}

sdk/src/oracles/oracleId.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,43 @@ export function getOracleSourceNum(source: OracleSource): number {
2424
throw new Error('Invalid oracle source');
2525
}
2626

27+
export function getOracleSourceFromNum(sourceNum: number): OracleSource {
28+
if (sourceNum === OracleSourceNum.PYTH) return 'pyth';
29+
if (sourceNum === OracleSourceNum.PYTH_1K) return 'pyth1K';
30+
if (sourceNum === OracleSourceNum.PYTH_1M) return 'pyth1M';
31+
if (sourceNum === OracleSourceNum.PYTH_PULL) return 'pythPull';
32+
if (sourceNum === OracleSourceNum.PYTH_1K_PULL) return 'pyth1KPull';
33+
if (sourceNum === OracleSourceNum.PYTH_1M_PULL) return 'pyth1MPull';
34+
if (sourceNum === OracleSourceNum.SWITCHBOARD) return 'switchboard';
35+
if (sourceNum === OracleSourceNum.QUOTE_ASSET) return 'quoteAsset';
36+
if (sourceNum === OracleSourceNum.PYTH_STABLE_COIN) return 'pythStableCoin';
37+
if (sourceNum === OracleSourceNum.PYTH_STABLE_COIN_PULL)
38+
return 'pythStableCoinPull';
39+
if (sourceNum === OracleSourceNum.PRELAUNCH) return 'prelaunch';
40+
if (sourceNum === OracleSourceNum.SWITCHBOARD_ON_DEMAND)
41+
return 'switchboardOnDemand';
42+
if (sourceNum === OracleSourceNum.PYTH_LAZER) return 'pythLazer';
43+
if (sourceNum === OracleSourceNum.PYTH_LAZER_1K) return 'pythLazer1K';
44+
if (sourceNum === OracleSourceNum.PYTH_LAZER_1M) return 'pythLazer1M';
45+
if (sourceNum === OracleSourceNum.PYTH_LAZER_STABLE_COIN)
46+
return 'pythLazerStableCoin';
47+
throw new Error('Invalid oracle source');
48+
}
49+
2750
export function getOracleId(
2851
publicKey: PublicKey,
2952
source: OracleSource
3053
): string {
3154
return `${publicKey.toBase58()}-${getOracleSourceNum(source)}`;
3255
}
56+
57+
export function getPublicKeyAndSourceFromOracleId(oracleId: string): {
58+
publicKey: PublicKey;
59+
source: OracleSource;
60+
} {
61+
const [publicKey, source] = oracleId.split('-');
62+
return {
63+
publicKey: new PublicKey(publicKey),
64+
source: getOracleSourceFromNum(parseInt(source)),
65+
};
66+
}

sdk/src/user.ts

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ import {
7575
} from './types';
7676
import { standardizeBaseAssetAmount } from './math/orders';
7777
import { UserStats } from './userStats';
78+
import { WebSocketProgramUserAccountSubscriber } from './accounts/websocketProgramUserAccountSubscriber';
7879
import {
7980
calculateAssetWeight,
8081
calculateLiabilityWeight,
@@ -149,15 +150,23 @@ export class User {
149150
}
150151
);
151152
} else {
152-
this.accountSubscriber = new WebSocketUserAccountSubscriber(
153-
config.driftClient.program,
154-
config.userAccountPublicKey,
155-
{
156-
resubTimeoutMs: config.accountSubscription?.resubTimeoutMs,
157-
logResubMessages: config.accountSubscription?.logResubMessages,
158-
},
159-
config.accountSubscription?.commitment
160-
);
153+
if (config.accountSubscription?.type === 'websocket' && config.accountSubscription?.programUserAccountSubscriber) {
154+
this.accountSubscriber = new WebSocketProgramUserAccountSubscriber(
155+
config.driftClient.program,
156+
config.userAccountPublicKey,
157+
config.accountSubscription.programUserAccountSubscriber
158+
);
159+
} else {
160+
this.accountSubscriber = new WebSocketUserAccountSubscriber(
161+
config.driftClient.program,
162+
config.userAccountPublicKey,
163+
{
164+
resubTimeoutMs: config.accountSubscription?.resubTimeoutMs,
165+
logResubMessages: config.accountSubscription?.logResubMessages,
166+
},
167+
config.accountSubscription?.commitment
168+
);
169+
}
161170
}
162171
this.eventEmitter = this.accountSubscriber.eventEmitter;
163172
}

sdk/src/userConfig.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ import { DriftClient } from './driftClient';
22
import { Commitment, PublicKey } from '@solana/web3.js';
33
import { BulkAccountLoader } from './accounts/bulkAccountLoader';
44
import { GrpcConfigs, UserAccountSubscriber } from './accounts/types';
5+
import { WebSocketProgramAccountSubscriber } from './accounts/webSocketProgramAccountSubscriber';
6+
import { UserAccount } from '@drift-labs/sdk';
7+
import { WebSocketProgramUserAccountSubscriber } from './accounts/websocketProgramUserAccountSubscriber';
58

69
export type UserConfig = {
710
accountSubscription?: UserSubscriptionConfig;
@@ -21,6 +24,7 @@ export type UserSubscriptionConfig =
2124
resubTimeoutMs?: number;
2225
logResubMessages?: boolean;
2326
commitment?: Commitment;
27+
programUserAccountSubscriber?: WebSocketProgramAccountSubscriber<UserAccount>;
2428
}
2529
| {
2630
type: 'polling';

0 commit comments

Comments
 (0)