openzeppelin_relayer/services/aws_kms/
mod.rs1use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37 primitives::Blob,
38 types::{MessageType, SigningAlgorithmSpec},
39 Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::collections::HashMap;
44use tokio::sync::RwLock;
45
46use crate::{
47 models::{Address, AwsKmsSignerConfig},
48 utils::{self, derive_ethereum_address_from_der, extract_public_key_from_der},
49};
50
51#[cfg(test)]
52use mockall::{automock, mock};
53
54#[derive(Clone, Debug, thiserror::Error, Serialize)]
55pub enum AwsKmsError {
56 #[error("AWS KMS response parse error: {0}")]
57 ParseError(String),
58 #[error("AWS KMS config error: {0}")]
59 ConfigError(String),
60 #[error("AWS KMS get error: {0}")]
61 GetError(String),
62 #[error("AWS KMS signing error: {0}")]
63 SignError(String),
64 #[error("AWS KMS permissions error: {0}")]
65 PermissionError(String),
66 #[error("AWS KMS public key error: {0}")]
67 RecoveryError(#[from] utils::Secp256k1Error),
68 #[error("AWS KMS conversion error: {0}")]
69 ConvertError(String),
70 #[error("AWS KMS Other error: {0}")]
71 Other(String),
72}
73
74pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
75
76#[async_trait]
77#[cfg_attr(test, automock)]
78pub trait AwsKmsEvmService: Send + Sync {
79 async fn get_evm_address(&self) -> AwsKmsResult<Address>;
81 async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
84}
85
86#[async_trait]
87#[cfg_attr(test, automock)]
88pub trait AwsKmsK256: Send + Sync {
89 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
91 async fn sign_digest<'a, 'b>(
93 &'a self,
94 key_id: &'b str,
95 digest: [u8; 32],
96 ) -> AwsKmsResult<Vec<u8>>;
97}
98
99#[cfg(test)]
100mock! {
101 pub AwsKmsClient { }
102 impl Clone for AwsKmsClient {
103 fn clone(&self) -> Self;
104 }
105
106 #[async_trait]
107 impl AwsKmsK256 for AwsKmsClient {
108 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
109 async fn sign_digest<'a, 'b>(
110 &'a self,
111 key_id: &'b str,
112 digest: [u8; 32],
113 ) -> AwsKmsResult<Vec<u8>>;
114 }
115
116}
117
118static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
120 Lazy::new(|| RwLock::new(HashMap::new()));
121
122#[derive(Debug, Clone)]
123pub struct AwsKmsClient {
124 inner: Client,
125}
126
127#[async_trait]
128impl AwsKmsK256 for AwsKmsClient {
129 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
130 let cached = {
132 let cache_read = KMS_DER_PK_CACHE.read().await;
133 cache_read.get(key_id).cloned()
134 };
135 if let Some(cached) = cached {
136 return Ok(cached);
137 }
138
139 let get_output = self
141 .inner
142 .get_public_key()
143 .key_id(key_id)
144 .send()
145 .await
146 .map_err(|e| AwsKmsError::GetError(e.to_string()))?;
147
148 let der_pk_blob = get_output
149 .public_key
150 .ok_or(AwsKmsError::GetError(
151 "No public key blob found".to_string(),
152 ))?
153 .into_inner();
154
155 let mut cache_write = KMS_DER_PK_CACHE.write().await;
157 cache_write.insert(key_id.to_string(), der_pk_blob.clone());
158 drop(cache_write);
159
160 Ok(der_pk_blob)
161 }
162
163 async fn sign_digest<'a, 'b>(
164 &'a self,
165 key_id: &'b str,
166 digest: [u8; 32],
167 ) -> AwsKmsResult<Vec<u8>> {
168 let sign_result = self
170 .inner
171 .sign()
172 .key_id(key_id)
173 .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
174 .message_type(MessageType::Digest)
175 .message(Blob::new(digest))
176 .send()
177 .await;
178
179 let der_signature = sign_result
181 .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
182 .signature
183 .ok_or(AwsKmsError::SignError(
184 "Signature not found in response".to_string(),
185 ))?
186 .into_inner();
187
188 Ok(der_signature)
189 }
190}
191
192#[derive(Debug, Clone)]
193pub struct AwsKmsService<T: AwsKmsK256 + Clone = AwsKmsClient> {
194 pub kms_key_id: String,
195 client: T,
196}
197
198impl AwsKmsService<AwsKmsClient> {
199 pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
200 let region_provider =
201 RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
202
203 let auth_config = aws_config::defaults(BehaviorVersion::latest())
204 .region(region_provider)
205 .load()
206 .await;
207 let client = AwsKmsClient {
208 inner: Client::new(&auth_config),
209 };
210
211 Ok(Self {
212 kms_key_id: config.key_id,
213 client,
214 })
215 }
216}
217
218#[cfg(test)]
219impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
220 pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
221 Self {
222 client,
223 kms_key_id: config.key_id,
224 }
225 }
226}
227
228impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
229 pub async fn sign_bytes_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
233 let digest = keccak256(bytes).0;
235
236 let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
239
240 let mut rs = k256::ecdsa::Signature::from_der(&der_signature)
242 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
243
244 if let Some(normalized) = rs.normalize_s() {
246 rs = normalized;
247 }
248 let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
249
250 let pk = extract_public_key_from_der(&der_pk)
252 .map_err(|e| AwsKmsError::ConvertError(e.to_string()))?;
253
254 let v = utils::recover_public_key(&pk, &rs, bytes)?;
256
257 let eth_v = 27 + v;
259
260 let mut sig_bytes = rs.to_vec();
262 sig_bytes.push(eth_v);
263
264 Ok(sig_bytes)
265 }
266}
267
268#[async_trait]
269impl<T: AwsKmsK256 + Clone> AwsKmsEvmService for AwsKmsService<T> {
270 async fn get_evm_address(&self) -> AwsKmsResult<Address> {
271 let der = self.client.get_der_public_key(&self.kms_key_id).await?;
272 let eth_address = derive_ethereum_address_from_der(&der)
273 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
274 Ok(Address::Evm(eth_address))
275 }
276
277 async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
278 self.sign_bytes_evm(message).await
279 }
280}
281
282#[cfg(test)]
283pub mod tests {
284 use super::*;
285
286 use alloy::primitives::utils::eip191_message;
287 use k256::{
288 ecdsa::SigningKey,
289 elliptic_curve::rand_core::OsRng,
290 pkcs8::{der::Encode, EncodePublicKey},
291 };
292 use mockall::predicate::{eq, ne};
293
294 pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
295 let mut client = MockAwsKmsClient::new();
296 let signing_key = SigningKey::random(&mut OsRng);
297 let s = signing_key
298 .verifying_key()
299 .to_public_key_der()
300 .unwrap()
301 .to_der()
302 .unwrap();
303
304 client
305 .expect_get_der_public_key()
306 .with(eq("test-key-id"))
307 .return_const(Ok(s));
308 client
309 .expect_get_der_public_key()
310 .with(ne("test-key-id"))
311 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
312
313 client
314 .expect_sign_digest()
315 .withf(|key_id, _| key_id.ne("test-key-id"))
316 .return_const(Err(AwsKmsError::SignError(
317 "Key does not exist".to_string(),
318 )));
319
320 let key = signing_key.clone();
321 client
322 .expect_sign_digest()
323 .withf(|key_id, _| key_id.eq("test-key-id"))
324 .returning(move |_, digest| {
325 let (signature, _) = signing_key
326 .sign_prehash_recoverable(&digest)
327 .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
328 let der_signature = signature.to_der().as_bytes().to_vec();
329 Ok(der_signature)
330 });
331
332 client.expect_clone().return_once(MockAwsKmsClient::new);
333
334 (client, key)
335 }
336
337 #[tokio::test]
338 async fn test_get_public_key() {
339 let (mock_client, key) = setup_mock_kms_client();
340 let kms = AwsKmsService::new_for_testing(
341 mock_client,
342 AwsKmsSignerConfig {
343 region: Some("us-east-1".to_string()),
344 key_id: "test-key-id".to_string(),
345 },
346 );
347
348 let result = kms.get_evm_address().await;
349 assert!(result.is_ok());
350 if let Ok(Address::Evm(evm_address)) = result {
351 let expected_address = derive_ethereum_address_from_der(
352 key.verifying_key().to_public_key_der().unwrap().as_bytes(),
353 )
354 .unwrap();
355 assert_eq!(expected_address, evm_address);
356 }
357 }
358
359 #[tokio::test]
360 async fn test_get_public_key_fail() {
361 let (mock_client, _) = setup_mock_kms_client();
362 let kms = AwsKmsService::new_for_testing(
363 mock_client,
364 AwsKmsSignerConfig {
365 region: Some("us-east-1".to_string()),
366 key_id: "invalid-key-id".to_string(),
367 },
368 );
369
370 let result = kms.get_evm_address().await;
371 assert!(result.is_err());
372 if let Err(err) = result {
373 assert!(matches!(err, AwsKmsError::GetError(_)))
374 }
375 }
376
377 #[tokio::test]
378 async fn test_sign_digest() {
379 let (mock_client, _) = setup_mock_kms_client();
380 let kms = AwsKmsService::new_for_testing(
381 mock_client,
382 AwsKmsSignerConfig {
383 region: Some("us-east-1".to_string()),
384 key_id: "test-key-id".to_string(),
385 },
386 );
387
388 let message_eip = eip191_message(b"Hello World!");
389 let result = kms.sign_payload_evm(&message_eip).await;
390
391 assert!(result.is_ok());
393 }
394
395 #[tokio::test]
396 async fn test_sign_digest_fail() {
397 let (mock_client, _) = setup_mock_kms_client();
398 let kms = AwsKmsService::new_for_testing(
399 mock_client,
400 AwsKmsSignerConfig {
401 region: Some("us-east-1".to_string()),
402 key_id: "invalid-key-id".to_string(),
403 },
404 );
405
406 let message_eip = eip191_message(b"Hello World!");
407 let result = kms.sign_payload_evm(&message_eip).await;
408 assert!(result.is_err());
409 if let Err(err) = result {
410 assert!(matches!(err, AwsKmsError::SignError(_)))
411 }
412 }
413}