openzeppelin_relayer/services/aws_kms/
mod.rs

1//! # AWS KMS Service Module
2//!
3//! This module provides integration with AWS KMS for secure key management
4//! and cryptographic operations such as public key retrieval and message signing.
5//!
6//! Currently only EVM is supported.
7//!
8//! ## Features
9//!
10//! - Service account authentication using credential providers
11//! - Public key retrieval from KMS
12//! - Message signing via KMS
13//!
14//! ## Architecture
15//!
16//! ```text
17//! AwsKmsService (implements AwsKmsEvmService)
18//!   ├── Authentication (via AwsKmsClient)
19//!   ├── Public Key Retrieval (via AwsKmsClient)
20//!   └── Message Signing (via AwsKmsClient)
21//! ```
22//! is based on
23//! ```text
24//! AwsKmsClient (implements AwsKmsK256)
25//!   ├── Authentication (via shared credentials)
26//!   ├── Public Key Retrieval in DER Encoding
27//!   └── Message Digest Signing in DER Encoding
28//! ```
29//! `AwsKmsK256` is mocked with `mockall` for unit testing
30//! and injected into `AwsKmsService`
31//!
32
33use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37    primitives::Blob,
38    types::{MessageType, SigningAlgorithmSpec},
39    Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::collections::HashMap;
44use tokio::sync::RwLock;
45
46use crate::{
47    models::{Address, AwsKmsSignerConfig},
48    utils::{self, derive_ethereum_address_from_der, extract_public_key_from_der},
49};
50
51#[cfg(test)]
52use mockall::{automock, mock};
53
54#[derive(Clone, Debug, thiserror::Error, Serialize)]
55pub enum AwsKmsError {
56    #[error("AWS KMS response parse error: {0}")]
57    ParseError(String),
58    #[error("AWS KMS config error: {0}")]
59    ConfigError(String),
60    #[error("AWS KMS get error: {0}")]
61    GetError(String),
62    #[error("AWS KMS signing error: {0}")]
63    SignError(String),
64    #[error("AWS KMS permissions error: {0}")]
65    PermissionError(String),
66    #[error("AWS KMS public key error: {0}")]
67    RecoveryError(#[from] utils::Secp256k1Error),
68    #[error("AWS KMS conversion error: {0}")]
69    ConvertError(String),
70    #[error("AWS KMS Other error: {0}")]
71    Other(String),
72}
73
74pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
75
76#[async_trait]
77#[cfg_attr(test, automock)]
78pub trait AwsKmsEvmService: Send + Sync {
79    /// Returns the EVM address derived from the configured public key.
80    async fn get_evm_address(&self) -> AwsKmsResult<Address>;
81    /// Signs a payload using the EVM signing scheme.
82    /// Pre-hashes the message with keccak-256.
83    async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
84}
85
86#[async_trait]
87#[cfg_attr(test, automock)]
88pub trait AwsKmsK256: Send + Sync {
89    /// Fetches the DER-encoded public key from AWS KMS.
90    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
91    /// Signs a digest using EcdsaSha256 spec. Returns DER-encoded signature
92    async fn sign_digest<'a, 'b>(
93        &'a self,
94        key_id: &'b str,
95        digest: [u8; 32],
96    ) -> AwsKmsResult<Vec<u8>>;
97}
98
99#[cfg(test)]
100mock! {
101    pub AwsKmsClient { }
102    impl Clone for AwsKmsClient {
103        fn clone(&self) -> Self;
104    }
105
106    #[async_trait]
107    impl AwsKmsK256 for AwsKmsClient {
108        async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
109        async fn sign_digest<'a, 'b>(
110            &'a self,
111            key_id: &'b str,
112            digest: [u8; 32],
113        ) -> AwsKmsResult<Vec<u8>>;
114    }
115
116}
117
118// Global cache - HashMap keyed by kms_key_id
119static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
120    Lazy::new(|| RwLock::new(HashMap::new()));
121
122#[derive(Debug, Clone)]
123pub struct AwsKmsClient {
124    inner: Client,
125}
126
127#[async_trait]
128impl AwsKmsK256 for AwsKmsClient {
129    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
130        // Try cache first with minimal lock time
131        let cached = {
132            let cache_read = KMS_DER_PK_CACHE.read().await;
133            cache_read.get(key_id).cloned()
134        };
135        if let Some(cached) = cached {
136            return Ok(cached);
137        }
138
139        // Fetch from AWS KMS
140        let get_output = self
141            .inner
142            .get_public_key()
143            .key_id(key_id)
144            .send()
145            .await
146            .map_err(|e| AwsKmsError::GetError(e.to_string()))?;
147
148        let der_pk_blob = get_output
149            .public_key
150            .ok_or(AwsKmsError::GetError(
151                "No public key blob found".to_string(),
152            ))?
153            .into_inner();
154
155        // Cache the result
156        let mut cache_write = KMS_DER_PK_CACHE.write().await;
157        cache_write.insert(key_id.to_string(), der_pk_blob.clone());
158        drop(cache_write);
159
160        Ok(der_pk_blob)
161    }
162
163    async fn sign_digest<'a, 'b>(
164        &'a self,
165        key_id: &'b str,
166        digest: [u8; 32],
167    ) -> AwsKmsResult<Vec<u8>> {
168        // Sign the digest with the AWS KMS
169        let sign_result = self
170            .inner
171            .sign()
172            .key_id(key_id)
173            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
174            .message_type(MessageType::Digest)
175            .message(Blob::new(digest))
176            .send()
177            .await;
178
179        // Process the result, extract DER signature
180        let der_signature = sign_result
181            .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
182            .signature
183            .ok_or(AwsKmsError::SignError(
184                "Signature not found in response".to_string(),
185            ))?
186            .into_inner();
187
188        Ok(der_signature)
189    }
190}
191
192#[derive(Debug, Clone)]
193pub struct AwsKmsService<T: AwsKmsK256 + Clone = AwsKmsClient> {
194    pub kms_key_id: String,
195    client: T,
196}
197
198impl AwsKmsService<AwsKmsClient> {
199    pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
200        let region_provider =
201            RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
202
203        let auth_config = aws_config::defaults(BehaviorVersion::latest())
204            .region(region_provider)
205            .load()
206            .await;
207        let client = AwsKmsClient {
208            inner: Client::new(&auth_config),
209        };
210
211        Ok(Self {
212            kms_key_id: config.key_id,
213            client,
214        })
215    }
216}
217
218#[cfg(test)]
219impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
220    pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
221        Self {
222            client,
223            kms_key_id: config.key_id,
224        }
225    }
226}
227
228impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
229    /// Signs a bytes with the private key stored in AWS KMS.
230    ///
231    /// Pre-hashes the message with keccak256.
232    pub async fn sign_bytes_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
233        // Create a digest of a message payload
234        let digest = keccak256(bytes).0;
235
236        // Sign the digest with the AWS KMS
237        // Process the result, extract DER signature
238        let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
239
240        // Parse DER into Secp256k1 format
241        let mut rs = k256::ecdsa::Signature::from_der(&der_signature)
242            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
243
244        // Normalize to low-s if necessary
245        if let Some(normalized) = rs.normalize_s() {
246            rs = normalized;
247        }
248        let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
249
250        // Extract public key from AWS KMS and convert it to an uncompressed 64 pk
251        let pk = extract_public_key_from_der(&der_pk)
252            .map_err(|e| AwsKmsError::ConvertError(e.to_string()))?;
253
254        // Extract v value from the public key recovery
255        let v = utils::recover_public_key(&pk, &rs, bytes)?;
256
257        // Adjust v value for Ethereum legacy transaction.
258        let eth_v = 27 + v;
259
260        // Append `v` to a signature bytes
261        let mut sig_bytes = rs.to_vec();
262        sig_bytes.push(eth_v);
263
264        Ok(sig_bytes)
265    }
266}
267
268#[async_trait]
269impl<T: AwsKmsK256 + Clone> AwsKmsEvmService for AwsKmsService<T> {
270    async fn get_evm_address(&self) -> AwsKmsResult<Address> {
271        let der = self.client.get_der_public_key(&self.kms_key_id).await?;
272        let eth_address = derive_ethereum_address_from_der(&der)
273            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
274        Ok(Address::Evm(eth_address))
275    }
276
277    async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
278        self.sign_bytes_evm(message).await
279    }
280}
281
282#[cfg(test)]
283pub mod tests {
284    use super::*;
285
286    use alloy::primitives::utils::eip191_message;
287    use k256::{
288        ecdsa::SigningKey,
289        elliptic_curve::rand_core::OsRng,
290        pkcs8::{der::Encode, EncodePublicKey},
291    };
292    use mockall::predicate::{eq, ne};
293
294    pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
295        let mut client = MockAwsKmsClient::new();
296        let signing_key = SigningKey::random(&mut OsRng);
297        let s = signing_key
298            .verifying_key()
299            .to_public_key_der()
300            .unwrap()
301            .to_der()
302            .unwrap();
303
304        client
305            .expect_get_der_public_key()
306            .with(eq("test-key-id"))
307            .return_const(Ok(s));
308        client
309            .expect_get_der_public_key()
310            .with(ne("test-key-id"))
311            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
312
313        client
314            .expect_sign_digest()
315            .withf(|key_id, _| key_id.ne("test-key-id"))
316            .return_const(Err(AwsKmsError::SignError(
317                "Key does not exist".to_string(),
318            )));
319
320        let key = signing_key.clone();
321        client
322            .expect_sign_digest()
323            .withf(|key_id, _| key_id.eq("test-key-id"))
324            .returning(move |_, digest| {
325                let (signature, _) = signing_key
326                    .sign_prehash_recoverable(&digest)
327                    .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
328                let der_signature = signature.to_der().as_bytes().to_vec();
329                Ok(der_signature)
330            });
331
332        client.expect_clone().return_once(MockAwsKmsClient::new);
333
334        (client, key)
335    }
336
337    #[tokio::test]
338    async fn test_get_public_key() {
339        let (mock_client, key) = setup_mock_kms_client();
340        let kms = AwsKmsService::new_for_testing(
341            mock_client,
342            AwsKmsSignerConfig {
343                region: Some("us-east-1".to_string()),
344                key_id: "test-key-id".to_string(),
345            },
346        );
347
348        let result = kms.get_evm_address().await;
349        assert!(result.is_ok());
350        if let Ok(Address::Evm(evm_address)) = result {
351            let expected_address = derive_ethereum_address_from_der(
352                key.verifying_key().to_public_key_der().unwrap().as_bytes(),
353            )
354            .unwrap();
355            assert_eq!(expected_address, evm_address);
356        }
357    }
358
359    #[tokio::test]
360    async fn test_get_public_key_fail() {
361        let (mock_client, _) = setup_mock_kms_client();
362        let kms = AwsKmsService::new_for_testing(
363            mock_client,
364            AwsKmsSignerConfig {
365                region: Some("us-east-1".to_string()),
366                key_id: "invalid-key-id".to_string(),
367            },
368        );
369
370        let result = kms.get_evm_address().await;
371        assert!(result.is_err());
372        if let Err(err) = result {
373            assert!(matches!(err, AwsKmsError::GetError(_)))
374        }
375    }
376
377    #[tokio::test]
378    async fn test_sign_digest() {
379        let (mock_client, _) = setup_mock_kms_client();
380        let kms = AwsKmsService::new_for_testing(
381            mock_client,
382            AwsKmsSignerConfig {
383                region: Some("us-east-1".to_string()),
384                key_id: "test-key-id".to_string(),
385            },
386        );
387
388        let message_eip = eip191_message(b"Hello World!");
389        let result = kms.sign_payload_evm(&message_eip).await;
390
391        // We just assert for Ok, since the pubkey recovery indicates the validity of signature
392        assert!(result.is_ok());
393    }
394
395    #[tokio::test]
396    async fn test_sign_digest_fail() {
397        let (mock_client, _) = setup_mock_kms_client();
398        let kms = AwsKmsService::new_for_testing(
399            mock_client,
400            AwsKmsSignerConfig {
401                region: Some("us-east-1".to_string()),
402                key_id: "invalid-key-id".to_string(),
403            },
404        );
405
406        let message_eip = eip191_message(b"Hello World!");
407        let result = kms.sign_payload_evm(&message_eip).await;
408        assert!(result.is_err());
409        if let Err(err) = result {
410            assert!(matches!(err, AwsKmsError::SignError(_)))
411        }
412    }
413}