openzeppelin_relayer/services/plugins/
socket.rs

1//! This module is responsible for creating a socket connection to the relayer server.
2//! It is used to send requests to the relayer server and processing the responses.
3//! It also intercepts the logs, errors and return values.
4//!
5//! The socket connection is created using the `UnixListener`.
6//!
7//! 1. Creates a socket connection using the `UnixListener`.
8//! 2. Each request payload is stringified by the client and added as a new line to the socket.
9//! 3. The server reads the requests from the socket and processes them.
10//! 4. The server sends the responses back to the client in the same format. By writing a new line in the socket
11//! 5. When the client sends the socket shutdown signal, the server closes the socket connection.
12//!
13//! Example:
14//! 1. Create a new socket connection using `/tmp/socket.sock`
15//! 2. Client sends request (writes in `/tmp/socket.sock`):
16//! ```json
17//! {
18//!   "request_id": "123",
19//!   "relayer_id": "relayer1",
20//!   "method": "sendTransaction",
21//!   "payload": {
22//!     "to": "0x1234567890123456789012345678901234567890",
23//!     "value": "1000000000000000000"
24//!   }
25//! }
26//! ```
27//! 3. Server process the requests, calls the relayer API and sends back the response (writes in `/tmp/socket.sock`):
28//! ```json
29//! {
30//!   "request_id": "123",
31//!   "result": {
32//!     "id": "123",
33//!     "status": "success"
34//!   }
35//! }
36//! ```
37//! 4. Client reads the response (reads from `/tmp/socket.sock`):
38//! ```json
39//! {
40//!   "request_id": "123",
41//!   "result": {
42//!     "id": "123",
43//!     "status": "success"
44//!   }
45//! }
46//! ```
47//! 5. Once the client finishes the execution, it sends a shutdown signal to the server.
48//! 6. The server closes the socket connection.
49//!
50
51use crate::jobs::JobProducerTrait;
52use crate::models::{
53    NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel, ThinDataAppState,
54    TransactionRepoModel,
55};
56use crate::repositories::{
57    NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
58    TransactionCounterTrait, TransactionRepository,
59};
60use std::sync::Arc;
61use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
62use tokio::net::{UnixListener, UnixStream};
63use tokio::sync::oneshot;
64
65use super::{
66    relayer_api::{RelayerApiTrait, Request},
67    PluginError,
68};
69
70pub struct SocketService {
71    socket_path: String,
72    listener: UnixListener,
73}
74
75impl SocketService {
76    /// Creates a new socket service.
77    ///
78    /// # Arguments
79    ///
80    /// * `socket_path` - The path to the socket file.
81    pub fn new(socket_path: &str) -> Result<Self, PluginError> {
82        // Remove existing socket file if it exists
83        let _ = std::fs::remove_file(socket_path);
84
85        let listener =
86            UnixListener::bind(socket_path).map_err(|e| PluginError::SocketError(e.to_string()))?;
87
88        Ok(Self {
89            socket_path: socket_path.to_string(),
90            listener,
91        })
92    }
93
94    pub fn socket_path(&self) -> &str {
95        &self.socket_path
96    }
97
98    /// Listens for incoming connections and processes the requests.
99    ///
100    /// # Arguments
101    ///
102    /// * `shutdown_rx` - A receiver for the shutdown signal.
103    /// * `state` - The application state.
104    /// * `relayer_api` - The relayer API.
105    ///
106    /// # Returns
107    ///
108    /// A vector of traces.
109    #[allow(clippy::type_complexity)]
110    pub async fn listen<RA, J, RR, TR, NR, NFR, SR, TCR, PR>(
111        self,
112        shutdown_rx: oneshot::Receiver<()>,
113        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
114        relayer_api: Arc<RA>,
115    ) -> Result<Vec<serde_json::Value>, PluginError>
116    where
117        RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR> + 'static + Send + Sync,
118        J: JobProducerTrait + Send + Sync + 'static,
119        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
120        TR: TransactionRepository
121            + Repository<TransactionRepoModel, String>
122            + Send
123            + Sync
124            + 'static,
125        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
126        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
127        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
128        TCR: TransactionCounterTrait + Send + Sync + 'static,
129        PR: PluginRepositoryTrait + Send + Sync + 'static,
130    {
131        let mut shutdown = shutdown_rx;
132
133        let mut traces = Vec::new();
134
135        loop {
136            let state = Arc::clone(&state);
137            let relayer_api = Arc::clone(&relayer_api);
138            tokio::select! {
139                Ok((stream, _)) = self.listener.accept() => {
140                    let result = tokio::spawn(Self::handle_connection::<RA, J, RR, TR, NR, NFR, SR, TCR, PR>(stream, state, relayer_api))
141                        .await
142                        .map_err(|e| PluginError::SocketError(e.to_string()))?;
143
144                    match result {
145                        Ok(trace) => traces.extend(trace),
146                        Err(e) => return Err(e),
147                    }
148                }
149                _ = &mut shutdown => {
150                    println!("Shutdown signal received. Closing listener.");
151                    break;
152                }
153            }
154        }
155
156        Ok(traces)
157    }
158
159    /// Handles a new connection.
160    ///
161    /// # Arguments
162    ///
163    /// * `stream` - The stream to the client.
164    /// * `state` - The application state.
165    /// * `relayer_api` - The relayer API.
166    ///
167    /// # Returns
168    ///
169    /// A vector of traces.
170    #[allow(clippy::type_complexity)]
171    async fn handle_connection<RA, J, RR, TR, NR, NFR, SR, TCR, PR>(
172        stream: UnixStream,
173        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
174        relayer_api: Arc<RA>,
175    ) -> Result<Vec<serde_json::Value>, PluginError>
176    where
177        RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR> + 'static + Send + Sync,
178        J: JobProducerTrait + 'static,
179        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
180        TR: TransactionRepository
181            + Repository<TransactionRepoModel, String>
182            + Send
183            + Sync
184            + 'static,
185        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
186        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
187        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
188        TCR: TransactionCounterTrait + Send + Sync + 'static,
189        PR: PluginRepositoryTrait + Send + Sync + 'static,
190    {
191        let (r, mut w) = stream.into_split();
192        let mut reader = BufReader::new(r).lines();
193        let mut traces = Vec::new();
194
195        while let Ok(Some(line)) = reader.next_line().await {
196            let trace: serde_json::Value = serde_json::from_str(&line)
197                .map_err(|e| PluginError::PluginError(format!("Failed to parse trace: {}", e)))?;
198            traces.push(trace);
199
200            let request: Request =
201                serde_json::from_str(&line).map_err(|e| PluginError::PluginError(e.to_string()))?;
202
203            let response = relayer_api.handle_request(request, &state).await;
204
205            let response_str = serde_json::to_string(&response)
206                .map_err(|e| PluginError::PluginError(e.to_string()))?
207                + "\n";
208
209            let _ = w.write_all(response_str.as_bytes()).await;
210        }
211
212        Ok(traces)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use crate::{
219        services::plugins::{MockRelayerApiTrait, PluginMethod, Response},
220        utils::mocks::mockutils::{create_mock_app_state, create_mock_evm_transaction_request},
221    };
222    use actix_web::web;
223    use std::time::Duration;
224
225    use super::*;
226
227    use tempfile::tempdir;
228    use tokio::{
229        io::{AsyncBufReadExt, BufReader},
230        time::timeout,
231    };
232
233    #[tokio::test]
234    async fn test_socket_service_listen_and_shutdown() {
235        let temp_dir = tempdir().unwrap();
236        let socket_path = temp_dir.path().join("test.sock");
237
238        let mock_relayer = MockRelayerApiTrait::default();
239
240        let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
241
242        let state = create_mock_app_state(None, None, None, None, None).await;
243        let (shutdown_tx, shutdown_rx) = oneshot::channel();
244
245        let listen_handle = tokio::spawn(async move {
246            service
247                .listen(
248                    shutdown_rx,
249                    Arc::new(web::ThinData(state)),
250                    Arc::new(mock_relayer),
251                )
252                .await
253        });
254
255        shutdown_tx.send(()).unwrap();
256
257        let result = timeout(Duration::from_millis(100), listen_handle).await;
258        assert!(result.is_ok(), "Listen handle timed out");
259        assert!(result.unwrap().is_ok(), "Listen handle returned error");
260    }
261
262    #[tokio::test]
263    async fn test_socket_service_handle_connection() {
264        let temp_dir = tempdir().unwrap();
265        let socket_path = temp_dir.path().join("test.sock");
266
267        let mut mock_relayer = MockRelayerApiTrait::default();
268
269        mock_relayer.expect_handle_request().returning(|_, _| {
270            Box::pin(async move {
271                Response {
272                    request_id: "test".to_string(),
273                    result: Some(serde_json::json!("test")),
274                    error: None,
275                }
276            })
277        });
278
279        let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
280
281        let state = create_mock_app_state(None, None, None, None, None).await;
282        let (shutdown_tx, shutdown_rx) = oneshot::channel();
283
284        let listen_handle = tokio::spawn(async move {
285            service
286                .listen(
287                    shutdown_rx,
288                    Arc::new(web::ThinData(state)),
289                    Arc::new(mock_relayer),
290                )
291                .await
292        });
293
294        tokio::time::sleep(Duration::from_millis(50)).await;
295
296        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
297            .await
298            .unwrap();
299
300        let request = Request {
301            request_id: "test".to_string(),
302            relayer_id: "test".to_string(),
303            method: PluginMethod::SendTransaction,
304            payload: serde_json::json!(create_mock_evm_transaction_request()),
305        };
306
307        let request_json = serde_json::to_string(&request).unwrap() + "\n";
308
309        client.write_all(request_json.as_bytes()).await.unwrap();
310
311        let mut reader = BufReader::new(&mut client);
312        let mut response_str = String::new();
313        let read_result = timeout(
314            Duration::from_millis(1000),
315            reader.read_line(&mut response_str),
316        )
317        .await;
318
319        assert!(
320            read_result.is_ok(),
321            "Reading response timed out: {:?}",
322            read_result
323        );
324        let bytes_read = read_result.unwrap().unwrap();
325        assert!(bytes_read > 0, "No data received");
326        shutdown_tx.send(()).unwrap();
327
328        let response: Response = serde_json::from_str(&response_str).unwrap();
329
330        assert!(response.error.is_none(), "Error should be none");
331        assert!(response.result.is_some(), "Result should be some");
332        assert_eq!(
333            response.request_id, request.request_id,
334            "Request id mismatch"
335        );
336
337        client.shutdown().await.unwrap();
338
339        let traces = listen_handle.await.unwrap().unwrap();
340
341        assert_eq!(traces.len(), 1);
342        let expected: serde_json::Value = serde_json::from_str(&request_json).unwrap();
343        let actual: serde_json::Value =
344            serde_json::from_str(&serde_json::to_string(&traces[0]).unwrap()).unwrap();
345        assert_eq!(expected, actual, "Request json mismatch with trace");
346    }
347}