use std::{ collections::HashMap, fs, path::{Path, PathBuf}, sync::RwLock, }; use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore}; /// A disk-backed key store implementing `OpenMlsKeyStore`. /// /// In-memory when `path` is `None`; otherwise flushes the entire map to disk on /// every store/delete so HPKE init keys survive process restarts. #[derive(Debug)] pub struct DiskKeyStore { path: Option, values: RwLock, Vec>>, } #[derive(thiserror::Error, Debug, PartialEq, Eq)] pub enum DiskKeyStoreError { #[error("serialization error")] Serialization, #[error("io error: {0}")] Io(String), } impl DiskKeyStore { /// In-memory keystore (no persistence). pub fn ephemeral() -> Self { Self { path: None, values: RwLock::new(HashMap::new()), } } /// Persistent keystore backed by `path`. Creates an empty store if missing. pub fn persistent(path: impl AsRef) -> Result { let path = path.as_ref().to_path_buf(); let values = if path.exists() { let bytes = fs::read(&path).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?; if bytes.is_empty() { HashMap::new() } else { bincode::deserialize(&bytes).map_err(|_| DiskKeyStoreError::Serialization)? } } else { HashMap::new() }; Ok(Self { path: Some(path), values: RwLock::new(values), }) } fn flush(&self) -> Result<(), DiskKeyStoreError> { let Some(path) = &self.path else { return Ok(()); }; let values = self.values.read().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; let bytes = bincode::serialize(&*values).map_err(|_| DiskKeyStoreError::Serialization)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| DiskKeyStoreError::Io(e.to_string())) } } impl Default for DiskKeyStore { fn default() -> Self { Self::ephemeral() } } impl OpenMlsKeyStore for DiskKeyStore { type Error = DiskKeyStoreError; fn store(&self, k: &[u8], v: &V) -> Result<(), Self::Error> { let value = serde_json::to_vec(v).map_err(|_| DiskKeyStoreError::Serialization)?; let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; values.insert(k.to_vec(), value); drop(values); self.flush() } fn read(&self, k: &[u8]) -> Option { let values = match self.values.read() { Ok(v) => v, Err(_) => return None, }; values .get(k) .and_then(|bytes| serde_json::from_slice(bytes).ok()) } fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; values.remove(k); drop(values); self.flush() } }