openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs

1//! Redis implementation of the transaction counter.
2//!
3//! This module provides a Redis-based implementation of the `TransactionCounterTrait`,
4//! allowing transaction counters to be stored and retrieved from a Redis database.
5//! The implementation includes comprehensive error handling, logging, and atomic operations
6//! to ensure consistency when incrementing and decrementing counters.
7
8use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use async_trait::async_trait;
12use log::debug;
13use redis::aio::ConnectionManager;
14use redis::AsyncCommands;
15use std::fmt;
16use std::sync::Arc;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22    pub client: Arc<ConnectionManager>,
23    pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("RedisTransactionCounter")
31            .field("key_prefix", &self.key_prefix)
32            .finish()
33    }
34}
35
36impl RedisTransactionCounter {
37    pub fn new(
38        connection_manager: Arc<ConnectionManager>,
39        key_prefix: String,
40    ) -> Result<Self, RepositoryError> {
41        if key_prefix.is_empty() {
42            return Err(RepositoryError::InvalidData(
43                "Redis key prefix cannot be empty".to_string(),
44            ));
45        }
46
47        Ok(Self {
48            client: connection_manager,
49            key_prefix,
50        })
51    }
52
53    /// Generate key for transaction counter: {prefix}:transaction_counter:{relayer_id}:{address}
54    fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55        format!(
56            "{}:{}:{}:{}",
57            self.key_prefix, COUNTER_PREFIX, relayer_id, address
58        )
59    }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64    async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65        if relayer_id.is_empty() {
66            return Err(RepositoryError::InvalidData(
67                "Relayer ID cannot be empty".to_string(),
68            ));
69        }
70
71        if address.is_empty() {
72            return Err(RepositoryError::InvalidData(
73                "Address cannot be empty".to_string(),
74            ));
75        }
76
77        let key = self.counter_key(relayer_id, address);
78        debug!(
79            "Getting counter for relayer {} and address {}",
80            relayer_id, address
81        );
82
83        let mut conn = self.client.as_ref().clone();
84
85        let value: Option<u64> = conn
86            .get(&key)
87            .await
88            .map_err(|e| self.map_redis_error(e, "get_counter"))?;
89
90        debug!("Retrieved counter value: {:?}", value);
91        Ok(value)
92    }
93
94    async fn get_and_increment(
95        &self,
96        relayer_id: &str,
97        address: &str,
98    ) -> Result<u64, RepositoryError> {
99        if relayer_id.is_empty() {
100            return Err(RepositoryError::InvalidData(
101                "Relayer ID cannot be empty".to_string(),
102            ));
103        }
104
105        if address.is_empty() {
106            return Err(RepositoryError::InvalidData(
107                "Address cannot be empty".to_string(),
108            ));
109        }
110
111        let key = self.counter_key(relayer_id, address);
112        debug!(
113            "Getting and incrementing counter for relayer {} and address {}",
114            relayer_id, address
115        );
116
117        let mut conn = self.client.as_ref().clone();
118
119        // Get current value (or 0 if not exists)
120        let current_value: Option<u64> = conn
121            .get(&key)
122            .await
123            .map_err(|e| self.map_redis_error(e, "get_current_value"))?;
124
125        let current = current_value.unwrap_or(0);
126
127        // Use a pipeline to atomically set the incremented value
128        let mut pipe = redis::pipe();
129        pipe.atomic();
130        pipe.set(&key, current + 1);
131
132        pipe.exec_async(&mut conn)
133            .await
134            .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
135
136        debug!("Counter incremented from {} to {}", current, current + 1);
137        Ok(current)
138    }
139
140    async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
141        if relayer_id.is_empty() {
142            return Err(RepositoryError::InvalidData(
143                "Relayer ID cannot be empty".to_string(),
144            ));
145        }
146
147        if address.is_empty() {
148            return Err(RepositoryError::InvalidData(
149                "Address cannot be empty".to_string(),
150            ));
151        }
152
153        let key = self.counter_key(relayer_id, address);
154        debug!(
155            "Decrementing counter for relayer {} and address {}",
156            relayer_id, address
157        );
158
159        let mut conn = self.client.as_ref().clone();
160
161        // Check if counter exists
162        let current_value: Option<u64> = conn
163            .get(&key)
164            .await
165            .map_err(|e| self.map_redis_error(e, "get_current_value_for_decrement"))?;
166
167        let current = current_value.ok_or_else(|| {
168            RepositoryError::NotFound(format!(
169                "Counter not found for relayer {} and address {}",
170                relayer_id, address
171            ))
172        })?;
173
174        // Only decrement if current value is greater than 0
175        let new_value = if current > 0 { current - 1 } else { 0 };
176
177        let _: () = conn
178            .set(&key, new_value)
179            .await
180            .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
181
182        debug!("Counter decremented from {} to {}", current, new_value);
183        Ok(new_value)
184    }
185
186    async fn set(
187        &self,
188        relayer_id: &str,
189        address: &str,
190        value: u64,
191    ) -> Result<(), RepositoryError> {
192        if relayer_id.is_empty() {
193            return Err(RepositoryError::InvalidData(
194                "Relayer ID cannot be empty".to_string(),
195            ));
196        }
197
198        if address.is_empty() {
199            return Err(RepositoryError::InvalidData(
200                "Address cannot be empty".to_string(),
201            ));
202        }
203
204        let key = self.counter_key(relayer_id, address);
205        debug!(
206            "Setting counter for relayer {} and address {} to {}",
207            relayer_id, address, value
208        );
209
210        let mut conn = self.client.as_ref().clone();
211
212        let _: () = conn
213            .set(&key, value)
214            .await
215            .map_err(|e| self.map_redis_error(e, "set_counter"))?;
216
217        debug!("Counter set to {}", value);
218        Ok(())
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use redis::aio::ConnectionManager;
226    use std::sync::Arc;
227    use tokio;
228    use uuid::Uuid;
229
230    async fn setup_test_repo() -> RedisTransactionCounter {
231        let redis_url =
232            std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
233        let client = redis::Client::open(redis_url).expect("Failed to create Redis client");
234        let connection_manager = ConnectionManager::new(client)
235            .await
236            .expect("Failed to create Redis connection manager");
237
238        RedisTransactionCounter::new(Arc::new(connection_manager), "test_counter".to_string())
239            .expect("Failed to create Redis transaction counter")
240    }
241
242    #[tokio::test]
243    #[ignore = "Requires active Redis instance"]
244    async fn test_get_nonexistent_counter() {
245        let repo = setup_test_repo().await;
246        let random_id = Uuid::new_v4().to_string();
247        let result = repo.get(&random_id, "0x1234").await.unwrap();
248        assert_eq!(result, None);
249    }
250
251    #[tokio::test]
252    #[ignore = "Requires active Redis instance"]
253    async fn test_set_and_get_counter() {
254        let repo = setup_test_repo().await;
255        let relayer_id = uuid::Uuid::new_v4().to_string();
256        let address = uuid::Uuid::new_v4().to_string();
257
258        repo.set(&relayer_id, &address, 100).await.unwrap();
259        let result = repo.get(&relayer_id, &address).await.unwrap();
260        assert_eq!(result, Some(100));
261    }
262
263    #[tokio::test]
264    #[ignore = "Requires active Redis instance"]
265    async fn test_get_and_increment() {
266        let repo = setup_test_repo().await;
267        let relayer_id = uuid::Uuid::new_v4().to_string();
268        let address = uuid::Uuid::new_v4().to_string();
269
270        // First increment should return 0 and set to 1
271        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
272        assert_eq!(result, 0);
273
274        let current = repo.get(&relayer_id, &address).await.unwrap();
275        assert_eq!(current, Some(1));
276
277        // Second increment should return 1 and set to 2
278        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
279        assert_eq!(result, 1);
280
281        let current = repo.get(&relayer_id, &address).await.unwrap();
282        assert_eq!(current, Some(2));
283    }
284
285    #[tokio::test]
286    #[ignore = "Requires active Redis instance"]
287    async fn test_decrement() {
288        let repo = setup_test_repo().await;
289        let relayer_id = uuid::Uuid::new_v4().to_string();
290        let address = uuid::Uuid::new_v4().to_string();
291
292        // Set initial value
293        repo.set(&relayer_id, &address, 5).await.unwrap();
294
295        // Decrement should return 4
296        let result = repo.decrement(&relayer_id, &address).await.unwrap();
297        assert_eq!(result, 4);
298
299        let current = repo.get(&relayer_id, &address).await.unwrap();
300        assert_eq!(current, Some(4));
301    }
302
303    #[tokio::test]
304    #[ignore = "Requires active Redis instance"]
305    async fn test_decrement_not_found() {
306        let repo = setup_test_repo().await;
307        let result = repo.decrement("nonexistent", "0x1234").await;
308        assert!(matches!(result, Err(RepositoryError::NotFound(_))));
309    }
310
311    #[tokio::test]
312    #[ignore = "Requires active Redis instance"]
313    async fn test_empty_validation() {
314        let repo = setup_test_repo().await;
315
316        // Test empty relayer_id
317        let result = repo.get("", "0x1234").await;
318        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
319
320        // Test empty address
321        let result = repo.get("relayer", "").await;
322        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
323    }
324
325    #[tokio::test]
326    #[ignore = "Requires active Redis instance"]
327    async fn test_multiple_relayers() {
328        let repo = setup_test_repo().await;
329        let relayer_1 = uuid::Uuid::new_v4().to_string();
330        let relayer_2 = uuid::Uuid::new_v4().to_string();
331        let address_1 = uuid::Uuid::new_v4().to_string();
332        let address_2 = uuid::Uuid::new_v4().to_string();
333
334        // Set different values for different relayer/address combinations
335        repo.set(&relayer_1, &address_1, 100).await.unwrap();
336        repo.set(&relayer_1, &address_2, 200).await.unwrap();
337        repo.set(&relayer_2, &address_1, 300).await.unwrap();
338
339        // Verify independent counters
340        assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
341        assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
342        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
343
344        // Verify independent increments
345        assert_eq!(
346            repo.get_and_increment(&relayer_1, &address_1)
347                .await
348                .unwrap(),
349            100
350        );
351        assert_eq!(
352            repo.get_and_increment(&relayer_1, &address_1)
353                .await
354                .unwrap(),
355            101
356        );
357        assert_eq!(
358            repo.get_and_increment(&relayer_1, &address_2)
359                .await
360                .unwrap(),
361            200
362        );
363        assert_eq!(
364            repo.get_and_increment(&relayer_1, &address_2)
365                .await
366                .unwrap(),
367            201
368        );
369        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
370    }
371}