openzeppelin_relayer/repositories/transaction_counter/
mod.rs1pub mod transaction_counter_in_memory;
22pub mod transaction_counter_redis;
23
24use redis::aio::ConnectionManager;
25pub use transaction_counter_in_memory::InMemoryTransactionCounter;
26pub use transaction_counter_redis::RedisTransactionCounter;
27
28use async_trait::async_trait;
29use serde::Serialize;
30use std::sync::Arc;
31use thiserror::Error;
32
33#[cfg(test)]
34use mockall::automock;
35
36use crate::models::RepositoryError;
37
38#[derive(Error, Debug, Serialize)]
39pub enum TransactionCounterError {
40 #[error("No sequence found for relayer {relayer_id} and address {address}")]
41 SequenceNotFound { relayer_id: String, address: String },
42 #[error("Counter not found for {0}")]
43 NotFound(String),
44}
45
46#[allow(dead_code)]
47#[async_trait]
48#[cfg_attr(test, automock)]
49pub trait TransactionCounterTrait {
50 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError>;
51
52 async fn get_and_increment(
53 &self,
54 relayer_id: &str,
55 address: &str,
56 ) -> Result<u64, RepositoryError>;
57
58 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError>;
59
60 async fn set(&self, relayer_id: &str, address: &str, value: u64)
61 -> Result<(), RepositoryError>;
62}
63
64#[derive(Debug, Clone)]
66pub enum TransactionCounterRepositoryStorage {
67 InMemory(InMemoryTransactionCounter),
68 Redis(RedisTransactionCounter),
69}
70
71impl TransactionCounterRepositoryStorage {
72 pub fn new_in_memory() -> Self {
73 Self::InMemory(InMemoryTransactionCounter::new())
74 }
75 pub fn new_redis(
76 connection_manager: Arc<ConnectionManager>,
77 key_prefix: String,
78 ) -> Result<Self, RepositoryError> {
79 Ok(Self::Redis(RedisTransactionCounter::new(
80 connection_manager,
81 key_prefix,
82 )?))
83 }
84}
85
86#[async_trait]
87impl TransactionCounterTrait for TransactionCounterRepositoryStorage {
88 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
89 match self {
90 TransactionCounterRepositoryStorage::InMemory(counter) => {
91 counter.get(relayer_id, address).await
92 }
93 TransactionCounterRepositoryStorage::Redis(counter) => {
94 counter.get(relayer_id, address).await
95 }
96 }
97 }
98
99 async fn get_and_increment(
100 &self,
101 relayer_id: &str,
102 address: &str,
103 ) -> Result<u64, RepositoryError> {
104 match self {
105 TransactionCounterRepositoryStorage::InMemory(counter) => {
106 counter.get_and_increment(relayer_id, address).await
107 }
108 TransactionCounterRepositoryStorage::Redis(counter) => {
109 counter.get_and_increment(relayer_id, address).await
110 }
111 }
112 }
113
114 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
115 match self {
116 TransactionCounterRepositoryStorage::InMemory(counter) => {
117 counter.decrement(relayer_id, address).await
118 }
119 TransactionCounterRepositoryStorage::Redis(counter) => {
120 counter.decrement(relayer_id, address).await
121 }
122 }
123 }
124
125 async fn set(
126 &self,
127 relayer_id: &str,
128 address: &str,
129 value: u64,
130 ) -> Result<(), RepositoryError> {
131 match self {
132 TransactionCounterRepositoryStorage::InMemory(counter) => {
133 counter.set(relayer_id, address, value).await
134 }
135 TransactionCounterRepositoryStorage::Redis(counter) => {
136 counter.set(relayer_id, address, value).await
137 }
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144
145 use super::*;
146
147 #[tokio::test]
148 async fn test_in_memory_repository_creation() {
149 let repo = TransactionCounterRepositoryStorage::new_in_memory();
150
151 matches!(repo, TransactionCounterRepositoryStorage::InMemory(_));
152 }
153
154 #[tokio::test]
155 async fn test_enum_wrapper_delegation() {
156 let repo = TransactionCounterRepositoryStorage::new_in_memory();
157
158 let result = repo.get("test_relayer", "0x1234").await.unwrap();
160 assert_eq!(result, None);
161
162 repo.set("test_relayer", "0x1234", 100).await.unwrap();
163 let result = repo.get("test_relayer", "0x1234").await.unwrap();
164 assert_eq!(result, Some(100));
165
166 let current = repo
167 .get_and_increment("test_relayer", "0x1234")
168 .await
169 .unwrap();
170 assert_eq!(current, 100);
171
172 let result = repo.get("test_relayer", "0x1234").await.unwrap();
173 assert_eq!(result, Some(101));
174
175 let new_value = repo.decrement("test_relayer", "0x1234").await.unwrap();
176 assert_eq!(new_value, 100);
177 }
178}