openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs1use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use async_trait::async_trait;
12use log::debug;
13use redis::aio::ConnectionManager;
14use redis::AsyncCommands;
15use std::fmt;
16use std::sync::Arc;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22 pub client: Arc<ConnectionManager>,
23 pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("RedisTransactionCounter")
31 .field("key_prefix", &self.key_prefix)
32 .finish()
33 }
34}
35
36impl RedisTransactionCounter {
37 pub fn new(
38 connection_manager: Arc<ConnectionManager>,
39 key_prefix: String,
40 ) -> Result<Self, RepositoryError> {
41 if key_prefix.is_empty() {
42 return Err(RepositoryError::InvalidData(
43 "Redis key prefix cannot be empty".to_string(),
44 ));
45 }
46
47 Ok(Self {
48 client: connection_manager,
49 key_prefix,
50 })
51 }
52
53 fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55 format!(
56 "{}:{}:{}:{}",
57 self.key_prefix, COUNTER_PREFIX, relayer_id, address
58 )
59 }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65 if relayer_id.is_empty() {
66 return Err(RepositoryError::InvalidData(
67 "Relayer ID cannot be empty".to_string(),
68 ));
69 }
70
71 if address.is_empty() {
72 return Err(RepositoryError::InvalidData(
73 "Address cannot be empty".to_string(),
74 ));
75 }
76
77 let key = self.counter_key(relayer_id, address);
78 debug!(
79 "Getting counter for relayer {} and address {}",
80 relayer_id, address
81 );
82
83 let mut conn = self.client.as_ref().clone();
84
85 let value: Option<u64> = conn
86 .get(&key)
87 .await
88 .map_err(|e| self.map_redis_error(e, "get_counter"))?;
89
90 debug!("Retrieved counter value: {:?}", value);
91 Ok(value)
92 }
93
94 async fn get_and_increment(
95 &self,
96 relayer_id: &str,
97 address: &str,
98 ) -> Result<u64, RepositoryError> {
99 if relayer_id.is_empty() {
100 return Err(RepositoryError::InvalidData(
101 "Relayer ID cannot be empty".to_string(),
102 ));
103 }
104
105 if address.is_empty() {
106 return Err(RepositoryError::InvalidData(
107 "Address cannot be empty".to_string(),
108 ));
109 }
110
111 let key = self.counter_key(relayer_id, address);
112 debug!(
113 "Getting and incrementing counter for relayer {} and address {}",
114 relayer_id, address
115 );
116
117 let mut conn = self.client.as_ref().clone();
118
119 let current_value: Option<u64> = conn
121 .get(&key)
122 .await
123 .map_err(|e| self.map_redis_error(e, "get_current_value"))?;
124
125 let current = current_value.unwrap_or(0);
126
127 let mut pipe = redis::pipe();
129 pipe.atomic();
130 pipe.set(&key, current + 1);
131
132 pipe.exec_async(&mut conn)
133 .await
134 .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
135
136 debug!("Counter incremented from {} to {}", current, current + 1);
137 Ok(current)
138 }
139
140 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
141 if relayer_id.is_empty() {
142 return Err(RepositoryError::InvalidData(
143 "Relayer ID cannot be empty".to_string(),
144 ));
145 }
146
147 if address.is_empty() {
148 return Err(RepositoryError::InvalidData(
149 "Address cannot be empty".to_string(),
150 ));
151 }
152
153 let key = self.counter_key(relayer_id, address);
154 debug!(
155 "Decrementing counter for relayer {} and address {}",
156 relayer_id, address
157 );
158
159 let mut conn = self.client.as_ref().clone();
160
161 let current_value: Option<u64> = conn
163 .get(&key)
164 .await
165 .map_err(|e| self.map_redis_error(e, "get_current_value_for_decrement"))?;
166
167 let current = current_value.ok_or_else(|| {
168 RepositoryError::NotFound(format!(
169 "Counter not found for relayer {} and address {}",
170 relayer_id, address
171 ))
172 })?;
173
174 let new_value = if current > 0 { current - 1 } else { 0 };
176
177 let _: () = conn
178 .set(&key, new_value)
179 .await
180 .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
181
182 debug!("Counter decremented from {} to {}", current, new_value);
183 Ok(new_value)
184 }
185
186 async fn set(
187 &self,
188 relayer_id: &str,
189 address: &str,
190 value: u64,
191 ) -> Result<(), RepositoryError> {
192 if relayer_id.is_empty() {
193 return Err(RepositoryError::InvalidData(
194 "Relayer ID cannot be empty".to_string(),
195 ));
196 }
197
198 if address.is_empty() {
199 return Err(RepositoryError::InvalidData(
200 "Address cannot be empty".to_string(),
201 ));
202 }
203
204 let key = self.counter_key(relayer_id, address);
205 debug!(
206 "Setting counter for relayer {} and address {} to {}",
207 relayer_id, address, value
208 );
209
210 let mut conn = self.client.as_ref().clone();
211
212 let _: () = conn
213 .set(&key, value)
214 .await
215 .map_err(|e| self.map_redis_error(e, "set_counter"))?;
216
217 debug!("Counter set to {}", value);
218 Ok(())
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use redis::aio::ConnectionManager;
226 use std::sync::Arc;
227 use tokio;
228 use uuid::Uuid;
229
230 async fn setup_test_repo() -> RedisTransactionCounter {
231 let redis_url =
232 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
233 let client = redis::Client::open(redis_url).expect("Failed to create Redis client");
234 let connection_manager = ConnectionManager::new(client)
235 .await
236 .expect("Failed to create Redis connection manager");
237
238 RedisTransactionCounter::new(Arc::new(connection_manager), "test_counter".to_string())
239 .expect("Failed to create Redis transaction counter")
240 }
241
242 #[tokio::test]
243 #[ignore = "Requires active Redis instance"]
244 async fn test_get_nonexistent_counter() {
245 let repo = setup_test_repo().await;
246 let random_id = Uuid::new_v4().to_string();
247 let result = repo.get(&random_id, "0x1234").await.unwrap();
248 assert_eq!(result, None);
249 }
250
251 #[tokio::test]
252 #[ignore = "Requires active Redis instance"]
253 async fn test_set_and_get_counter() {
254 let repo = setup_test_repo().await;
255 let relayer_id = uuid::Uuid::new_v4().to_string();
256 let address = uuid::Uuid::new_v4().to_string();
257
258 repo.set(&relayer_id, &address, 100).await.unwrap();
259 let result = repo.get(&relayer_id, &address).await.unwrap();
260 assert_eq!(result, Some(100));
261 }
262
263 #[tokio::test]
264 #[ignore = "Requires active Redis instance"]
265 async fn test_get_and_increment() {
266 let repo = setup_test_repo().await;
267 let relayer_id = uuid::Uuid::new_v4().to_string();
268 let address = uuid::Uuid::new_v4().to_string();
269
270 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
272 assert_eq!(result, 0);
273
274 let current = repo.get(&relayer_id, &address).await.unwrap();
275 assert_eq!(current, Some(1));
276
277 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
279 assert_eq!(result, 1);
280
281 let current = repo.get(&relayer_id, &address).await.unwrap();
282 assert_eq!(current, Some(2));
283 }
284
285 #[tokio::test]
286 #[ignore = "Requires active Redis instance"]
287 async fn test_decrement() {
288 let repo = setup_test_repo().await;
289 let relayer_id = uuid::Uuid::new_v4().to_string();
290 let address = uuid::Uuid::new_v4().to_string();
291
292 repo.set(&relayer_id, &address, 5).await.unwrap();
294
295 let result = repo.decrement(&relayer_id, &address).await.unwrap();
297 assert_eq!(result, 4);
298
299 let current = repo.get(&relayer_id, &address).await.unwrap();
300 assert_eq!(current, Some(4));
301 }
302
303 #[tokio::test]
304 #[ignore = "Requires active Redis instance"]
305 async fn test_decrement_not_found() {
306 let repo = setup_test_repo().await;
307 let result = repo.decrement("nonexistent", "0x1234").await;
308 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
309 }
310
311 #[tokio::test]
312 #[ignore = "Requires active Redis instance"]
313 async fn test_empty_validation() {
314 let repo = setup_test_repo().await;
315
316 let result = repo.get("", "0x1234").await;
318 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
319
320 let result = repo.get("relayer", "").await;
322 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
323 }
324
325 #[tokio::test]
326 #[ignore = "Requires active Redis instance"]
327 async fn test_multiple_relayers() {
328 let repo = setup_test_repo().await;
329 let relayer_1 = uuid::Uuid::new_v4().to_string();
330 let relayer_2 = uuid::Uuid::new_v4().to_string();
331 let address_1 = uuid::Uuid::new_v4().to_string();
332 let address_2 = uuid::Uuid::new_v4().to_string();
333
334 repo.set(&relayer_1, &address_1, 100).await.unwrap();
336 repo.set(&relayer_1, &address_2, 200).await.unwrap();
337 repo.set(&relayer_2, &address_1, 300).await.unwrap();
338
339 assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
341 assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
342 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
343
344 assert_eq!(
346 repo.get_and_increment(&relayer_1, &address_1)
347 .await
348 .unwrap(),
349 100
350 );
351 assert_eq!(
352 repo.get_and_increment(&relayer_1, &address_1)
353 .await
354 .unwrap(),
355 101
356 );
357 assert_eq!(
358 repo.get_and_increment(&relayer_1, &address_2)
359 .await
360 .unwrap(),
361 200
362 );
363 assert_eq!(
364 repo.get_and_increment(&relayer_1, &address_2)
365 .await
366 .unwrap(),
367 201
368 );
369 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
370 }
371}