diff --git a/network/src/dht/background_tasks.rs b/network/src/dht/background_tasks.rs index 89b2827400..ec577efdf6 100644 --- a/network/src/dht/background_tasks.rs +++ b/network/src/dht/background_tasks.rs @@ -1,5 +1,6 @@ use std::collections::hash_map; use std::sync::Arc; +use std::sync::atomic::Ordering; use anyhow::Result; use futures_util::StreamExt; @@ -150,7 +151,13 @@ impl DhtInner { value.signature = &signature; self.store_value(network, &ValueRef::Peer(value), true) - .await + .await?; + + if !self.local_info_pre_announced.swap(true, Ordering::Acquire) { + self.local_info_announced_notify.notify_waiters(); + } + + Ok(()) } #[tracing::instrument(level = "debug", skip_all, fields(local_id = %self.local_id))] diff --git a/network/src/dht/mod.rs b/network/src/dht/mod.rs index 6a7773761f..4a62045a75 100644 --- a/network/src/dht/mod.rs +++ b/network/src/dht/mod.rs @@ -1,12 +1,14 @@ +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; +use ahash::HashMapExt; use anyhow::Result; use bytes::{Buf, Bytes}; use rand::RngCore; use tl_proto::TlRead; use tokio::sync::{Notify, broadcast}; -use tycho_util::realloc_box_enum; use tycho_util::time::now_sec; +use tycho_util::{FastHashMap, realloc_box_enum}; pub use self::config::DhtConfig; pub use self::peer_resolver::{ @@ -344,6 +346,9 @@ impl DhtServiceBuilder { announced_peers, find_value_queries: Default::default(), peer_added: Arc::new(Default::default()), + + local_info_pre_announced: AtomicBool::new(false), + local_info_announced_notify: Arc::new(Default::default()), }); let background_tasks = DhtServiceBackgroundTasks { @@ -414,6 +419,10 @@ impl DhtService { pub fn peer_added(&self) -> &Arc { &self.0.peer_added } + + pub async fn wait_for_pre_announce(&self) { + self.0.wait_for_local_info_announced().await; + } } impl Service for DhtService { @@ -541,9 +550,21 @@ struct DhtInner { announced_peers: broadcast::Sender>, find_value_queries: QueryCache>>, peer_added: Arc, + + local_info_pre_announced: AtomicBool, + local_info_announced_notify: Arc, } impl DhtInner { + async fn wait_for_local_info_announced(&self) { + let notified = self.local_info_announced_notify.notified(); + if self.local_info_pre_announced.load(Ordering::Acquire) { + return; + } + + notified.await; + } + async fn find_value( &self, network: &Network, @@ -574,27 +595,56 @@ impl DhtInner { ) -> Result<()> { self.storage.insert(DhtValueSource::Local, value)?; - let local_peer_info = if with_peer_info { - let mut node_info = self.local_peer_info.lock().unwrap(); + let local_info = if with_peer_info { + let mut info = self.local_peer_info.lock().unwrap(); Some( - node_info - .get_or_insert_with(|| self.make_local_peer_info(network, now_sec())) + info.get_or_insert_with(|| self.make_local_peer_info(network, now_sec())) .clone(), ) } else { None }; - let query = StoreValue::new( - network.clone(), - &self.routing_table.lock().unwrap(), - value, - self.config.max_k, - local_peer_info.as_ref(), - ); + let key_hash = match value { + ValueRef::Peer(value) => tl_proto::hash(&value.key), + ValueRef::Merged(value) => tl_proto::hash(&value.key), + }; + + let local_peers = { + let table = self.routing_table.lock().unwrap(); + table.closest(&key_hash, self.config.max_k * 2) + }; + + let query = { + let table = self.routing_table.lock().unwrap(); + Query::new( + network.clone(), + &table, + &key_hash, + self.config.max_k, + DhtQueryMode::Closest, + ) + }; + let lookup_peers = query.find_peers(Some(3)).await; + + let mut candidates = FastHashMap::new(); + for peer in local_peers { + candidates.insert(peer.id, peer); + } + for (_, peer) in lookup_peers { + // Ensure the peer is known so store requests can connect. + let _ = network.known_peers().insert(peer.clone(), false); + candidates.insert(peer.id, peer); + } + + let mut chosen = candidates.into_values().collect::>(); + chosen.sort_by_key(|peer| xor_distance(&peer.id, PeerId::wrap(&key_hash))); + chosen.truncate(self.config.max_k); + + StoreValue::new_with_peers(network.clone(), chosen, value, local_info.as_ref()) + .run() + .await; - // NOTE: expression is intentionally split to drop the routing table guard - query.run().await; Ok(()) } diff --git a/network/src/dht/peer_resolver.rs b/network/src/dht/peer_resolver.rs index 438248e3b6..fcf65cda73 100644 --- a/network/src/dht/peer_resolver.rs +++ b/network/src/dht/peer_resolver.rs @@ -285,10 +285,16 @@ impl PeerResolverInner { // "Fast" path let mut attempts = 0usize; + let mut pre_announce_complete = false; loop { attempts += 1; let is_stale = attempts > self.config.fast_retry_count as usize; + if !pre_announce_complete { + self.dht_service.wait_for_pre_announce().await; + pre_announce_complete = true; + } + // NOTE: Acquire network ref only during the operation. { let network = self.weak_network.upgrade()?; diff --git a/network/src/dht/query.rs b/network/src/dht/query.rs index b68adfa9c8..0cbeb5bc78 100644 --- a/network/src/dht/query.rs +++ b/network/src/dht/query.rs @@ -397,36 +397,64 @@ pub struct StoreValue { } impl StoreValue<()> { - pub fn new( + // pub fn new( + // network: Network, + // routing_table: &HandlesRoutingTable, + // value: &ValueRef<'_>, + // max_k: usize, + // local_peer_info: Option<&PeerInfo>, + // ) -> StoreValue, Option>)> + Send + use<>> { + // let key_hash = match value { + // ValueRef::Peer(value) => tl_proto::hash(&value.key), + // ValueRef::Merged(value) => tl_proto::hash(&value.key), + // }; + // + // let request_body = Bytes::from(match local_peer_info { + // Some(peer_info) => tl_proto::serialize(( + // rpc::WithPeerInfo::wrap(peer_info), + // rpc::StoreRef::wrap(value), + // )), + // None => tl_proto::serialize(rpc::StoreRef::wrap(value)), + // }); + // + // let semaphore = Arc::new(Semaphore::new(10)); + // let futures = futures_util::stream::FuturesUnordered::new(); + // routing_table.visit_closest(&key_hash, max_k, |node| { + // futures.push(Self::visit( + // network.clone(), + // node.load_peer_info(), + // request_body.clone(), + // semaphore.clone(), + // )); + // }); + // + // StoreValue { futures } + // } + + pub fn new_with_peers( network: Network, - routing_table: &HandlesRoutingTable, + peers: Vec>, value: &ValueRef<'_>, - max_k: usize, - local_peer_info: Option<&PeerInfo>, + local_info: Option<&PeerInfo>, ) -> StoreValue, Option>)> + Send + use<>> { - let key_hash = match value { - ValueRef::Peer(value) => tl_proto::hash(&value.key), - ValueRef::Merged(value) => tl_proto::hash(&value.key), - }; - - let request_body = Bytes::from(match local_peer_info { - Some(peer_info) => tl_proto::serialize(( - rpc::WithPeerInfo::wrap(peer_info), - rpc::StoreRef::wrap(value), - )), + let request_body = Bytes::from(match local_info { + Some(info) => { + tl_proto::serialize((rpc::WithPeerInfo::wrap(info), rpc::StoreRef::wrap(value))) + } None => tl_proto::serialize(rpc::StoreRef::wrap(value)), }); let semaphore = Arc::new(Semaphore::new(10)); let futures = futures_util::stream::FuturesUnordered::new(); - routing_table.visit_closest(&key_hash, max_k, |node| { + + for peer in peers { futures.push(Self::visit( network.clone(), - node.load_peer_info(), + peer, request_body.clone(), semaphore.clone(), )); - }); + } StoreValue { futures } } diff --git a/network/tests/private_overlay.rs b/network/tests/private_overlay.rs index b1d7bd818e..1a05e94e81 100644 --- a/network/tests/private_overlay.rs +++ b/network/tests/private_overlay.rs @@ -85,7 +85,7 @@ fn make_network(node_count: usize) -> Vec { async fn private_overlays_accessible() -> Result<()> { tycho_util::test::init_logger("private_overlays_accessible", "debug"); - let nodes = make_network(20); + let nodes = make_network(30); for node in &nodes { let resolved = FuturesUnordered::new();