openzeppelin_relayer/utils/
encryption.rs

1//! Field-level encryption utilities for sensitive data protection
2//!
3//! This module provides secure encryption and decryption of sensitive fields using AES-256-GCM.
4//! It's designed to be used transparently in the repository layer to protect data at rest.
5
6use aes_gcm::{
7    aead::{rand_core::RngCore, Aead, KeyInit, OsRng},
8    Aes256Gcm, Key, Nonce,
9};
10use serde::{Deserialize, Serialize};
11use std::env;
12use thiserror::Error;
13use zeroize::Zeroize;
14
15use crate::{
16    models::SecretString,
17    utils::{base64_decode, base64_encode},
18};
19
20#[derive(Error, Debug, Clone)]
21pub enum EncryptionError {
22    #[error("Encryption failed: {0}")]
23    EncryptionFailed(String),
24    #[error("Decryption failed: {0}")]
25    DecryptionFailed(String),
26    #[error("Key derivation failed: {0}")]
27    KeyDerivationFailed(String),
28    #[error("Invalid encrypted data format: {0}")]
29    InvalidFormat(String),
30    #[error("Missing encryption key environment variable: {0}")]
31    MissingKey(String),
32    #[error("Invalid key length: expected 32 bytes, got {0}")]
33    InvalidKeyLength(usize),
34}
35
36/// Encrypted data container that holds the nonce and ciphertext
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct EncryptedData {
39    /// Base64-encoded nonce (12 bytes for GCM)
40    pub nonce: String,
41    /// Base64-encoded ciphertext with authentication tag
42    pub ciphertext: String,
43    /// Version for future compatibility
44    pub version: u8,
45}
46
47/// Main encryption service for field-level encryption
48#[derive(Clone)]
49pub struct FieldEncryption {
50    cipher: Aes256Gcm,
51}
52
53impl FieldEncryption {
54    /// Creates a new FieldEncryption instance using a key from environment variables
55    ///
56    /// # Environment Variables
57    /// - `STORAGE_ENCRYPTION_KEY`: Base64-encoded 32-byte encryption key
58    /// ```
59    pub fn new() -> Result<Self, EncryptionError> {
60        let key = Self::load_key_from_env()?;
61        let cipher = Aes256Gcm::new(&key);
62        Ok(Self { cipher })
63    }
64
65    /// Creates a new FieldEncryption instance with a provided key (for testing)
66    pub fn new_with_key(key: &[u8; 32]) -> Result<Self, EncryptionError> {
67        let key = Key::<Aes256Gcm>::from_slice(key);
68        let cipher = Aes256Gcm::new(key);
69        Ok(Self { cipher })
70    }
71
72    /// Loads encryption key from environment variables
73    fn load_key_from_env() -> Result<Key<Aes256Gcm>, EncryptionError> {
74        let key = env::var("STORAGE_ENCRYPTION_KEY")
75            .map(|v| SecretString::new(&v))
76            .map_err(|_| {
77                EncryptionError::MissingKey("STORAGE_ENCRYPTION_KEY must be set".to_string())
78            })?;
79
80        key.as_str(|key_b64| {
81            let mut key_bytes = base64_decode(key_b64)
82                .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
83            if key_bytes.len() != 32 {
84                key_bytes.zeroize(); // Explicit cleanup on error path
85                return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
86            }
87
88            Ok(*Key::<Aes256Gcm>::from_slice(&key_bytes))
89        })
90    }
91
92    /// Encrypts plaintext data and returns an EncryptedData structure
93    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
94        // Generate random 12-byte nonce for GCM
95        let mut nonce_bytes = [0u8; 12];
96        OsRng.fill_bytes(&mut nonce_bytes);
97        let nonce = Nonce::from_slice(&nonce_bytes);
98
99        // Encrypt the data
100        let ciphertext = self
101            .cipher
102            .encrypt(nonce, plaintext)
103            .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
104
105        Ok(EncryptedData {
106            nonce: base64_encode(&nonce_bytes),
107            ciphertext: base64_encode(&ciphertext),
108            version: 1,
109        })
110    }
111
112    /// Decrypts an EncryptedData structure and returns the plaintext
113    pub fn decrypt(&self, encrypted_data: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
114        if encrypted_data.version != 1 {
115            return Err(EncryptionError::InvalidFormat(format!(
116                "Unsupported encryption version: {}",
117                encrypted_data.version
118            )));
119        }
120
121        // Decode nonce and ciphertext
122        let nonce_bytes = base64_decode(&encrypted_data.nonce)
123            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid nonce: {}", e)))?;
124
125        let ciphertext_bytes = base64_decode(&encrypted_data.ciphertext)
126            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid ciphertext: {}", e)))?;
127
128        if nonce_bytes.len() != 12 {
129            return Err(EncryptionError::InvalidFormat(format!(
130                "Invalid nonce length: expected 12, got {}",
131                nonce_bytes.len()
132            )));
133        }
134
135        let nonce = Nonce::from_slice(&nonce_bytes);
136
137        // Decrypt the data
138        let plaintext = self
139            .cipher
140            .decrypt(nonce, ciphertext_bytes.as_ref())
141            .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
142
143        Ok(plaintext)
144    }
145
146    /// Encrypts a string and returns base64-encoded encrypted data (opaque format)
147    pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
148        let encrypted_data = self.encrypt(plaintext.as_bytes())?;
149        let json_data = serde_json::to_string(&encrypted_data).map_err(|e| {
150            EncryptionError::EncryptionFailed(format!("Serialization failed: {}", e))
151        })?;
152
153        // Base64 encode the entire JSON to make it opaque
154        Ok(base64_encode(json_data.as_bytes()))
155    }
156
157    /// Decrypts a base64-encoded encrypted string
158    pub fn decrypt_string(&self, encrypted_base64: &str) -> Result<String, EncryptionError> {
159        // Decode from base64 to get the JSON
160        let json_bytes = base64_decode(encrypted_base64)
161            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {}", e)))?;
162
163        let encrypted_json = String::from_utf8(json_bytes).map_err(|e| {
164            EncryptionError::InvalidFormat(format!("Invalid UTF-8 in decoded data: {}", e))
165        })?;
166
167        let encrypted_data: EncryptedData = serde_json::from_str(&encrypted_json).map_err(|e| {
168            EncryptionError::InvalidFormat(format!("Invalid JSON structure: {}", e))
169        })?;
170
171        let plaintext_bytes = self.decrypt(&encrypted_data)?;
172        String::from_utf8(plaintext_bytes).map_err(|e| {
173            EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {}", e))
174        })
175    }
176
177    /// Utility function to generate a new encryption key for setup
178    pub fn generate_key() -> String {
179        let mut key = [0u8; 32];
180        OsRng.fill_bytes(&mut key);
181        let key_b64 = base64_encode(&key);
182
183        // Zero out the key from memory
184        let mut key_zeroize = key;
185        key_zeroize.zeroize();
186
187        key_b64
188    }
189
190    /// Checks if encryption is properly configured
191    pub fn is_configured() -> bool {
192        env::var("STORAGE_ENCRYPTION_KEY").is_ok()
193    }
194}
195
196/// Global encryption instance (lazy-initialized)
197static ENCRYPTION_INSTANCE: std::sync::OnceLock<Result<FieldEncryption, EncryptionError>> =
198    std::sync::OnceLock::new();
199
200/// Gets the global encryption instance
201pub fn get_encryption() -> Result<&'static FieldEncryption, &'static EncryptionError> {
202    ENCRYPTION_INSTANCE
203        .get_or_init(FieldEncryption::new)
204        .as_ref()
205}
206
207/// Encrypts sensitive data if encryption is configured, otherwise returns base64-encoded plaintext
208pub fn encrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
209    if FieldEncryption::is_configured() {
210        match get_encryption() {
211            Ok(encryption) => encryption.encrypt_string(data),
212            Err(e) => Err(e.clone()),
213        }
214    } else {
215        // For development/testing when encryption is not configured,
216        // base64-encode the JSON string for consistency
217        let json_data = serde_json::to_string(data).map_err(|e| {
218            EncryptionError::EncryptionFailed(format!("JSON encoding failed: {}", e))
219        })?;
220        Ok(base64_encode(json_data.as_bytes()))
221    }
222}
223
224/// Decrypts sensitive data from base64 format
225pub fn decrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
226    // Always try to decode base64 first
227    let json_bytes = base64_decode(data)
228        .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {}", e)))?;
229
230    let json_str = String::from_utf8(json_bytes)
231        .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
232
233    // Try to parse as encrypted data first (if encryption is configured)
234    if FieldEncryption::is_configured() {
235        if let Ok(encryption) = get_encryption() {
236            // Check if this looks like encrypted data by trying to parse as EncryptedData
237            if let Ok(encrypted_data) = serde_json::from_str::<EncryptedData>(&json_str) {
238                // This is encrypted data, decrypt it
239                let plaintext_bytes = encryption.decrypt(&encrypted_data)?;
240                return String::from_utf8(plaintext_bytes).map_err(|e| {
241                    EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {}", e))
242                });
243            }
244        }
245    }
246
247    // If we get here, either encryption is not configured, or this is fallback data
248    // Try to parse as JSON string (fallback format)
249    serde_json::from_str(&json_str)
250        .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid JSON string: {}", e)))
251}
252
253/// Utility function to generate a new encryption key
254pub fn generate_encryption_key() -> String {
255    FieldEncryption::generate_key()
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use std::env;
262
263    #[test]
264    fn test_encrypt_decrypt_data() {
265        let key = [0u8; 32]; // Test key
266        let encryption = FieldEncryption::new_with_key(&key).unwrap();
267
268        let plaintext = b"This is a secret message!";
269        let encrypted = encryption.encrypt(plaintext).unwrap();
270        let decrypted = encryption.decrypt(&encrypted).unwrap();
271
272        assert_eq!(plaintext, decrypted.as_slice());
273    }
274
275    #[test]
276    fn test_encrypt_decrypt_string() {
277        let key = [1u8; 32]; // Different test key
278        let encryption = FieldEncryption::new_with_key(&key).unwrap();
279
280        let plaintext = "Sensitive API key: sk-1234567890abcdef";
281        let encrypted = encryption.encrypt_string(plaintext).unwrap();
282        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
283
284        assert_eq!(plaintext, decrypted);
285    }
286
287    #[test]
288    fn test_different_keys_produce_different_results() {
289        let key1 = [1u8; 32];
290        let key2 = [2u8; 32];
291        let encryption1 = FieldEncryption::new_with_key(&key1).unwrap();
292        let encryption2 = FieldEncryption::new_with_key(&key2).unwrap();
293
294        let plaintext = "secret";
295        let encrypted1 = encryption1.encrypt_string(plaintext).unwrap();
296        let encrypted2 = encryption2.encrypt_string(plaintext).unwrap();
297
298        assert_ne!(encrypted1, encrypted2);
299
300        // Each should decrypt with their own key
301        assert_eq!(encryption1.decrypt_string(&encrypted1).unwrap(), plaintext);
302        assert_eq!(encryption2.decrypt_string(&encrypted2).unwrap(), plaintext);
303
304        // But not with the other key
305        assert!(encryption1.decrypt_string(&encrypted2).is_err());
306        assert!(encryption2.decrypt_string(&encrypted1).is_err());
307    }
308
309    #[test]
310    fn test_nonce_uniqueness() {
311        let key = [3u8; 32];
312        let encryption = FieldEncryption::new_with_key(&key).unwrap();
313
314        let plaintext = "same message";
315        let encrypted1 = encryption.encrypt_string(plaintext).unwrap();
316        let encrypted2 = encryption.encrypt_string(plaintext).unwrap();
317
318        // Same plaintext should produce different ciphertext due to random nonces
319        assert_ne!(encrypted1, encrypted2);
320
321        // Both should decrypt to the same plaintext
322        assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), plaintext);
323        assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), plaintext);
324    }
325
326    #[test]
327    fn test_invalid_encrypted_data() {
328        let key = [4u8; 32];
329        let encryption = FieldEncryption::new_with_key(&key).unwrap();
330
331        // Test with invalid base64
332        assert!(encryption.decrypt_string("invalid base64!").is_err());
333
334        // Test with valid base64 but invalid JSON inside
335        assert!(encryption
336            .decrypt_string(&base64_encode(b"not json"))
337            .is_err());
338
339        // Test with valid base64 but wrong JSON structure inside
340        let invalid_json_b64 = base64_encode(b"{\"wrong\": \"structure\"}");
341        assert!(encryption.decrypt_string(&invalid_json_b64).is_err());
342
343        // Test with plain JSON (old format) - should fail since we only accept base64
344        assert!(encryption
345            .decrypt_string(&base64_encode(
346                b"{\"nonce\":\"test\",\"ciphertext\":\"test\",\"version\":1}"
347            ))
348            .is_err());
349    }
350
351    #[test]
352    fn test_generate_key() {
353        let key1 = FieldEncryption::generate_key();
354        let key2 = FieldEncryption::generate_key();
355
356        // Keys should be different
357        assert_ne!(key1, key2);
358
359        // Keys should be valid base64
360        assert!(base64_decode(&key1).is_ok());
361        assert!(base64_decode(&key2).is_ok());
362
363        // Decoded keys should be 32 bytes
364        assert_eq!(base64_decode(&key1).unwrap().len(), 32);
365        assert_eq!(base64_decode(&key2).unwrap().len(), 32);
366    }
367
368    #[test]
369    fn test_env_key_loading() {
370        // Test base64 key
371        let test_key = FieldEncryption::generate_key();
372        env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
373
374        let encryption = FieldEncryption::new().unwrap();
375        let plaintext = "test message";
376        let encrypted = encryption.encrypt_string(plaintext).unwrap();
377        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
378        assert_eq!(plaintext, decrypted);
379
380        // Test missing key
381        env::remove_var("STORAGE_ENCRYPTION_KEY");
382        assert!(FieldEncryption::new().is_err());
383
384        // Clean up
385        env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
386    }
387
388    #[test]
389    fn test_high_level_encryption_functions() {
390        let plaintext = "sensitive data";
391
392        // Test that the high-level encrypt/decrypt functions work together
393        let encoded = encrypt_sensitive_field(plaintext).unwrap();
394        let decoded = decrypt_sensitive_field(&encoded).unwrap();
395        assert_eq!(plaintext, decoded);
396
397        // All outputs should now be base64-encoded (whether encrypted or fallback)
398        assert!(base64_decode(&encoded).is_ok());
399
400        // Just verify it works - don't make assumptions about internal format
401        // since global encryption state may vary between test runs
402    }
403
404    #[test]
405    fn test_fallback_when_encryption_disabled() {
406        // Temporarily clear encryption key to test fallback
407        let old_key = env::var("STORAGE_ENCRYPTION_KEY").ok();
408
409        env::remove_var("STORAGE_ENCRYPTION_KEY");
410
411        let plaintext = "fallback test";
412
413        // Should use fallback mode (base64-encoded JSON)
414        let encoded = encrypt_sensitive_field(plaintext).unwrap();
415        let decoded = decrypt_sensitive_field(&encoded).unwrap();
416        assert_eq!(plaintext, decoded);
417
418        // Should be base64-encoded JSON
419        let expected_json = serde_json::to_string(plaintext).unwrap();
420        let expected_b64 = base64_encode(expected_json.as_bytes());
421        assert_eq!(encoded, expected_b64);
422
423        // Restore original environment
424        if let Some(key) = old_key {
425            env::set_var("STORAGE_ENCRYPTION_KEY", key);
426        }
427    }
428
429    #[test]
430    fn test_core_encryption_methods() {
431        let key = [9u8; 32];
432        let encryption = FieldEncryption::new_with_key(&key).unwrap();
433        let plaintext = "core encryption test";
434
435        // Test core encryption methods directly
436        let encrypted = encryption.encrypt_string(plaintext).unwrap();
437        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
438        assert_eq!(plaintext, decrypted);
439
440        // Should be base64-encoded
441        assert!(base64_decode(&encrypted).is_ok());
442        // Should not contain readable structure
443        assert!(!encrypted.contains("nonce"));
444        assert!(!encrypted.contains("ciphertext"));
445        assert!(!encrypted.contains("{"));
446    }
447
448    #[test]
449    fn test_base64_encoding_hides_structure() {
450        let key = [7u8; 32];
451        let encryption = FieldEncryption::new_with_key(&key).unwrap();
452
453        let plaintext = "secret message";
454        let encrypted = encryption.encrypt_string(plaintext).unwrap();
455
456        // Should be valid base64
457        assert!(base64_decode(&encrypted).is_ok());
458
459        // Should not contain readable JSON structure
460        assert!(!encrypted.contains("nonce"));
461        assert!(!encrypted.contains("ciphertext"));
462        assert!(!encrypted.contains("version"));
463        assert!(!encrypted.contains("{"));
464        assert!(!encrypted.contains("}"));
465
466        // Should decrypt correctly
467        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
468        assert_eq!(plaintext, decrypted);
469    }
470}