openzeppelin_relayer/utils/
encryption.rs1use aes_gcm::{
7 aead::{rand_core::RngCore, Aead, KeyInit, OsRng},
8 Aes256Gcm, Key, Nonce,
9};
10use serde::{Deserialize, Serialize};
11use std::env;
12use thiserror::Error;
13use zeroize::Zeroize;
14
15use crate::{
16 models::SecretString,
17 utils::{base64_decode, base64_encode},
18};
19
20#[derive(Error, Debug, Clone)]
21pub enum EncryptionError {
22 #[error("Encryption failed: {0}")]
23 EncryptionFailed(String),
24 #[error("Decryption failed: {0}")]
25 DecryptionFailed(String),
26 #[error("Key derivation failed: {0}")]
27 KeyDerivationFailed(String),
28 #[error("Invalid encrypted data format: {0}")]
29 InvalidFormat(String),
30 #[error("Missing encryption key environment variable: {0}")]
31 MissingKey(String),
32 #[error("Invalid key length: expected 32 bytes, got {0}")]
33 InvalidKeyLength(usize),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct EncryptedData {
39 pub nonce: String,
41 pub ciphertext: String,
43 pub version: u8,
45}
46
47#[derive(Clone)]
49pub struct FieldEncryption {
50 cipher: Aes256Gcm,
51}
52
53impl FieldEncryption {
54 pub fn new() -> Result<Self, EncryptionError> {
60 let key = Self::load_key_from_env()?;
61 let cipher = Aes256Gcm::new(&key);
62 Ok(Self { cipher })
63 }
64
65 pub fn new_with_key(key: &[u8; 32]) -> Result<Self, EncryptionError> {
67 let key = Key::<Aes256Gcm>::from_slice(key);
68 let cipher = Aes256Gcm::new(key);
69 Ok(Self { cipher })
70 }
71
72 fn load_key_from_env() -> Result<Key<Aes256Gcm>, EncryptionError> {
74 let key = env::var("STORAGE_ENCRYPTION_KEY")
75 .map(|v| SecretString::new(&v))
76 .map_err(|_| {
77 EncryptionError::MissingKey("STORAGE_ENCRYPTION_KEY must be set".to_string())
78 })?;
79
80 key.as_str(|key_b64| {
81 let mut key_bytes = base64_decode(key_b64)
82 .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
83 if key_bytes.len() != 32 {
84 key_bytes.zeroize(); return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
86 }
87
88 Ok(*Key::<Aes256Gcm>::from_slice(&key_bytes))
89 })
90 }
91
92 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
94 let mut nonce_bytes = [0u8; 12];
96 OsRng.fill_bytes(&mut nonce_bytes);
97 let nonce = Nonce::from_slice(&nonce_bytes);
98
99 let ciphertext = self
101 .cipher
102 .encrypt(nonce, plaintext)
103 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
104
105 Ok(EncryptedData {
106 nonce: base64_encode(&nonce_bytes),
107 ciphertext: base64_encode(&ciphertext),
108 version: 1,
109 })
110 }
111
112 pub fn decrypt(&self, encrypted_data: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
114 if encrypted_data.version != 1 {
115 return Err(EncryptionError::InvalidFormat(format!(
116 "Unsupported encryption version: {}",
117 encrypted_data.version
118 )));
119 }
120
121 let nonce_bytes = base64_decode(&encrypted_data.nonce)
123 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid nonce: {}", e)))?;
124
125 let ciphertext_bytes = base64_decode(&encrypted_data.ciphertext)
126 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid ciphertext: {}", e)))?;
127
128 if nonce_bytes.len() != 12 {
129 return Err(EncryptionError::InvalidFormat(format!(
130 "Invalid nonce length: expected 12, got {}",
131 nonce_bytes.len()
132 )));
133 }
134
135 let nonce = Nonce::from_slice(&nonce_bytes);
136
137 let plaintext = self
139 .cipher
140 .decrypt(nonce, ciphertext_bytes.as_ref())
141 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
142
143 Ok(plaintext)
144 }
145
146 pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
148 let encrypted_data = self.encrypt(plaintext.as_bytes())?;
149 let json_data = serde_json::to_string(&encrypted_data).map_err(|e| {
150 EncryptionError::EncryptionFailed(format!("Serialization failed: {}", e))
151 })?;
152
153 Ok(base64_encode(json_data.as_bytes()))
155 }
156
157 pub fn decrypt_string(&self, encrypted_base64: &str) -> Result<String, EncryptionError> {
159 let json_bytes = base64_decode(encrypted_base64)
161 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {}", e)))?;
162
163 let encrypted_json = String::from_utf8(json_bytes).map_err(|e| {
164 EncryptionError::InvalidFormat(format!("Invalid UTF-8 in decoded data: {}", e))
165 })?;
166
167 let encrypted_data: EncryptedData = serde_json::from_str(&encrypted_json).map_err(|e| {
168 EncryptionError::InvalidFormat(format!("Invalid JSON structure: {}", e))
169 })?;
170
171 let plaintext_bytes = self.decrypt(&encrypted_data)?;
172 String::from_utf8(plaintext_bytes).map_err(|e| {
173 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {}", e))
174 })
175 }
176
177 pub fn generate_key() -> String {
179 let mut key = [0u8; 32];
180 OsRng.fill_bytes(&mut key);
181 let key_b64 = base64_encode(&key);
182
183 let mut key_zeroize = key;
185 key_zeroize.zeroize();
186
187 key_b64
188 }
189
190 pub fn is_configured() -> bool {
192 env::var("STORAGE_ENCRYPTION_KEY").is_ok()
193 }
194}
195
196static ENCRYPTION_INSTANCE: std::sync::OnceLock<Result<FieldEncryption, EncryptionError>> =
198 std::sync::OnceLock::new();
199
200pub fn get_encryption() -> Result<&'static FieldEncryption, &'static EncryptionError> {
202 ENCRYPTION_INSTANCE
203 .get_or_init(FieldEncryption::new)
204 .as_ref()
205}
206
207pub fn encrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
209 if FieldEncryption::is_configured() {
210 match get_encryption() {
211 Ok(encryption) => encryption.encrypt_string(data),
212 Err(e) => Err(e.clone()),
213 }
214 } else {
215 let json_data = serde_json::to_string(data).map_err(|e| {
218 EncryptionError::EncryptionFailed(format!("JSON encoding failed: {}", e))
219 })?;
220 Ok(base64_encode(json_data.as_bytes()))
221 }
222}
223
224pub fn decrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
226 let json_bytes = base64_decode(data)
228 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {}", e)))?;
229
230 let json_str = String::from_utf8(json_bytes)
231 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
232
233 if FieldEncryption::is_configured() {
235 if let Ok(encryption) = get_encryption() {
236 if let Ok(encrypted_data) = serde_json::from_str::<EncryptedData>(&json_str) {
238 let plaintext_bytes = encryption.decrypt(&encrypted_data)?;
240 return String::from_utf8(plaintext_bytes).map_err(|e| {
241 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {}", e))
242 });
243 }
244 }
245 }
246
247 serde_json::from_str(&json_str)
250 .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid JSON string: {}", e)))
251}
252
253pub fn generate_encryption_key() -> String {
255 FieldEncryption::generate_key()
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use std::env;
262
263 #[test]
264 fn test_encrypt_decrypt_data() {
265 let key = [0u8; 32]; let encryption = FieldEncryption::new_with_key(&key).unwrap();
267
268 let plaintext = b"This is a secret message!";
269 let encrypted = encryption.encrypt(plaintext).unwrap();
270 let decrypted = encryption.decrypt(&encrypted).unwrap();
271
272 assert_eq!(plaintext, decrypted.as_slice());
273 }
274
275 #[test]
276 fn test_encrypt_decrypt_string() {
277 let key = [1u8; 32]; let encryption = FieldEncryption::new_with_key(&key).unwrap();
279
280 let plaintext = "Sensitive API key: sk-1234567890abcdef";
281 let encrypted = encryption.encrypt_string(plaintext).unwrap();
282 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
283
284 assert_eq!(plaintext, decrypted);
285 }
286
287 #[test]
288 fn test_different_keys_produce_different_results() {
289 let key1 = [1u8; 32];
290 let key2 = [2u8; 32];
291 let encryption1 = FieldEncryption::new_with_key(&key1).unwrap();
292 let encryption2 = FieldEncryption::new_with_key(&key2).unwrap();
293
294 let plaintext = "secret";
295 let encrypted1 = encryption1.encrypt_string(plaintext).unwrap();
296 let encrypted2 = encryption2.encrypt_string(plaintext).unwrap();
297
298 assert_ne!(encrypted1, encrypted2);
299
300 assert_eq!(encryption1.decrypt_string(&encrypted1).unwrap(), plaintext);
302 assert_eq!(encryption2.decrypt_string(&encrypted2).unwrap(), plaintext);
303
304 assert!(encryption1.decrypt_string(&encrypted2).is_err());
306 assert!(encryption2.decrypt_string(&encrypted1).is_err());
307 }
308
309 #[test]
310 fn test_nonce_uniqueness() {
311 let key = [3u8; 32];
312 let encryption = FieldEncryption::new_with_key(&key).unwrap();
313
314 let plaintext = "same message";
315 let encrypted1 = encryption.encrypt_string(plaintext).unwrap();
316 let encrypted2 = encryption.encrypt_string(plaintext).unwrap();
317
318 assert_ne!(encrypted1, encrypted2);
320
321 assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), plaintext);
323 assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), plaintext);
324 }
325
326 #[test]
327 fn test_invalid_encrypted_data() {
328 let key = [4u8; 32];
329 let encryption = FieldEncryption::new_with_key(&key).unwrap();
330
331 assert!(encryption.decrypt_string("invalid base64!").is_err());
333
334 assert!(encryption
336 .decrypt_string(&base64_encode(b"not json"))
337 .is_err());
338
339 let invalid_json_b64 = base64_encode(b"{\"wrong\": \"structure\"}");
341 assert!(encryption.decrypt_string(&invalid_json_b64).is_err());
342
343 assert!(encryption
345 .decrypt_string(&base64_encode(
346 b"{\"nonce\":\"test\",\"ciphertext\":\"test\",\"version\":1}"
347 ))
348 .is_err());
349 }
350
351 #[test]
352 fn test_generate_key() {
353 let key1 = FieldEncryption::generate_key();
354 let key2 = FieldEncryption::generate_key();
355
356 assert_ne!(key1, key2);
358
359 assert!(base64_decode(&key1).is_ok());
361 assert!(base64_decode(&key2).is_ok());
362
363 assert_eq!(base64_decode(&key1).unwrap().len(), 32);
365 assert_eq!(base64_decode(&key2).unwrap().len(), 32);
366 }
367
368 #[test]
369 fn test_env_key_loading() {
370 let test_key = FieldEncryption::generate_key();
372 env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
373
374 let encryption = FieldEncryption::new().unwrap();
375 let plaintext = "test message";
376 let encrypted = encryption.encrypt_string(plaintext).unwrap();
377 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
378 assert_eq!(plaintext, decrypted);
379
380 env::remove_var("STORAGE_ENCRYPTION_KEY");
382 assert!(FieldEncryption::new().is_err());
383
384 env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
386 }
387
388 #[test]
389 fn test_high_level_encryption_functions() {
390 let plaintext = "sensitive data";
391
392 let encoded = encrypt_sensitive_field(plaintext).unwrap();
394 let decoded = decrypt_sensitive_field(&encoded).unwrap();
395 assert_eq!(plaintext, decoded);
396
397 assert!(base64_decode(&encoded).is_ok());
399
400 }
403
404 #[test]
405 fn test_fallback_when_encryption_disabled() {
406 let old_key = env::var("STORAGE_ENCRYPTION_KEY").ok();
408
409 env::remove_var("STORAGE_ENCRYPTION_KEY");
410
411 let plaintext = "fallback test";
412
413 let encoded = encrypt_sensitive_field(plaintext).unwrap();
415 let decoded = decrypt_sensitive_field(&encoded).unwrap();
416 assert_eq!(plaintext, decoded);
417
418 let expected_json = serde_json::to_string(plaintext).unwrap();
420 let expected_b64 = base64_encode(expected_json.as_bytes());
421 assert_eq!(encoded, expected_b64);
422
423 if let Some(key) = old_key {
425 env::set_var("STORAGE_ENCRYPTION_KEY", key);
426 }
427 }
428
429 #[test]
430 fn test_core_encryption_methods() {
431 let key = [9u8; 32];
432 let encryption = FieldEncryption::new_with_key(&key).unwrap();
433 let plaintext = "core encryption test";
434
435 let encrypted = encryption.encrypt_string(plaintext).unwrap();
437 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
438 assert_eq!(plaintext, decrypted);
439
440 assert!(base64_decode(&encrypted).is_ok());
442 assert!(!encrypted.contains("nonce"));
444 assert!(!encrypted.contains("ciphertext"));
445 assert!(!encrypted.contains("{"));
446 }
447
448 #[test]
449 fn test_base64_encoding_hides_structure() {
450 let key = [7u8; 32];
451 let encryption = FieldEncryption::new_with_key(&key).unwrap();
452
453 let plaintext = "secret message";
454 let encrypted = encryption.encrypt_string(plaintext).unwrap();
455
456 assert!(base64_decode(&encrypted).is_ok());
458
459 assert!(!encrypted.contains("nonce"));
461 assert!(!encrypted.contains("ciphertext"));
462 assert!(!encrypted.contains("version"));
463 assert!(!encrypted.contains("{"));
464 assert!(!encrypted.contains("}"));
465
466 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
468 assert_eq!(plaintext, decrypted);
469 }
470}