Ensure live queries are killed correctly (#2676)

This commit is contained in:
Tobie Morgan Hitchcock 2023-09-12 10:38:28 +01:00 committed by GitHub
parent 28368d83c9
commit 248829cf8a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 22 deletions

View file

@ -802,7 +802,6 @@ impl Transaction {
nxt = Some(k.clone()); nxt = Some(k.clone());
} }
// Delete // Delete
trace!("Found getr {:?} {:?}", crate::key::debug::sprint_key(&k), v);
out.push((k, v)); out.push((k, v));
// Count // Count
num -= 1; num -= 1;

View file

@ -86,29 +86,13 @@ impl Connection {
ws_id, ws_id,
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()), WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
); );
let mut live_queries_to_gc = Vec::new();
// Remove all live queries
LIVE_QUERIES.write().await.retain(|key, value| {
if value == &ws_id {
trace!("Removing live query: {}", key);
live_queries_to_gc.push(*key);
return false;
}
true
});
// Garbage collect Live Query
if let Err(e) =
DB.get().unwrap().garbage_collect_dead_session(live_queries_to_gc.as_slice()).await
{
error!("Failed to garbage collect dead sessions: {:?}", e);
}
// Spawn async tasks for the WebSocket
let mut tasks = JoinSet::new(); let mut tasks = JoinSet::new();
tasks.spawn(Self::ping(rpc.clone(), internal_sender.clone())); tasks.spawn(Self::ping(rpc.clone(), internal_sender.clone()));
tasks.spawn(Self::read(rpc.clone(), receiver, internal_sender.clone())); tasks.spawn(Self::read(rpc.clone(), receiver, internal_sender.clone()));
tasks.spawn(Self::write(rpc.clone(), sender, internal_receiver.clone())); tasks.spawn(Self::write(rpc.clone(), sender, internal_receiver.clone()));
tasks.spawn(Self::lq_notifications(rpc.clone())); tasks.spawn(Self::notifications(rpc.clone()));
// Wait until all tasks finish // Wait until all tasks finish
while let Some(res) = tasks.join_next().await { while let Some(res) = tasks.join_next().await {
@ -117,10 +101,26 @@ impl Connection {
} }
} }
trace!("WebSocket {} disconnected", ws_id);
// Remove this WebSocket from the list // Remove this WebSocket from the list
WEBSOCKETS.write().await.remove(&ws_id); WEBSOCKETS.write().await.remove(&ws_id);
trace!("WebSocket {} disconnected", ws_id); // Remove all live queries
let mut gc = Vec::new();
LIVE_QUERIES.write().await.retain(|key, value| {
if value == &ws_id {
trace!("Removing live query: {}", key);
gc.push(*key);
return false;
}
true
});
// Garbage collect queries
if let Err(e) = DB.get().unwrap().garbage_collect_dead_session(gc.as_slice()).await {
error!("Failed to garbage collect dead sessions: {:?}", e);
}
if let Err(err) = telemetry::metrics::ws::on_disconnect() { if let Err(err) = telemetry::metrics::ws::on_disconnect() {
error!("Error running metrics::ws::on_disconnect hook: {}", err); error!("Error running metrics::ws::on_disconnect hook: {}", err);
@ -241,7 +241,7 @@ impl Connection {
} }
/// Send live query notifications to the client /// Send live query notifications to the client
async fn lq_notifications(rpc: Arc<RwLock<Connection>>) { async fn notifications(rpc: Arc<RwLock<Connection>>) {
if let Some(channel) = DB.get().unwrap().notifications() { if let Some(channel) = DB.get().unwrap().notifications() {
let cancel_token = rpc.read().await.graceful_shutdown.clone(); let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop { loop {

View file

@ -15,8 +15,8 @@ use uuid::Uuid;
static CONN_CLOSED_ERR: &str = "Connection closed normally"; static CONN_CLOSED_ERR: &str = "Connection closed normally";
// Mapping of WebSocketID to WebSocket
pub struct WebSocketRef(Sender<Message>, CancellationToken); pub struct WebSocketRef(Sender<Message>, CancellationToken);
// Mapping of WebSocketID to WebSocket
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>; type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
// Mapping of LiveQueryID to WebSocketID // Mapping of LiveQueryID to WebSocketID
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>; type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;