use std::{ fs, path::{Path, PathBuf}, }; use openmls_memory_storage::MemoryStorage; use openmls_traits::storage::{traits, StorageProvider, CURRENT_VERSION}; /// A disk-backed storage provider implementing `StorageProvider`. /// /// Wraps `openmls_memory_storage::MemoryStorage` and flushes to disk on every /// write so that HPKE init keys and group state survive process restarts. /// /// # Serialization /// /// Uses bincode for the outer `HashMap, Vec>` container when /// persisting to disk. The inner values use serde_json (matching /// `MemoryStorage`'s serialization format). /// /// # Persistence security /// /// When `path` is set, file permissions are restricted to owner-only (0o600) /// on Unix platforms, since the store may contain HPKE private keys. #[derive(Debug)] pub struct DiskKeyStore { path: Option, storage: MemoryStorage, } #[derive(thiserror::Error, Debug)] pub enum DiskKeyStoreError { #[error("serialization error")] Serialization, #[error("io error: {0}")] Io(String), #[error("memory storage error: {0}")] MemoryStorage(#[from] openmls_memory_storage::MemoryStorageError), } impl DiskKeyStore { /// In-memory keystore (no persistence). pub fn ephemeral() -> Self { Self { path: None, storage: MemoryStorage::default(), } } /// Persistent keystore backed by `path`. Creates an empty store if missing. pub fn persistent(path: impl AsRef) -> Result { let path = path.as_ref().to_path_buf(); let storage = if path.exists() { let bytes = fs::read(&path).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?; if bytes.is_empty() { MemoryStorage::default() } else { let map: std::collections::HashMap, Vec> = bincode::deserialize(&bytes) .map_err(|_| DiskKeyStoreError::Serialization)?; let storage = MemoryStorage::default(); let mut values = storage.values.write() .map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; *values = map; drop(values); storage } } else { MemoryStorage::default() }; let store = Self { path: Some(path), storage, }; // Set restrictive file permissions on the keystore file. store.set_file_permissions()?; Ok(store) } fn flush(&self) -> Result<(), DiskKeyStoreError> { let Some(path) = &self.path else { return Ok(()); }; let values = self.storage.values.read() .map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; let bytes = bincode::serialize(&*values) .map_err(|_| DiskKeyStoreError::Serialization)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?; } fs::write(path, &bytes).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?; self.set_file_permissions()?; Ok(()) } /// Serialize the backing storage to bytes (bincode). /// /// This captures all key material *and* MLS group state held by the /// `StorageProvider`, allowing the caller to persist it in a database /// column instead of (or in addition to) on-disk files. pub fn to_bytes(&self) -> Result, DiskKeyStoreError> { let values = self.storage.values.read() .map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; bincode::serialize(&*values).map_err(|_| DiskKeyStoreError::Serialization) } /// Restore a `DiskKeyStore` from bytes previously produced by [`to_bytes`]. pub fn from_bytes(bytes: &[u8]) -> Result { let map: std::collections::HashMap, Vec> = bincode::deserialize(bytes).map_err(|_| DiskKeyStoreError::Serialization)?; let storage = MemoryStorage::default(); let mut values = storage.values.write() .map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?; *values = map; drop(values); Ok(Self { path: None, storage, }) } /// Restrict file permissions to owner-only (0o600) on Unix. #[cfg(unix)] fn set_file_permissions(&self) -> Result<(), DiskKeyStoreError> { use std::os::unix::fs::PermissionsExt; if let Some(path) = &self.path { if path.exists() { let perms = std::fs::Permissions::from_mode(0o600); fs::set_permissions(path, perms) .map_err(|e| DiskKeyStoreError::Io(format!("set permissions: {e}")))?; } } Ok(()) } #[cfg(not(unix))] fn set_file_permissions(&self) -> Result<(), DiskKeyStoreError> { Ok(()) } } impl Default for DiskKeyStore { fn default() -> Self { Self::ephemeral() } } /// Delegate all `StorageProvider` methods to the inner `MemoryStorage`, /// flushing to disk after every write/delete operation. /// /// The flush errors are mapped to `DiskKeyStoreError` via the /// `MemoryStorageError` conversion. If a flush fails, the in-memory state /// is still updated (matching the old DiskKeyStore behavior). impl StorageProvider for DiskKeyStore { type Error = DiskKeyStoreError; fn write_mls_join_config< GroupId: traits::GroupId, MlsGroupJoinConfig: traits::MlsGroupJoinConfig, >( &self, group_id: &GroupId, config: &MlsGroupJoinConfig, ) -> Result<(), Self::Error> { self.storage.write_mls_join_config(group_id, config)?; self.flush() } fn append_own_leaf_node< GroupId: traits::GroupId, LeafNode: traits::LeafNode, >( &self, group_id: &GroupId, leaf_node: &LeafNode, ) -> Result<(), Self::Error> { self.storage.append_own_leaf_node(group_id, leaf_node)?; self.flush() } fn queue_proposal< GroupId: traits::GroupId, ProposalRef: traits::ProposalRef, QueuedProposal: traits::QueuedProposal, >( &self, group_id: &GroupId, proposal_ref: &ProposalRef, proposal: &QueuedProposal, ) -> Result<(), Self::Error> { self.storage.queue_proposal(group_id, proposal_ref, proposal)?; self.flush() } fn write_tree< GroupId: traits::GroupId, TreeSync: traits::TreeSync, >( &self, group_id: &GroupId, tree: &TreeSync, ) -> Result<(), Self::Error> { self.storage.write_tree(group_id, tree)?; self.flush() } fn write_interim_transcript_hash< GroupId: traits::GroupId, InterimTranscriptHash: traits::InterimTranscriptHash, >( &self, group_id: &GroupId, interim_transcript_hash: &InterimTranscriptHash, ) -> Result<(), Self::Error> { self.storage.write_interim_transcript_hash(group_id, interim_transcript_hash)?; self.flush() } fn write_context< GroupId: traits::GroupId, GroupContext: traits::GroupContext, >( &self, group_id: &GroupId, group_context: &GroupContext, ) -> Result<(), Self::Error> { self.storage.write_context(group_id, group_context)?; self.flush() } fn write_confirmation_tag< GroupId: traits::GroupId, ConfirmationTag: traits::ConfirmationTag, >( &self, group_id: &GroupId, confirmation_tag: &ConfirmationTag, ) -> Result<(), Self::Error> { self.storage.write_confirmation_tag(group_id, confirmation_tag)?; self.flush() } fn write_group_state< GroupState: traits::GroupState, GroupId: traits::GroupId, >( &self, group_id: &GroupId, group_state: &GroupState, ) -> Result<(), Self::Error> { self.storage.write_group_state(group_id, group_state)?; self.flush() } fn write_message_secrets< GroupId: traits::GroupId, MessageSecrets: traits::MessageSecrets, >( &self, group_id: &GroupId, message_secrets: &MessageSecrets, ) -> Result<(), Self::Error> { self.storage.write_message_secrets(group_id, message_secrets)?; self.flush() } fn write_resumption_psk_store< GroupId: traits::GroupId, ResumptionPskStore: traits::ResumptionPskStore, >( &self, group_id: &GroupId, resumption_psk_store: &ResumptionPskStore, ) -> Result<(), Self::Error> { self.storage.write_resumption_psk_store(group_id, resumption_psk_store)?; self.flush() } fn write_own_leaf_index< GroupId: traits::GroupId, LeafNodeIndex: traits::LeafNodeIndex, >( &self, group_id: &GroupId, own_leaf_index: &LeafNodeIndex, ) -> Result<(), Self::Error> { self.storage.write_own_leaf_index(group_id, own_leaf_index)?; self.flush() } fn write_group_epoch_secrets< GroupId: traits::GroupId, GroupEpochSecrets: traits::GroupEpochSecrets, >( &self, group_id: &GroupId, group_epoch_secrets: &GroupEpochSecrets, ) -> Result<(), Self::Error> { self.storage.write_group_epoch_secrets(group_id, group_epoch_secrets)?; self.flush() } fn write_signature_key_pair< SignaturePublicKey: traits::SignaturePublicKey, SignatureKeyPair: traits::SignatureKeyPair, >( &self, public_key: &SignaturePublicKey, signature_key_pair: &SignatureKeyPair, ) -> Result<(), Self::Error> { self.storage.write_signature_key_pair(public_key, signature_key_pair)?; self.flush() } fn write_encryption_key_pair< EncryptionKey: traits::EncryptionKey, HpkeKeyPair: traits::HpkeKeyPair, >( &self, public_key: &EncryptionKey, key_pair: &HpkeKeyPair, ) -> Result<(), Self::Error> { self.storage.write_encryption_key_pair(public_key, key_pair)?; self.flush() } fn write_encryption_epoch_key_pairs< GroupId: traits::GroupId, EpochKey: traits::EpochKey, HpkeKeyPair: traits::HpkeKeyPair, >( &self, group_id: &GroupId, epoch: &EpochKey, leaf_index: u32, key_pairs: &[HpkeKeyPair], ) -> Result<(), Self::Error> { self.storage.write_encryption_epoch_key_pairs(group_id, epoch, leaf_index, key_pairs)?; self.flush() } fn write_key_package< HashReference: traits::HashReference, KeyPackage: traits::KeyPackage, >( &self, hash_ref: &HashReference, key_package: &KeyPackage, ) -> Result<(), Self::Error> { self.storage.write_key_package(hash_ref, key_package)?; self.flush() } fn write_psk< PskId: traits::PskId, PskBundle: traits::PskBundle, >( &self, psk_id: &PskId, psk: &PskBundle, ) -> Result<(), Self::Error> { self.storage.write_psk(psk_id, psk)?; self.flush() } // --- getters (no flush needed) --- fn mls_group_join_config< GroupId: traits::GroupId, MlsGroupJoinConfig: traits::MlsGroupJoinConfig, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.mls_group_join_config(group_id)?) } fn own_leaf_nodes< GroupId: traits::GroupId, LeafNode: traits::LeafNode, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.own_leaf_nodes(group_id)?) } fn queued_proposal_refs< GroupId: traits::GroupId, ProposalRef: traits::ProposalRef, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.queued_proposal_refs(group_id)?) } fn queued_proposals< GroupId: traits::GroupId, ProposalRef: traits::ProposalRef, QueuedProposal: traits::QueuedProposal, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.queued_proposals(group_id)?) } fn tree< GroupId: traits::GroupId, TreeSync: traits::TreeSync, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.tree(group_id)?) } fn group_context< GroupId: traits::GroupId, GroupContext: traits::GroupContext, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.group_context(group_id)?) } fn interim_transcript_hash< GroupId: traits::GroupId, InterimTranscriptHash: traits::InterimTranscriptHash, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.interim_transcript_hash(group_id)?) } fn confirmation_tag< GroupId: traits::GroupId, ConfirmationTag: traits::ConfirmationTag, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.confirmation_tag(group_id)?) } fn group_state< GroupState: traits::GroupState, GroupId: traits::GroupId, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.group_state(group_id)?) } fn message_secrets< GroupId: traits::GroupId, MessageSecrets: traits::MessageSecrets, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.message_secrets(group_id)?) } fn resumption_psk_store< GroupId: traits::GroupId, ResumptionPskStore: traits::ResumptionPskStore, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.resumption_psk_store(group_id)?) } fn own_leaf_index< GroupId: traits::GroupId, LeafNodeIndex: traits::LeafNodeIndex, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.own_leaf_index(group_id)?) } fn group_epoch_secrets< GroupId: traits::GroupId, GroupEpochSecrets: traits::GroupEpochSecrets, >( &self, group_id: &GroupId, ) -> Result, Self::Error> { Ok(self.storage.group_epoch_secrets(group_id)?) } fn signature_key_pair< SignaturePublicKey: traits::SignaturePublicKey, SignatureKeyPair: traits::SignatureKeyPair, >( &self, public_key: &SignaturePublicKey, ) -> Result, Self::Error> { Ok(self.storage.signature_key_pair(public_key)?) } fn encryption_key_pair< HpkeKeyPair: traits::HpkeKeyPair, EncryptionKey: traits::EncryptionKey, >( &self, public_key: &EncryptionKey, ) -> Result, Self::Error> { Ok(self.storage.encryption_key_pair(public_key)?) } fn encryption_epoch_key_pairs< GroupId: traits::GroupId, EpochKey: traits::EpochKey, HpkeKeyPair: traits::HpkeKeyPair, >( &self, group_id: &GroupId, epoch: &EpochKey, leaf_index: u32, ) -> Result, Self::Error> { Ok(self.storage.encryption_epoch_key_pairs(group_id, epoch, leaf_index)?) } fn key_package< KeyPackageRef: traits::HashReference, KeyPackage: traits::KeyPackage, >( &self, hash_ref: &KeyPackageRef, ) -> Result, Self::Error> { Ok(self.storage.key_package(hash_ref)?) } fn psk< PskBundle: traits::PskBundle, PskId: traits::PskId, >( &self, psk_id: &PskId, ) -> Result, Self::Error> { Ok(self.storage.psk(psk_id)?) } // --- deleters (flush needed) --- fn remove_proposal< GroupId: traits::GroupId, ProposalRef: traits::ProposalRef, >( &self, group_id: &GroupId, proposal_ref: &ProposalRef, ) -> Result<(), Self::Error> { self.storage.remove_proposal(group_id, proposal_ref)?; self.flush() } fn delete_own_leaf_nodes>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_own_leaf_nodes(group_id)?; self.flush() } fn delete_group_config>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_group_config(group_id)?; self.flush() } fn delete_tree>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_tree(group_id)?; self.flush() } fn delete_confirmation_tag>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_confirmation_tag(group_id)?; self.flush() } fn delete_group_state>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_group_state(group_id)?; self.flush() } fn delete_context>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_context(group_id)?; self.flush() } fn delete_interim_transcript_hash>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_interim_transcript_hash(group_id)?; self.flush() } fn delete_message_secrets>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_message_secrets(group_id)?; self.flush() } fn delete_all_resumption_psk_secrets>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_all_resumption_psk_secrets(group_id)?; self.flush() } fn delete_own_leaf_index>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_own_leaf_index(group_id)?; self.flush() } fn delete_group_epoch_secrets>( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.delete_group_epoch_secrets(group_id)?; self.flush() } fn clear_proposal_queue< GroupId: traits::GroupId, ProposalRef: traits::ProposalRef, >( &self, group_id: &GroupId, ) -> Result<(), Self::Error> { self.storage.clear_proposal_queue::(group_id)?; self.flush() } fn delete_signature_key_pair< SignaturePublicKey: traits::SignaturePublicKey, >( &self, public_key: &SignaturePublicKey, ) -> Result<(), Self::Error> { self.storage.delete_signature_key_pair(public_key)?; self.flush() } fn delete_encryption_key_pair>( &self, public_key: &EncryptionKey, ) -> Result<(), Self::Error> { self.storage.delete_encryption_key_pair(public_key)?; self.flush() } fn delete_encryption_epoch_key_pairs< GroupId: traits::GroupId, EpochKey: traits::EpochKey, >( &self, group_id: &GroupId, epoch: &EpochKey, leaf_index: u32, ) -> Result<(), Self::Error> { self.storage.delete_encryption_epoch_key_pairs(group_id, epoch, leaf_index)?; self.flush() } fn delete_key_package>( &self, hash_ref: &KeyPackageRef, ) -> Result<(), Self::Error> { self.storage.delete_key_package(hash_ref)?; self.flush() } fn delete_psk>( &self, psk_id: &PskKey, ) -> Result<(), Self::Error> { self.storage.delete_psk(psk_id)?; self.flush() } }