Skip to content

Commit 4cc2e7f

Browse files
added post client
1 parent b4ab3f5 commit 4cc2e7f

File tree

1 file changed

+314
-0
lines changed

1 file changed

+314
-0
lines changed

src/ws/ws_post_client.rs

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
use crate::{
2+
exchange::Actions,
3+
helpers::next_nonce,
4+
signature::sign_l1_action,
5+
BaseUrl, BulkCancelCloid, BulkOrder, Error, ExchangeResponseStatus,
6+
};
7+
use ethers::{
8+
signers::LocalWallet,
9+
types::{H160, H256},
10+
};
11+
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
12+
use log::{debug, error};
13+
use serde::{Deserialize, Serialize};
14+
use std::{
15+
collections::HashMap,
16+
sync::{
17+
atomic::{AtomicU64, Ordering},
18+
Arc,
19+
},
20+
time::Duration,
21+
};
22+
use tokio::{
23+
net::TcpStream,
24+
spawn,
25+
sync::{Mutex, oneshot},
26+
time::timeout,
27+
};
28+
use tokio_tungstenite::{
29+
connect_async_with_config,
30+
tungstenite::protocol::{self, WebSocketConfig},
31+
MaybeTlsStream, WebSocketStream,
32+
};
33+
34+
#[derive(Serialize, Debug)]
35+
#[serde(rename_all = "camelCase")]
36+
struct WsPostRequest<T> {
37+
method: String,
38+
id: u64,
39+
request: WsRequestData<T>,
40+
}
41+
42+
#[derive(Serialize, Debug)]
43+
#[serde(rename_all = "camelCase")]
44+
struct WsRequestData<T> {
45+
#[serde(rename = "type")]
46+
request_type: String,
47+
payload: T,
48+
}
49+
50+
#[derive(Deserialize, Debug)]
51+
#[serde(rename_all = "camelCase")]
52+
struct WsPostResponse {
53+
channel: String,
54+
data: WsResponseData,
55+
}
56+
57+
#[derive(Deserialize, Debug)]
58+
#[serde(rename_all = "camelCase")]
59+
struct WsResponseData {
60+
id: u64,
61+
response: WsResponse,
62+
}
63+
64+
#[derive(Deserialize, Debug)]
65+
#[serde(tag = "type")]
66+
#[serde(rename_all = "camelCase")]
67+
enum WsResponse {
68+
Action { payload: ExchangeResponseStatus },
69+
Error { payload: String },
70+
}
71+
72+
#[derive(Serialize, Debug)]
73+
#[serde(rename_all = "camelCase")]
74+
struct WsExchangePayload {
75+
action: serde_json::Value,
76+
signature: WsSignature,
77+
nonce: u64,
78+
vault_address: Option<H160>,
79+
}
80+
81+
#[derive(Serialize, Debug)]
82+
struct WsSignature {
83+
r: String,
84+
s: String,
85+
v: u8,
86+
}
87+
88+
type ResponseSender = oneshot::Sender<Result<ExchangeResponseStatus, Error>>;
89+
90+
#[derive(Debug)]
91+
pub struct WsPostClient {
92+
writer: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, protocol::Message>>>,
93+
pending_requests: Arc<Mutex<HashMap<u64, ResponseSender>>>,
94+
request_id_counter: AtomicU64,
95+
}
96+
97+
impl WsPostClient {
98+
pub async fn new(base_url: BaseUrl) -> Result<Self, Error> {
99+
let url = match base_url {
100+
BaseUrl::Mainnet => "wss://api.hyperliquid.xyz/ws",
101+
BaseUrl::Testnet => "wss://api.hyperliquid-testnet.xyz/ws",
102+
BaseUrl::Localhost => "ws://localhost:3001/ws",
103+
};
104+
105+
let (ws_stream, _) = connect_async_with_config(
106+
url,
107+
Some(create_optimized_websocket_config()),
108+
true,
109+
)
110+
.await
111+
.map_err(|e| Error::Websocket(e.to_string()))?;
112+
113+
let (writer, mut reader) = ws_stream.split();
114+
let writer = Arc::new(Mutex::new(writer));
115+
let pending_requests: Arc<Mutex<HashMap<u64, ResponseSender>>> =
116+
Arc::new(Mutex::new(HashMap::new()));
117+
118+
// Spawn reader task to handle responses
119+
let pending_requests_clone = pending_requests.clone();
120+
spawn(async move {
121+
while let Some(msg) = reader.next().await {
122+
match msg {
123+
Ok(protocol::Message::Text(text)) => {
124+
if let Err(e) = Self::handle_response(text.to_string(), &pending_requests_clone).await {
125+
error!("Error handling websocket response: {}", e);
126+
}
127+
}
128+
Ok(protocol::Message::Pong(_)) => {
129+
debug!("Received pong");
130+
}
131+
Ok(_) => {
132+
debug!("Received non-text message");
133+
}
134+
Err(e) => {
135+
error!("WebSocket error: {}", e);
136+
// Notify all pending requests about the error
137+
let mut pending = pending_requests_clone.lock().await;
138+
for (_, sender) in pending.drain() {
139+
let _ = sender.send(Err(Error::Websocket(e.to_string())));
140+
}
141+
break;
142+
}
143+
}
144+
}
145+
});
146+
147+
Ok(Self {
148+
writer,
149+
pending_requests,
150+
request_id_counter: AtomicU64::new(1),
151+
})
152+
}
153+
154+
async fn handle_response(
155+
text: String,
156+
pending_requests: &Arc<Mutex<HashMap<u64, ResponseSender>>>,
157+
) -> Result<(), Error> {
158+
// First try to parse as a proper response
159+
if let Ok(response) = serde_json::from_str::<WsPostResponse>(&text) {
160+
if response.channel == "post" {
161+
let mut pending = pending_requests.lock().await;
162+
if let Some(sender) = pending.remove(&response.data.id) {
163+
let result = match response.data.response {
164+
WsResponse::Action { payload } => Ok(payload),
165+
WsResponse::Error { payload } => Err(Error::GenericRequest(payload)),
166+
};
167+
let _ = sender.send(result);
168+
}
169+
}
170+
return Ok(());
171+
}
172+
173+
// If that fails, it might be an error string - log it
174+
error!("Received non-standard response: {}", text);
175+
176+
// For now, we can't correlate this to a specific request, so we'll ignore it
177+
// In a production system, you might want to handle this differently
178+
Ok(())
179+
}
180+
181+
async fn send_request<T: Serialize>(
182+
&self,
183+
payload: T,
184+
timeout_duration: Duration,
185+
) -> Result<ExchangeResponseStatus, Error> {
186+
let request_id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
187+
let (tx, rx) = oneshot::channel();
188+
189+
// Store the response sender
190+
{
191+
let mut pending = self.pending_requests.lock().await;
192+
pending.insert(request_id, tx);
193+
}
194+
195+
// Create and send the request
196+
let request = WsPostRequest {
197+
method: "post".to_string(),
198+
id: request_id,
199+
request: WsRequestData {
200+
request_type: "action".to_string(),
201+
payload,
202+
},
203+
};
204+
205+
let message_text =
206+
serde_json::to_string(&request).map_err(|e| Error::JsonParse(e.to_string()))?;
207+
208+
{
209+
let mut writer = self.writer.lock().await;
210+
writer
211+
.send(protocol::Message::Text(message_text.into()))
212+
.await
213+
.map_err(|e| Error::Websocket(e.to_string()))?;
214+
}
215+
216+
// Wait for response with timeout
217+
match timeout(timeout_duration, rx).await {
218+
Ok(Ok(result)) => result,
219+
Ok(Err(_)) => Err(Error::GenericRequest("Response channel closed".to_string())),
220+
Err(_) => {
221+
// Remove the pending request on timeout
222+
let mut pending = self.pending_requests.lock().await;
223+
pending.remove(&request_id);
224+
Err(Error::GenericRequest("Request timeout".to_string()))
225+
}
226+
}
227+
}
228+
229+
pub async fn bulk_order(
230+
&self,
231+
action: BulkOrder,
232+
wallet: &LocalWallet,
233+
is_mainnet: bool,
234+
vault_address: Option<H160>,
235+
) -> Result<ExchangeResponseStatus, Error> {
236+
let timestamp = next_nonce();
237+
let full_action = Actions::Order(action);
238+
let connection_id = self.calculate_action_hash(&full_action, timestamp, vault_address)?;
239+
let signature = sign_l1_action(wallet, connection_id, is_mainnet)?;
240+
241+
let r = format!("0x{:x}", signature.r);
242+
let s = format!("0x{:x}", signature.s);
243+
let v = signature.v as u8;
244+
245+
let payload = WsExchangePayload {
246+
action: serde_json::to_value(&full_action)
247+
.map_err(|e| Error::JsonParse(e.to_string()))?,
248+
signature: WsSignature { r, s, v },
249+
nonce: timestamp,
250+
vault_address,
251+
};
252+
253+
self.send_request(payload, Duration::from_secs(15)).await
254+
}
255+
256+
pub async fn bulk_cancel_by_cloid(
257+
&self,
258+
action: BulkCancelCloid,
259+
wallet: &LocalWallet,
260+
is_mainnet: bool,
261+
vault_address: Option<H160>,
262+
) -> Result<ExchangeResponseStatus, Error> {
263+
let timestamp = next_nonce();
264+
let full_action = Actions::CancelByCloid(action);
265+
let connection_id = self.calculate_action_hash(&full_action, timestamp, vault_address)?;
266+
let signature = sign_l1_action(wallet, connection_id, is_mainnet)?;
267+
268+
let r = format!("0x{:x}", signature.r);
269+
let s = format!("0x{:x}", signature.s);
270+
let v = signature.v as u8;
271+
272+
let payload = WsExchangePayload {
273+
action: serde_json::to_value(&full_action)
274+
.map_err(|e| Error::JsonParse(e.to_string()))?,
275+
signature: WsSignature { r, s, v },
276+
nonce: timestamp,
277+
vault_address,
278+
};
279+
280+
self.send_request(payload, Duration::from_secs(15)).await
281+
}
282+
283+
fn calculate_action_hash<T: Serialize>(
284+
&self,
285+
action: &T,
286+
timestamp: u64,
287+
vault_address: Option<H160>,
288+
) -> Result<H256, Error> {
289+
let mut bytes =
290+
rmp_serde::to_vec_named(action).map_err(|e| Error::RmpParse(e.to_string()))?;
291+
bytes.extend(timestamp.to_be_bytes());
292+
if let Some(vault_address) = vault_address {
293+
bytes.push(1);
294+
bytes.extend(vault_address.to_fixed_bytes());
295+
} else {
296+
bytes.push(0);
297+
}
298+
Ok(H256(ethers::utils::keccak256(bytes)))
299+
}
300+
}
301+
302+
/// Create optimized WebSocket configuration for low-latency trading
303+
fn create_optimized_websocket_config() -> WebSocketConfig {
304+
let mut config = WebSocketConfig::default();
305+
306+
config.read_buffer_size = 64 * 1024;
307+
config.write_buffer_size = 0;
308+
config.max_write_buffer_size = 512 * 1024;
309+
config.max_message_size = None;
310+
config.max_frame_size = Some(128 * 1024);
311+
config.accept_unmasked_frames = false;
312+
313+
config
314+
}

0 commit comments

Comments
 (0)