openzeppelin_relayer/repositories/plugin/
plugin_in_memory.rs1use crate::{
6 models::{PaginationQuery, PluginModel},
7 repositories::{PaginatedResult, PluginRepositoryTrait, RepositoryError},
8};
9
10use async_trait::async_trait;
11
12use std::collections::HashMap;
13use tokio::sync::{Mutex, MutexGuard};
14
15#[derive(Debug)]
16pub struct InMemoryPluginRepository {
17 store: Mutex<HashMap<String, PluginModel>>,
18}
19
20impl Clone for InMemoryPluginRepository {
21 fn clone(&self) -> Self {
22 let data = self
24 .store
25 .try_lock()
26 .map(|guard| guard.clone())
27 .unwrap_or_else(|_| HashMap::new());
28
29 Self {
30 store: Mutex::new(data),
31 }
32 }
33}
34
35impl InMemoryPluginRepository {
36 pub fn new() -> Self {
37 Self {
38 store: Mutex::new(HashMap::new()),
39 }
40 }
41
42 pub async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
43 let store = Self::acquire_lock(&self.store).await?;
44 Ok(store.get(id).cloned())
45 }
46
47 async fn acquire_lock<T>(lock: &Mutex<T>) -> Result<MutexGuard<T>, RepositoryError> {
48 Ok(lock.lock().await)
49 }
50}
51
52impl Default for InMemoryPluginRepository {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58#[async_trait]
59impl PluginRepositoryTrait for InMemoryPluginRepository {
60 async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
61 let store = Self::acquire_lock(&self.store).await?;
62 Ok(store.get(id).cloned())
63 }
64
65 async fn add(&self, plugin: PluginModel) -> Result<(), RepositoryError> {
66 let mut store = Self::acquire_lock(&self.store).await?;
67 store.insert(plugin.id.clone(), plugin);
68 Ok(())
69 }
70
71 async fn list_paginated(
72 &self,
73 query: PaginationQuery,
74 ) -> Result<PaginatedResult<PluginModel>, RepositoryError> {
75 let total = self.count().await?;
76 let start = ((query.page - 1) * query.per_page) as usize;
77
78 let items = self
79 .store
80 .lock()
81 .await
82 .values()
83 .skip(start)
84 .take(query.per_page as usize)
85 .cloned()
86 .collect();
87
88 Ok(PaginatedResult {
89 items,
90 total: total as u64,
91 page: query.page,
92 per_page: query.per_page,
93 })
94 }
95
96 async fn count(&self) -> Result<usize, RepositoryError> {
97 let store = self.store.lock().await;
98 Ok(store.len())
99 }
100
101 async fn has_entries(&self) -> Result<bool, RepositoryError> {
102 let store = Self::acquire_lock(&self.store).await?;
103 Ok(!store.is_empty())
104 }
105
106 async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
107 let mut store = Self::acquire_lock(&self.store).await?;
108 store.clear();
109 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::{config::PluginFileConfig, constants::DEFAULT_PLUGIN_TIMEOUT_SECONDS};
116
117 use super::*;
118 use std::{sync::Arc, time::Duration};
119
120 #[tokio::test]
121 async fn test_in_memory_plugin_repository() {
122 let plugin_repository = Arc::new(InMemoryPluginRepository::new());
123
124 let plugin = PluginModel {
126 id: "test-plugin".to_string(),
127 path: "test-path".to_string(),
128 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
129 };
130 plugin_repository.add(plugin.clone()).await.unwrap();
131 assert_eq!(
132 plugin_repository.get_by_id("test-plugin").await.unwrap(),
133 Some(plugin)
134 );
135 }
136
137 #[tokio::test]
138 async fn test_get_nonexistent_plugin() {
139 let plugin_repository = Arc::new(InMemoryPluginRepository::new());
140
141 let result = plugin_repository.get_by_id("test-plugin").await;
142 assert!(matches!(result, Ok(None)));
143 }
144
145 #[tokio::test]
146 async fn test_try_from() {
147 let plugin = PluginFileConfig {
148 id: "test-plugin".to_string(),
149 path: "test-path".to_string(),
150 timeout: None,
151 };
152 let result = PluginModel::try_from(plugin);
153 assert!(result.is_ok());
154 assert_eq!(
155 result.unwrap(),
156 PluginModel {
157 id: "test-plugin".to_string(),
158 path: "test-path".to_string(),
159 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
160 }
161 );
162 }
163
164 #[tokio::test]
165 async fn test_get_by_id() {
166 let plugin_repository = Arc::new(InMemoryPluginRepository::new());
167
168 let plugin = PluginModel {
169 id: "test-plugin".to_string(),
170 path: "test-path".to_string(),
171 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
172 };
173 plugin_repository.add(plugin.clone()).await.unwrap();
174 assert_eq!(
175 plugin_repository.get_by_id("test-plugin").await.unwrap(),
176 Some(plugin)
177 );
178 }
179
180 #[tokio::test]
181 async fn test_list_paginated() {
182 let plugin_repository = Arc::new(InMemoryPluginRepository::new());
183
184 let plugin1 = PluginModel {
185 id: "test-plugin1".to_string(),
186 path: "test-path1".to_string(),
187 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
188 };
189
190 let plugin2 = PluginModel {
191 id: "test-plugin2".to_string(),
192 path: "test-path2".to_string(),
193 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
194 };
195
196 plugin_repository.add(plugin1.clone()).await.unwrap();
197 plugin_repository.add(plugin2.clone()).await.unwrap();
198
199 let query = PaginationQuery {
200 page: 1,
201 per_page: 2,
202 };
203
204 let result = plugin_repository.list_paginated(query).await;
205 assert!(result.is_ok());
206 let result = result.unwrap();
207 assert_eq!(result.items.len(), 2);
208 }
209
210 #[tokio::test]
211 async fn test_has_entries() {
212 let plugin_repository = Arc::new(InMemoryPluginRepository::new());
213 assert!(!plugin_repository.has_entries().await.unwrap());
214 plugin_repository
215 .add(PluginModel {
216 id: "test-plugin".to_string(),
217 path: "test-path".to_string(),
218 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
219 })
220 .await
221 .unwrap();
222
223 assert!(plugin_repository.has_entries().await.unwrap());
224 plugin_repository.drop_all_entries().await.unwrap();
225 assert!(!plugin_repository.has_entries().await.unwrap());
226 }
227
228 #[tokio::test]
229 async fn test_drop_all_entries() {
230 let plugin_repository = Arc::new(InMemoryPluginRepository::new());
231 plugin_repository
232 .add(PluginModel {
233 id: "test-plugin".to_string(),
234 path: "test-path".to_string(),
235 timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
236 })
237 .await
238 .unwrap();
239
240 assert!(plugin_repository.has_entries().await.unwrap());
241 plugin_repository.drop_all_entries().await.unwrap();
242 assert!(!plugin_repository.has_entries().await.unwrap());
243 }
244}