openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_in_memory.rs1use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::repositories::{RepositoryError, TransactionCounterTrait};
11
12#[derive(Debug, Default, Clone)]
13pub struct InMemoryTransactionCounter {
14 store: DashMap<(String, String), u64>, }
16
17impl InMemoryTransactionCounter {
18 pub fn new() -> Self {
19 Self {
20 store: DashMap::new(),
21 }
22 }
23}
24
25#[async_trait]
26impl TransactionCounterTrait for InMemoryTransactionCounter {
27 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
28 Ok(self
29 .store
30 .get(&(relayer_id.to_string(), address.to_string()))
31 .map(|n| *n))
32 }
33
34 async fn get_and_increment(
35 &self,
36 relayer_id: &str,
37 address: &str,
38 ) -> Result<u64, RepositoryError> {
39 let mut entry = self
40 .store
41 .entry((relayer_id.to_string(), address.to_string()))
42 .or_insert(0);
43 let current = *entry;
44 *entry += 1;
45 Ok(current)
46 }
47
48 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
49 let mut entry = self
50 .store
51 .get_mut(&(relayer_id.to_string(), address.to_string()))
52 .ok_or_else(|| {
53 RepositoryError::NotFound(format!("Counter not found for {}", address))
54 })?;
55 if *entry > 0 {
56 *entry -= 1;
57 }
58 Ok(*entry)
59 }
60
61 async fn set(
62 &self,
63 relayer_id: &str,
64 address: &str,
65 value: u64,
66 ) -> Result<(), RepositoryError> {
67 self.store
68 .insert((relayer_id.to_string(), address.to_string()), value);
69 Ok(())
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[tokio::test]
78 async fn test_decrement_not_found() {
79 let store = InMemoryTransactionCounter::new();
80 let result = store.decrement("nonexistent", "0x1234").await;
81 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
82 }
83
84 #[tokio::test]
85 async fn test_nonce_store() {
86 let store = InMemoryTransactionCounter::new();
87 let relayer_id = "relayer_1";
88 let address = "0x1234";
89
90 assert_eq!(store.get(relayer_id, address).await.unwrap(), None);
92
93 store.set(relayer_id, address, 100).await.unwrap();
95 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
96
97 assert_eq!(
99 store.get_and_increment(relayer_id, address).await.unwrap(),
100 100
101 );
102 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(101));
103
104 assert_eq!(store.decrement(relayer_id, address).await.unwrap(), 100);
106 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
107 }
108
109 #[tokio::test]
110 async fn test_multiple_relayers() {
111 let store = InMemoryTransactionCounter::new();
112
113 store.set("relayer_1", "0x1234", 100).await.unwrap();
115 store.set("relayer_1", "0x5678", 200).await.unwrap();
116 store.set("relayer_2", "0x1234", 300).await.unwrap();
117
118 assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), Some(100));
120 assert_eq!(store.get("relayer_1", "0x5678").await.unwrap(), Some(200));
121 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
122
123 assert_eq!(
125 store
126 .get_and_increment("relayer_1", "0x1234")
127 .await
128 .unwrap(),
129 100
130 );
131 assert_eq!(
132 store
133 .get_and_increment("relayer_1", "0x1234")
134 .await
135 .unwrap(),
136 101
137 );
138 assert_eq!(
139 store
140 .get_and_increment("relayer_1", "0x5678")
141 .await
142 .unwrap(),
143 200
144 );
145 assert_eq!(
146 store
147 .get_and_increment("relayer_1", "0x5678")
148 .await
149 .unwrap(),
150 201
151 );
152 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
153 }
154}