1use crate::jobs::JobProducerTrait;
52use crate::models::{
53 NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel, ThinDataAppState,
54 TransactionRepoModel,
55};
56use crate::repositories::{
57 NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
58 TransactionCounterTrait, TransactionRepository,
59};
60use std::sync::Arc;
61use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
62use tokio::net::{UnixListener, UnixStream};
63use tokio::sync::oneshot;
64
65use super::{
66 relayer_api::{RelayerApiTrait, Request},
67 PluginError,
68};
69
70pub struct SocketService {
71 socket_path: String,
72 listener: UnixListener,
73}
74
75impl SocketService {
76 pub fn new(socket_path: &str) -> Result<Self, PluginError> {
82 let _ = std::fs::remove_file(socket_path);
84
85 let listener =
86 UnixListener::bind(socket_path).map_err(|e| PluginError::SocketError(e.to_string()))?;
87
88 Ok(Self {
89 socket_path: socket_path.to_string(),
90 listener,
91 })
92 }
93
94 pub fn socket_path(&self) -> &str {
95 &self.socket_path
96 }
97
98 #[allow(clippy::type_complexity)]
110 pub async fn listen<RA, J, RR, TR, NR, NFR, SR, TCR, PR>(
111 self,
112 shutdown_rx: oneshot::Receiver<()>,
113 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
114 relayer_api: Arc<RA>,
115 ) -> Result<Vec<serde_json::Value>, PluginError>
116 where
117 RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR> + 'static + Send + Sync,
118 J: JobProducerTrait + Send + Sync + 'static,
119 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
120 TR: TransactionRepository
121 + Repository<TransactionRepoModel, String>
122 + Send
123 + Sync
124 + 'static,
125 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
126 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
127 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
128 TCR: TransactionCounterTrait + Send + Sync + 'static,
129 PR: PluginRepositoryTrait + Send + Sync + 'static,
130 {
131 let mut shutdown = shutdown_rx;
132
133 let mut traces = Vec::new();
134
135 loop {
136 let state = Arc::clone(&state);
137 let relayer_api = Arc::clone(&relayer_api);
138 tokio::select! {
139 Ok((stream, _)) = self.listener.accept() => {
140 let result = tokio::spawn(Self::handle_connection::<RA, J, RR, TR, NR, NFR, SR, TCR, PR>(stream, state, relayer_api))
141 .await
142 .map_err(|e| PluginError::SocketError(e.to_string()))?;
143
144 match result {
145 Ok(trace) => traces.extend(trace),
146 Err(e) => return Err(e),
147 }
148 }
149 _ = &mut shutdown => {
150 println!("Shutdown signal received. Closing listener.");
151 break;
152 }
153 }
154 }
155
156 Ok(traces)
157 }
158
159 #[allow(clippy::type_complexity)]
171 async fn handle_connection<RA, J, RR, TR, NR, NFR, SR, TCR, PR>(
172 stream: UnixStream,
173 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
174 relayer_api: Arc<RA>,
175 ) -> Result<Vec<serde_json::Value>, PluginError>
176 where
177 RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR> + 'static + Send + Sync,
178 J: JobProducerTrait + 'static,
179 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
180 TR: TransactionRepository
181 + Repository<TransactionRepoModel, String>
182 + Send
183 + Sync
184 + 'static,
185 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
186 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
187 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
188 TCR: TransactionCounterTrait + Send + Sync + 'static,
189 PR: PluginRepositoryTrait + Send + Sync + 'static,
190 {
191 let (r, mut w) = stream.into_split();
192 let mut reader = BufReader::new(r).lines();
193 let mut traces = Vec::new();
194
195 while let Ok(Some(line)) = reader.next_line().await {
196 let trace: serde_json::Value = serde_json::from_str(&line)
197 .map_err(|e| PluginError::PluginError(format!("Failed to parse trace: {}", e)))?;
198 traces.push(trace);
199
200 let request: Request =
201 serde_json::from_str(&line).map_err(|e| PluginError::PluginError(e.to_string()))?;
202
203 let response = relayer_api.handle_request(request, &state).await;
204
205 let response_str = serde_json::to_string(&response)
206 .map_err(|e| PluginError::PluginError(e.to_string()))?
207 + "\n";
208
209 let _ = w.write_all(response_str.as_bytes()).await;
210 }
211
212 Ok(traces)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use crate::{
219 services::plugins::{MockRelayerApiTrait, PluginMethod, Response},
220 utils::mocks::mockutils::{create_mock_app_state, create_mock_evm_transaction_request},
221 };
222 use actix_web::web;
223 use std::time::Duration;
224
225 use super::*;
226
227 use tempfile::tempdir;
228 use tokio::{
229 io::{AsyncBufReadExt, BufReader},
230 time::timeout,
231 };
232
233 #[tokio::test]
234 async fn test_socket_service_listen_and_shutdown() {
235 let temp_dir = tempdir().unwrap();
236 let socket_path = temp_dir.path().join("test.sock");
237
238 let mock_relayer = MockRelayerApiTrait::default();
239
240 let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
241
242 let state = create_mock_app_state(None, None, None, None, None).await;
243 let (shutdown_tx, shutdown_rx) = oneshot::channel();
244
245 let listen_handle = tokio::spawn(async move {
246 service
247 .listen(
248 shutdown_rx,
249 Arc::new(web::ThinData(state)),
250 Arc::new(mock_relayer),
251 )
252 .await
253 });
254
255 shutdown_tx.send(()).unwrap();
256
257 let result = timeout(Duration::from_millis(100), listen_handle).await;
258 assert!(result.is_ok(), "Listen handle timed out");
259 assert!(result.unwrap().is_ok(), "Listen handle returned error");
260 }
261
262 #[tokio::test]
263 async fn test_socket_service_handle_connection() {
264 let temp_dir = tempdir().unwrap();
265 let socket_path = temp_dir.path().join("test.sock");
266
267 let mut mock_relayer = MockRelayerApiTrait::default();
268
269 mock_relayer.expect_handle_request().returning(|_, _| {
270 Box::pin(async move {
271 Response {
272 request_id: "test".to_string(),
273 result: Some(serde_json::json!("test")),
274 error: None,
275 }
276 })
277 });
278
279 let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
280
281 let state = create_mock_app_state(None, None, None, None, None).await;
282 let (shutdown_tx, shutdown_rx) = oneshot::channel();
283
284 let listen_handle = tokio::spawn(async move {
285 service
286 .listen(
287 shutdown_rx,
288 Arc::new(web::ThinData(state)),
289 Arc::new(mock_relayer),
290 )
291 .await
292 });
293
294 tokio::time::sleep(Duration::from_millis(50)).await;
295
296 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
297 .await
298 .unwrap();
299
300 let request = Request {
301 request_id: "test".to_string(),
302 relayer_id: "test".to_string(),
303 method: PluginMethod::SendTransaction,
304 payload: serde_json::json!(create_mock_evm_transaction_request()),
305 };
306
307 let request_json = serde_json::to_string(&request).unwrap() + "\n";
308
309 client.write_all(request_json.as_bytes()).await.unwrap();
310
311 let mut reader = BufReader::new(&mut client);
312 let mut response_str = String::new();
313 let read_result = timeout(
314 Duration::from_millis(1000),
315 reader.read_line(&mut response_str),
316 )
317 .await;
318
319 assert!(
320 read_result.is_ok(),
321 "Reading response timed out: {:?}",
322 read_result
323 );
324 let bytes_read = read_result.unwrap().unwrap();
325 assert!(bytes_read > 0, "No data received");
326 shutdown_tx.send(()).unwrap();
327
328 let response: Response = serde_json::from_str(&response_str).unwrap();
329
330 assert!(response.error.is_none(), "Error should be none");
331 assert!(response.result.is_some(), "Result should be some");
332 assert_eq!(
333 response.request_id, request.request_id,
334 "Request id mismatch"
335 );
336
337 client.shutdown().await.unwrap();
338
339 let traces = listen_handle.await.unwrap().unwrap();
340
341 assert_eq!(traces.len(), 1);
342 let expected: serde_json::Value = serde_json::from_str(&request_json).unwrap();
343 let actual: serde_json::Value =
344 serde_json::from_str(&serde_json::to_string(&traces[0]).unwrap()).unwrap();
345 assert_eq!(expected, actual, "Request json mismatch with trace");
346 }
347}