openzeppelin_relayer/repositories/plugin/
plugin_in_memory.rs

1//! This module provides an in-memory implementation of plugins.
2//!
3//! The `InMemoryPluginRepository` struct is used to store and retrieve plugins
4//! script paths for further execution.
5use crate::{
6    models::{PaginationQuery, PluginModel},
7    repositories::{PaginatedResult, PluginRepositoryTrait, RepositoryError},
8};
9
10use async_trait::async_trait;
11
12use std::collections::HashMap;
13use tokio::sync::{Mutex, MutexGuard};
14
15#[derive(Debug)]
16pub struct InMemoryPluginRepository {
17    store: Mutex<HashMap<String, PluginModel>>,
18}
19
20impl Clone for InMemoryPluginRepository {
21    fn clone(&self) -> Self {
22        // Try to get the current data, or use empty HashMap if lock fails
23        let data = self
24            .store
25            .try_lock()
26            .map(|guard| guard.clone())
27            .unwrap_or_else(|_| HashMap::new());
28
29        Self {
30            store: Mutex::new(data),
31        }
32    }
33}
34
35impl InMemoryPluginRepository {
36    pub fn new() -> Self {
37        Self {
38            store: Mutex::new(HashMap::new()),
39        }
40    }
41
42    pub async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
43        let store = Self::acquire_lock(&self.store).await?;
44        Ok(store.get(id).cloned())
45    }
46
47    async fn acquire_lock<T>(lock: &Mutex<T>) -> Result<MutexGuard<T>, RepositoryError> {
48        Ok(lock.lock().await)
49    }
50}
51
52impl Default for InMemoryPluginRepository {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58#[async_trait]
59impl PluginRepositoryTrait for InMemoryPluginRepository {
60    async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
61        let store = Self::acquire_lock(&self.store).await?;
62        Ok(store.get(id).cloned())
63    }
64
65    async fn add(&self, plugin: PluginModel) -> Result<(), RepositoryError> {
66        let mut store = Self::acquire_lock(&self.store).await?;
67        store.insert(plugin.id.clone(), plugin);
68        Ok(())
69    }
70
71    async fn list_paginated(
72        &self,
73        query: PaginationQuery,
74    ) -> Result<PaginatedResult<PluginModel>, RepositoryError> {
75        let total = self.count().await?;
76        let start = ((query.page - 1) * query.per_page) as usize;
77
78        let items = self
79            .store
80            .lock()
81            .await
82            .values()
83            .skip(start)
84            .take(query.per_page as usize)
85            .cloned()
86            .collect();
87
88        Ok(PaginatedResult {
89            items,
90            total: total as u64,
91            page: query.page,
92            per_page: query.per_page,
93        })
94    }
95
96    async fn count(&self) -> Result<usize, RepositoryError> {
97        let store = self.store.lock().await;
98        Ok(store.len())
99    }
100
101    async fn has_entries(&self) -> Result<bool, RepositoryError> {
102        let store = Self::acquire_lock(&self.store).await?;
103        Ok(!store.is_empty())
104    }
105
106    async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
107        let mut store = Self::acquire_lock(&self.store).await?;
108        store.clear();
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use crate::{config::PluginFileConfig, constants::DEFAULT_PLUGIN_TIMEOUT_SECONDS};
116
117    use super::*;
118    use std::{sync::Arc, time::Duration};
119
120    #[tokio::test]
121    async fn test_in_memory_plugin_repository() {
122        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
123
124        // Test add and get_by_id
125        let plugin = PluginModel {
126            id: "test-plugin".to_string(),
127            path: "test-path".to_string(),
128            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
129        };
130        plugin_repository.add(plugin.clone()).await.unwrap();
131        assert_eq!(
132            plugin_repository.get_by_id("test-plugin").await.unwrap(),
133            Some(plugin)
134        );
135    }
136
137    #[tokio::test]
138    async fn test_get_nonexistent_plugin() {
139        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
140
141        let result = plugin_repository.get_by_id("test-plugin").await;
142        assert!(matches!(result, Ok(None)));
143    }
144
145    #[tokio::test]
146    async fn test_try_from() {
147        let plugin = PluginFileConfig {
148            id: "test-plugin".to_string(),
149            path: "test-path".to_string(),
150            timeout: None,
151        };
152        let result = PluginModel::try_from(plugin);
153        assert!(result.is_ok());
154        assert_eq!(
155            result.unwrap(),
156            PluginModel {
157                id: "test-plugin".to_string(),
158                path: "test-path".to_string(),
159                timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
160            }
161        );
162    }
163
164    #[tokio::test]
165    async fn test_get_by_id() {
166        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
167
168        let plugin = PluginModel {
169            id: "test-plugin".to_string(),
170            path: "test-path".to_string(),
171            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
172        };
173        plugin_repository.add(plugin.clone()).await.unwrap();
174        assert_eq!(
175            plugin_repository.get_by_id("test-plugin").await.unwrap(),
176            Some(plugin)
177        );
178    }
179
180    #[tokio::test]
181    async fn test_list_paginated() {
182        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
183
184        let plugin1 = PluginModel {
185            id: "test-plugin1".to_string(),
186            path: "test-path1".to_string(),
187            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
188        };
189
190        let plugin2 = PluginModel {
191            id: "test-plugin2".to_string(),
192            path: "test-path2".to_string(),
193            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
194        };
195
196        plugin_repository.add(plugin1.clone()).await.unwrap();
197        plugin_repository.add(plugin2.clone()).await.unwrap();
198
199        let query = PaginationQuery {
200            page: 1,
201            per_page: 2,
202        };
203
204        let result = plugin_repository.list_paginated(query).await;
205        assert!(result.is_ok());
206        let result = result.unwrap();
207        assert_eq!(result.items.len(), 2);
208    }
209
210    #[tokio::test]
211    async fn test_has_entries() {
212        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
213        assert!(!plugin_repository.has_entries().await.unwrap());
214        plugin_repository
215            .add(PluginModel {
216                id: "test-plugin".to_string(),
217                path: "test-path".to_string(),
218                timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
219            })
220            .await
221            .unwrap();
222
223        assert!(plugin_repository.has_entries().await.unwrap());
224        plugin_repository.drop_all_entries().await.unwrap();
225        assert!(!plugin_repository.has_entries().await.unwrap());
226    }
227
228    #[tokio::test]
229    async fn test_drop_all_entries() {
230        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
231        plugin_repository
232            .add(PluginModel {
233                id: "test-plugin".to_string(),
234                path: "test-path".to_string(),
235                timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
236            })
237            .await
238            .unwrap();
239
240        assert!(plugin_repository.has_entries().await.unwrap());
241        plugin_repository.drop_all_entries().await.unwrap();
242        assert!(!plugin_repository.has_entries().await.unwrap());
243    }
244}