openzeppelin_relayer/services/plugins/
runner.rs

1//! This module is the orchestrator of the plugin execution.
2//!
3//! 1. Initiates a socket connection to the relayer server - socket.rs
4//! 2. Executes the plugin script - script_executor.rs
5//! 3. Sends the shutdown signal to the relayer server - socket.rs
6//! 4. Waits for the relayer server to finish the execution - socket.rs
7//! 5. Returns the output of the script - script_executor.rs
8//!
9use std::{sync::Arc, time::Duration};
10
11use crate::services::plugins::{RelayerApi, ScriptExecutor, ScriptResult, SocketService};
12use crate::{
13    jobs::JobProducerTrait,
14    models::{
15        NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel,
16        ThinDataAppState, TransactionRepoModel,
17    },
18    repositories::{
19        NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
20        TransactionCounterTrait, TransactionRepository,
21    },
22};
23
24use super::PluginError;
25use async_trait::async_trait;
26use tokio::{sync::oneshot, time::timeout};
27
28#[cfg(test)]
29use mockall::automock;
30
31#[cfg_attr(test, automock)]
32#[async_trait]
33pub trait PluginRunnerTrait {
34    #[allow(clippy::type_complexity)]
35    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR>(
36        &self,
37        socket_path: &str,
38        script_path: String,
39        timeout_duration: Duration,
40        script_params: String,
41        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
42    ) -> Result<ScriptResult, PluginError>
43    where
44        J: JobProducerTrait + Send + Sync + 'static,
45        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
46        TR: TransactionRepository
47            + Repository<TransactionRepoModel, String>
48            + Send
49            + Sync
50            + 'static,
51        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
52        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
53        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
54        TCR: TransactionCounterTrait + Send + Sync + 'static,
55        PR: PluginRepositoryTrait + Send + Sync + 'static;
56}
57
58#[derive(Default)]
59pub struct PluginRunner;
60
61#[allow(clippy::type_complexity)]
62impl PluginRunner {
63    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR>(
64        &self,
65        socket_path: &str,
66        script_path: String,
67        timeout_duration: Duration,
68        script_params: String,
69        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
70    ) -> Result<ScriptResult, PluginError>
71    where
72        J: JobProducerTrait + 'static,
73        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
74        TR: TransactionRepository
75            + Repository<TransactionRepoModel, String>
76            + Send
77            + Sync
78            + 'static,
79        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
80        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
81        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
82        TCR: TransactionCounterTrait + Send + Sync + 'static,
83        PR: PluginRepositoryTrait + Send + Sync + 'static,
84    {
85        let socket_service = SocketService::new(socket_path)?;
86        let socket_path_clone = socket_service.socket_path().to_string();
87
88        let (shutdown_tx, shutdown_rx) = oneshot::channel();
89
90        let server_handle = tokio::spawn(async move {
91            let relayer_api = Arc::new(RelayerApi);
92            socket_service.listen(shutdown_rx, state, relayer_api).await
93        });
94
95        let mut script_result = match timeout(
96            timeout_duration,
97            ScriptExecutor::execute_typescript(script_path, socket_path_clone, script_params),
98        )
99        .await
100        {
101            Ok(result) => result?,
102            Err(_) => {
103                // ensures the socket gets closed.
104                let _ = shutdown_tx.send(());
105                return Err(PluginError::ScriptTimeout(timeout_duration.as_secs()));
106            }
107        };
108
109        let _ = shutdown_tx.send(());
110
111        let server_handle = server_handle
112            .await
113            .map_err(|e| PluginError::SocketError(e.to_string()))?;
114
115        match server_handle {
116            Ok(traces) => {
117                script_result.trace = traces;
118            }
119            Err(e) => {
120                return Err(PluginError::SocketError(e.to_string()));
121            }
122        }
123
124        Ok(script_result)
125    }
126}
127
128#[async_trait]
129impl PluginRunnerTrait for PluginRunner {
130    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR>(
131        &self,
132        socket_path: &str,
133        script_path: String,
134        timeout_duration: Duration,
135        script_params: String,
136        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR>>,
137    ) -> Result<ScriptResult, PluginError>
138    where
139        J: JobProducerTrait + Send + Sync + 'static,
140        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
141        TR: TransactionRepository
142            + Repository<TransactionRepoModel, String>
143            + Send
144            + Sync
145            + 'static,
146        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
147        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
148        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
149        TCR: TransactionCounterTrait + Send + Sync + 'static,
150        PR: PluginRepositoryTrait + Send + Sync + 'static,
151    {
152        self.run(
153            socket_path,
154            script_path,
155            timeout_duration,
156            script_params,
157            state,
158        )
159        .await
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use actix_web::web;
166    use std::fs;
167
168    use crate::{
169        jobs::MockJobProducerTrait,
170        repositories::{
171            NetworkRepositoryStorage, NotificationRepositoryStorage, PluginRepositoryStorage,
172            RelayerRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage,
173            TransactionRepositoryStorage,
174        },
175        services::plugins::LogLevel,
176        utils::mocks::mockutils::create_mock_app_state,
177    };
178    use tempfile::tempdir;
179
180    use super::*;
181
182    static TS_CONFIG: &str = r#"
183        {
184            "compilerOptions": {
185              "target": "es2016",
186              "module": "commonjs",
187              "esModuleInterop": true,
188              "forceConsistentCasingInFileNames": true,
189              "strict": true,
190              "skipLibCheck": true
191            }
192          }
193    "#;
194
195    #[tokio::test]
196    async fn test_run() {
197        let temp_dir = tempdir().unwrap();
198        let ts_config = temp_dir.path().join("tsconfig.json");
199        let script_path = temp_dir.path().join("test_run.ts");
200        let socket_path = temp_dir.path().join("test_run.sock");
201
202        let content = r#"
203            export async function handler(api: any, params: any) {
204                console.log('test');
205                console.error('test-error');
206                return 'test-result';
207            }
208        "#;
209        fs::write(script_path.clone(), content).unwrap();
210        fs::write(ts_config.clone(), TS_CONFIG.as_bytes()).unwrap();
211
212        let state = create_mock_app_state(None, None, None, None, None).await;
213
214        let plugin_runner = PluginRunner;
215        let result = plugin_runner
216            .run::<MockJobProducerTrait, RelayerRepositoryStorage, TransactionRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage, PluginRepositoryStorage>(
217                &socket_path.display().to_string(),
218                script_path.display().to_string(),
219                Duration::from_secs(10),
220                "{ \"test\": \"test\" }".to_string(),
221                Arc::new(web::ThinData(state)),
222            )
223            .await;
224
225        assert!(result.is_ok());
226        let result = result.unwrap();
227        assert_eq!(result.logs[0].level, LogLevel::Log);
228        assert_eq!(result.logs[0].message, "test");
229        assert_eq!(result.logs[1].level, LogLevel::Error);
230        assert_eq!(result.logs[1].message, "test-error");
231        assert_eq!(result.return_value, "test-result");
232    }
233
234    #[tokio::test]
235    async fn test_run_timeout() {
236        let temp_dir = tempdir().unwrap();
237        let ts_config = temp_dir.path().join("tsconfig.json");
238        let script_path = temp_dir.path().join("test_simple_timeout.ts");
239        let socket_path = temp_dir.path().join("test_simple_timeout.sock");
240
241        // Script that takes 200ms
242        let content = r#"
243            function sleep(ms) {
244                return new Promise(resolve => setTimeout(resolve, ms));
245            }
246
247            async function main() {
248                await sleep(200); // 200ms
249                console.log(JSON.stringify({ level: 'result', message: 'Should not reach here' }));
250            }
251
252            main();
253        "#;
254
255        fs::write(script_path.clone(), content).unwrap();
256        fs::write(ts_config.clone(), TS_CONFIG.as_bytes()).unwrap();
257
258        let state = create_mock_app_state(None, None, None, None, None).await;
259        let plugin_runner = PluginRunner;
260
261        // Use 100ms timeout for a 200ms script
262        let result = plugin_runner
263        .run::<MockJobProducerTrait, RelayerRepositoryStorage, TransactionRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage, PluginRepositoryStorage>(
264            &socket_path.display().to_string(),
265                script_path.display().to_string(),
266                Duration::from_millis(100), // 100ms timeout
267                "{}".to_string(),
268                Arc::new(web::ThinData(state)),
269            )
270            .await;
271
272        // Should timeout
273        assert!(result.is_err());
274        assert!(result
275            .unwrap_err()
276            .to_string()
277            .contains("Script execution timed out after"));
278    }
279}