diff --git a/src/protocol/libp2p/kademlia/message.rs b/src/protocol/libp2p/kademlia/message.rs index de8665f5..e6f72a9b 100644 --- a/src/protocol/libp2p/kademlia/message.rs +++ b/src/protocol/libp2p/kademlia/message.rs @@ -247,7 +247,7 @@ impl KademliaMessage { 4 => { let peers = message .closer_peers - .iter() + .into_iter() .filter_map(|peer| KademliaPeer::try_from(peer).ok()) .take(replication_factor) .collect(); @@ -285,7 +285,7 @@ impl KademliaMessage { record, peers: message .closer_peers - .iter() + .into_iter() .filter_map(|peer| KademliaPeer::try_from(peer).ok()) .take(replication_factor) .collect(), @@ -296,7 +296,7 @@ impl KademliaMessage { let key = (!message.key.is_empty()).then_some(message.key.into())?; let providers = message .provider_peers - .iter() + .into_iter() .filter_map(|peer| KademliaPeer::try_from(peer).ok()) .take(replication_factor) .collect(); @@ -308,13 +308,13 @@ impl KademliaMessage { let key = (!message.key.is_empty()).then_some(message.key.into()); let peers = message .closer_peers - .iter() + .into_iter() .filter_map(|peer| KademliaPeer::try_from(peer).ok()) .take(replication_factor) .collect(); let providers = message .provider_peers - .iter() + .into_iter() .filter_map(|peer| KademliaPeer::try_from(peer).ok()) .take(replication_factor) .collect(); diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index fcfb6a88..5a92e5a1 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -408,7 +408,7 @@ impl Kademlia { .await; for info in peers { - let addresses = info.addresses(); + let addresses: Vec = info.addresses().cloned().collect(); self.service.add_known_address(&info.peer, addresses.clone().into_iter()); if std::matches!(self.update_mode, RoutingTableUpdateMode::Automatic) { @@ -544,7 +544,7 @@ impl Kademlia { match (providers.len(), providers.pop()) { (1, Some(provider)) => { - let addresses = provider.addresses(); + let addresses: Vec = provider.addresses().cloned().collect(); if provider.peer == peer { self.store.put_provider( @@ -797,7 +797,7 @@ impl Kademlia { query_id: query, peers: peers .into_iter() - .map(|info| (info.peer, info.addresses())) + .map(|info| (info.peer, info.addresses().cloned().collect())) .collect(), }) .await; @@ -1431,7 +1431,10 @@ mod tests { // Check peer addresses. match kademlia.routing_table.entry(Key::from(peer)) { KBucketEntry::Occupied(entry) => { - assert_eq!(entry.addresses(), vec![address_a.clone()]); + assert_eq!( + entry.addresses().cloned().collect::>(), + vec![address_a.clone()] + ); } _ => panic!("Peer not found in routing table"), }; @@ -1450,7 +1453,7 @@ mod tests { match kademlia.routing_table.entry(Key::from(peer)) { KBucketEntry::Occupied(entry) => { assert_eq!( - entry.addresses(), + entry.addresses().cloned().collect::>(), vec![address_b.clone(), address_a.clone()] ); } @@ -1469,7 +1472,7 @@ mod tests { match kademlia.routing_table.entry(Key::from(peer)) { KBucketEntry::Occupied(entry) => { assert_eq!( - entry.addresses(), + entry.addresses().cloned().collect::>(), vec![address_b.clone(), address_a.clone()] ); } @@ -1483,7 +1486,7 @@ mod tests { match kademlia.routing_table.entry(Key::from(peer)) { KBucketEntry::Occupied(entry) => { assert_eq!( - entry.addresses(), + entry.addresses().cloned().collect::>(), vec![address_a.clone(), address_b.clone()] ); } diff --git a/src/protocol/libp2p/kademlia/query/get_providers.rs b/src/protocol/libp2p/kademlia/query/get_providers.rs index d6874bd1..3c31ae80 100644 --- a/src/protocol/libp2p/kademlia/query/get_providers.rs +++ b/src/protocol/libp2p/kademlia/query/get_providers.rs @@ -117,7 +117,10 @@ impl GetProvidersContext { // Merge addresses of different provider records of the same peer. let mut providers = HashMap::>::new(); found_providers.into_iter().for_each(|provider| { - providers.entry(provider.peer).or_default().extend(provider.addresses()) + providers + .entry(provider.peer) + .or_default() + .extend(provider.addresses().cloned()) }); // Convert into `Vec` diff --git a/src/protocol/libp2p/kademlia/record.rs b/src/protocol/libp2p/kademlia/record.rs index 322553d4..643d5679 100644 --- a/src/protocol/libp2p/kademlia/record.rs +++ b/src/protocol/libp2p/kademlia/record.rs @@ -23,7 +23,7 @@ use crate::{ protocol::libp2p::kademlia::types::{ ConnectionType, Distance, KademliaPeer, Key as KademliaKey, }, - transport::manager::address::{AddressRecord, AddressStore}, + transport::manager::address::AddressStoreBuckets, Multiaddr, PeerId, }; @@ -170,15 +170,10 @@ pub struct ContentProvider { impl From for KademliaPeer { fn from(provider: ContentProvider) -> Self { - let mut address_store = AddressStore::new(); - for address in provider.addresses.iter() { - address_store.insert(AddressRecord::from_raw_multiaddr(address.clone())); - } - Self { key: KademliaKey::from(provider.peer), peer: provider.peer, - address_store, + address_store: AddressStoreBuckets::from_unknown(provider.addresses), connection: ConnectionType::NotConnected, } } diff --git a/src/protocol/libp2p/kademlia/routing_table.rs b/src/protocol/libp2p/kademlia/routing_table.rs index e012318e..dd317dd4 100644 --- a/src/protocol/libp2p/kademlia/routing_table.rs +++ b/src/protocol/libp2p/kademlia/routing_table.rs @@ -400,7 +400,7 @@ mod tests { KBucketEntry::Occupied(entry) => { assert_eq!(entry.key, key); assert_eq!(entry.peer, peer); - assert_eq!(entry.addresses(), addresses); + assert_eq!(entry.addresses().cloned().collect::>(), addresses); assert_eq!(entry.connection, ConnectionType::Connected); } state => panic!("invalid state for `KBucketEntry`: {state:?}"), @@ -418,7 +418,7 @@ mod tests { KBucketEntry::Occupied(entry) => { assert_eq!(entry.key, key); assert_eq!(entry.peer, peer); - assert_eq!(entry.addresses(), addresses); + assert_eq!(entry.addresses().cloned().collect::>(), addresses); assert_eq!(entry.connection, ConnectionType::NotConnected); } state => panic!("invalid state for `KBucketEntry`: {state:?}"), @@ -508,7 +508,7 @@ mod tests { KBucketEntry::Occupied(entry) => { assert_eq!(entry.key, key); assert_eq!(entry.peer, peer); - assert_eq!(entry.addresses(), addresses); + assert_eq!(entry.addresses().cloned().collect::>(), addresses); assert_eq!(entry.connection, ConnectionType::CanConnect); } state => panic!("invalid state for `KBucketEntry`: {state:?}"), diff --git a/src/protocol/libp2p/kademlia/types.rs b/src/protocol/libp2p/kademlia/types.rs index c71bb878..b4949f2c 100644 --- a/src/protocol/libp2p/kademlia/types.rs +++ b/src/protocol/libp2p/kademlia/types.rs @@ -26,7 +26,7 @@ use crate::{ protocol::libp2p::kademlia::schema, - transport::manager::address::{AddressRecord, AddressStore}, + transport::manager::address::{AddressRecord, AddressStoreBuckets}, PeerId, }; @@ -254,7 +254,7 @@ pub struct KademliaPeer { pub(super) peer: PeerId, /// Known addresses of peer. - pub(super) address_store: AddressStore, + pub(super) address_store: AddressStoreBuckets, /// Connection type. pub(super) connection: ConnectionType, @@ -263,15 +263,9 @@ pub struct KademliaPeer { impl KademliaPeer { /// Create new [`KademliaPeer`]. pub fn new(peer: PeerId, addresses: Vec, connection: ConnectionType) -> Self { - let mut address_store = AddressStore::new(); - - for address in addresses.into_iter() { - address_store.insert(AddressRecord::from_raw_multiaddr(address)); - } - Self { peer, - address_store, + address_store: AddressStoreBuckets::from_unknown(addresses), connection, key: Key::from(peer), } @@ -285,29 +279,22 @@ impl KademliaPeer { } /// Returns the addresses of the peer. - pub fn addresses(&self) -> Vec { + pub fn addresses(&self) -> impl Iterator { self.address_store.addresses(MAX_ADDRESSES) } } -impl TryFrom<&schema::kademlia::Peer> for KademliaPeer { +impl TryFrom for KademliaPeer { type Error = (); - fn try_from(record: &schema::kademlia::Peer) -> Result { + fn try_from(record: schema::kademlia::Peer) -> Result { let peer = PeerId::from_bytes(&record.id).map_err(|_| ())?; - - let mut address_store = AddressStore::new(); - for address in record.addrs.iter() { - let Ok(address) = Multiaddr::try_from(address.clone()) else { - continue; - }; - address_store.insert(AddressRecord::from_raw_multiaddr(address)); - } + let addresses = record.addrs.into_iter().filter_map(|addr| Multiaddr::try_from(addr).ok()); Ok(KademliaPeer { key: Key::from(peer), peer, - address_store, + address_store: AddressStoreBuckets::from_unknown(addresses), connection: ConnectionType::try_from(record.connection)?, }) } @@ -320,7 +307,6 @@ impl From<&KademliaPeer> for schema::kademlia::Peer { addrs: peer .address_store .addresses(MAX_ADDRESSES) - .iter() .map(|address| address.to_vec()) .collect(), connection: peer.connection.into(), diff --git a/src/transport/manager/address.rs b/src/transport/manager/address.rs index 2c3cb5d0..f78bbf57 100644 --- a/src/transport/manager/address.rs +++ b/src/transport/manager/address.rs @@ -23,11 +23,20 @@ use crate::{error::DialError, PeerId}; use multiaddr::{Multiaddr, Protocol}; use multihash::Multihash; -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::{hash_map::Entry, HashMap, HashSet}; /// Maximum number of addresses tracked for a peer. const MAX_ADDRESSES: usize = 64; +/// Maximum number of addresses tracked for a peer in the success bucket. +const MAX_SUCCESS_ADDRESSES: usize = 32; + +/// Maximum number of addresses tracked for a peer in the unknown bucket. +const MAX_UNKNOWN_ADDRESSES: usize = 16; + +/// Maximum number of addresses tracked for a peer in the failure bucket. +const MAX_FAILURE_ADDRESSES: usize = 16; + /// Scores for address records. pub mod scores { /// Score indicating that the connection was successfully established. @@ -261,6 +270,105 @@ impl AddressStore { } } +/// Buckets for storing addresses based on dial results. +/// +/// This is a more optimized version of [`AddressStore`] that separates addresses +/// based on their dial results (success, unknown, failure). +/// +/// It allows for more efficient management of addresses based on their dial outcomes, +/// reducing the need for sorting and filtering during address selection. +#[derive(Debug, Clone, Default)] +pub struct AddressStoreBuckets { + /// Addresses with successful dials. + pub success: HashSet, + + /// Addresses not yet dialed. + pub unknown: HashSet, + + /// Addresses with dial failures. + pub failure: HashSet, +} + +impl AddressStoreBuckets { + /// Create new [`AddressStoreBuckets`]. + pub fn new() -> Self { + Self { + success: HashSet::with_capacity(MAX_SUCCESS_ADDRESSES), + unknown: HashSet::with_capacity(MAX_UNKNOWN_ADDRESSES), + failure: HashSet::with_capacity(MAX_FAILURE_ADDRESSES), + } + } + + /// Create [`AddressStoreBuckets`] from a set of unknown addresses. + /// + /// If the addresses exceed the maximum capacity, they will be truncated. + pub fn from_unknown(addresses: impl IntoIterator) -> Self { + let mut store = Self::new(); + for address in addresses.into_iter().take(MAX_UNKNOWN_ADDRESSES) { + store.unknown.insert(address); + } + store + } + + /// Insert an address record into the appropriate bucket based on its score. + pub fn insert(&mut self, record: AddressRecord) { + let AddressRecord { score, address } = record; + + match score { + score if score > 0 => { + // Moves directly to the success bucket. + self.unknown.remove(&address); + self.failure.remove(&address); + + Self::ensure_space(&mut self.success); + self.success.insert(address); + } + 0 => { + // Moves to the unknown bucket. + self.success.remove(&address); + self.failure.remove(&address); + + Self::ensure_space(&mut self.unknown); + self.unknown.insert(address); + } + _ => { + // Moves to the failure bucket. + self.success.remove(&address); + self.unknown.remove(&address); + + Self::ensure_space(&mut self.failure); + self.failure.insert(address); + } + } + } + + /// Ensure that there is space in the bucket. + fn ensure_space(bucket: &mut HashSet) { + if bucket.len() < bucket.capacity() { + return; + } + + // Remove the first element to ensure space. + if let Some(first) = bucket.iter().next().cloned() { + bucket.remove(&first); + } + } + + /// Check if the store is empty. + pub fn is_empty(&self) -> bool { + self.success.is_empty() && self.unknown.is_empty() && self.failure.is_empty() + } + + /// Return the available addresses from all buckets. + pub fn addresses(&self, limit: usize) -> impl Iterator { + self.success + .iter() + .chain(self.unknown.iter()) + .chain(self.failure.iter()) + .take(limit) + } +} + #[cfg(test)] mod tests { use std::{