surrealpatch/sdk/src/api/engine/remote/ws/wasm.rs
2024-09-16 15:49:55 +00:00

549 lines
14 KiB
Rust

use super::{HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH};
use crate::api::conn::DbResponse;
use crate::api::conn::Route;
use crate::api::conn::Router;
use crate::api::conn::{Command, Connection, RequestData};
use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::Response;
use crate::api::engine::remote::{deserialize, serialize};
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint;
use crate::api::ExtraFeatures;
use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::Data;
use crate::engine::IntervalStream;
use crate::opt::WaitFor;
use crate::{Action, Notification};
use channel::{Receiver, Sender};
use futures::stream::{SplitSink, SplitStream};
use futures::FutureExt;
use futures::SinkExt;
use futures::StreamExt;
use pharos::Channel;
use pharos::Events;
use pharos::Observable;
use pharos::ObserveConfig;
use revision::revisioned;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use surrealdb_core::sql::Value as CoreValue;
use tokio::sync::watch;
use trice::Instant;
use wasm_bindgen_futures::spawn_local;
use wasmtimer::tokio as time;
use wasmtimer::tokio::MissedTickBehavior;
use ws_stream_wasm::WsMessage as Message;
use ws_stream_wasm::WsMeta;
use ws_stream_wasm::{WsEvent, WsStream};
type MessageStream = SplitStream<WsStream>;
type MessageSink = SplitSink<WsStream, Message>;
type RouterState = super::RouterState<MessageSink, MessageStream>;
impl crate::api::Connection for Client {}
impl Connection for Client {
fn connect(
mut address: Endpoint,
capacity: usize,
) -> BoxFuture<'static, Result<Surreal<Self>>> {
Box::pin(async move {
address.url = address.url.join(PATH)?;
let (route_tx, route_rx) = match capacity {
0 => channel::unbounded(),
capacity => channel::bounded(capacity),
};
let (conn_tx, conn_rx) = channel::bounded(1);
spawn_local(run_router(address, capacity, conn_tx, route_rx));
conn_rx.recv().await??;
let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries);
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}
}
async fn router_handle_request(
Route {
request,
response,
}: Route,
state: &mut RouterState,
_endpoint: &Endpoint,
) -> HandleResult {
let RequestData {
id,
command,
} = request;
let entry = state.pending_requests.entry(id);
// We probably shouldn't be sending duplicate id requests.
let Entry::Vacant(entry) = entry else {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
return HandleResult::Ok;
};
let mut effect = RequestEffect::None;
match command {
Command::Set {
ref key,
ref value,
} => {
effect = RequestEffect::Set {
key: key.clone(),
value: value.clone(),
};
}
Command::Unset {
ref key,
} => {
effect = RequestEffect::Clear {
key: key.clone(),
};
}
Command::Insert {
..
} => {
effect = RequestEffect::Insert;
}
Command::SubscribeLive {
ref uuid,
ref notification_sender,
} => {
state.live_queries.insert(*uuid, notification_sender.clone());
if response.send(Ok(DbResponse::Other(CoreValue::None))).await.is_err() {
trace!("Receiver dropped");
}
// There is nothing to send to the server here
return HandleResult::Ok;
}
Command::Kill {
ref uuid,
} => {
state.live_queries.remove(uuid);
}
Command::Use {
..
} => {
state.replay.insert(ReplayMethod::Use, command.clone());
}
Command::Signup {
..
} => {
state.replay.insert(ReplayMethod::Signup, command.clone());
}
Command::Signin {
..
} => {
state.replay.insert(ReplayMethod::Signin, command.clone());
}
Command::Invalidate {
..
} => {
state.replay.insert(ReplayMethod::Invalidate, command.clone());
}
Command::Authenticate {
..
} => {
state.replay.insert(ReplayMethod::Authenticate, command.clone());
}
_ => {}
}
let message = {
let Some(req) = command.into_router_request(Some(id)) else {
let _ = response.send(Err(Error::BackupsNotSupported.into())).await;
return HandleResult::Ok;
};
trace!("Request {:?}", req);
let payload = serialize(&req, true).unwrap();
Message::Binary(payload)
};
match state.sink.send(message).await {
Ok(..) => {
state.last_activity = Instant::now();
entry.insert(PendingRequest {
effect,
response_channel: response,
});
}
Err(error) => {
let error = Error::Ws(error.to_string());
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
return HandleResult::Disconnected;
}
}
HandleResult::Ok
}
async fn router_handle_response(
response: Message,
state: &mut RouterState,
_endpoint: &Endpoint,
) -> HandleResult {
match Response::try_from(&response) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
trace!("{response:?}");
match response.id {
// If `id` is set this is a normal response
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some(pending) = state.pending_requests.remove(&id) {
match pending.effect {
RequestEffect::None => {}
RequestEffect::Insert => {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(CoreValue::Array(value))) =
response.result
{
if value.len() == 1 {
let _ = pending
.response_channel
.send(DbResponse::from_server_result(Ok(
Data::Other(
value.into_iter().next().unwrap(),
),
)))
.await;
} else {
let _ = pending
.response_channel
.send(DbResponse::from_server_result(Ok(
Data::Other(CoreValue::Array(value)),
)))
.await;
}
return HandleResult::Ok;
}
}
RequestEffect::Set {
key,
value,
} => {
state.vars.insert(key, value);
}
RequestEffect::Clear {
key,
} => {
state.vars.shift_remove(&key);
}
}
let _res = pending
.response_channel
.send(DbResponse::from_server_result(response.result))
.await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
}
}
}
// If `id` is not set, this may be a live query notification
None => match response.result {
Ok(Data::Live(notification)) => {
let live_query_id = notification.id;
// Check if this live query is registered
if let Some(sender) = state.live_queries.get(&live_query_id) {
// Send the notification back to the caller or kill live query if the receiver is already dropped
let notification = Notification {
query_id: notification.id.0,
action: Action::from_core(notification.action),
data: notification.result,
};
if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id);
let kill = {
let request = Command::Kill {
uuid: live_query_id.0,
}
.into_router_request(None);
let value = serialize(&request, true).unwrap();
Message::Binary(value)
};
if let Err(error) = state.sink.send(kill).await {
trace!(
"failed to send kill query to the server; {error:?}"
);
return HandleResult::Disconnected;
}
}
}
}
Ok(..) => { /* Ignored responses like pings */ }
Err(error) => error!("{error:?}"),
},
}
}
}
Err(error) => {
#[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response {
id: Option<CoreValue>,
}
// Let's try to find out the ID of the response that failed to deserialise
if let Message::Binary(binary) = response {
if let Ok(Response {
id,
}) = deserialize(&mut &binary[..], true)
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(CoreValue::coerce_to_i64) {
if let Some(req) = state.pending_requests.remove(&id) {
let _res = req.response_channel.send(Err(error)).await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
}
}
} else {
// Unfortunately, we don't know which response failed to deserialize
warn!("Failed to deserialise message; {error:?}");
}
}
}
}
HandleResult::Ok
}
async fn router_reconnect(
state: &mut RouterState,
events: &mut Events<WsEvent>,
endpoint: &Endpoint,
capacity: usize,
) {
loop {
trace!("Reconnecting...");
let connect = WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await;
match connect {
Ok((mut meta, stream)) => {
let (new_sink, new_stream) = stream.split();
state.sink = new_sink;
state.stream = new_stream;
*events = {
let result = match capacity {
0 => meta.observe(ObserveConfig::default()).await,
capacity => meta.observe(Channel::Bounded(capacity).into()).await,
};
match result {
Ok(events) => events,
Err(error) => {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
};
for (_, message) in &state.replay {
let message = message.clone().into_router_request(None);
let message = serialize(&message, true).unwrap();
if let Err(error) = state.sink.send(Message::Binary(message)).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
for (key, value) in &state.vars {
let request = Command::Set {
key: key.as_str().into(),
value: value.clone(),
}
.into_router_request(None);
trace!("Request {:?}", request);
let serialize = serialize(&request, false).unwrap();
if let Err(error) = state.sink.send(Message::Binary(serialize)).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
trace!("Reconnected successfully");
break;
}
Err(error) => {
trace!("Failed to reconnect; {error}");
time::sleep(Duration::from_secs(1)).await;
}
}
}
}
pub(crate) async fn run_router(
endpoint: Endpoint,
capacity: usize,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Route>,
) {
let connect = WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await;
let (mut ws, socket) = match connect {
Ok(pair) => pair,
Err(error) => {
let _ = conn_tx.send(Err(error.into())).await;
return;
}
};
let mut events = {
let result = match capacity {
0 => ws.observe(ObserveConfig::default()).await,
capacity => ws.observe(Channel::Bounded(capacity).into()).await,
};
match result {
Ok(events) => events,
Err(error) => {
let _ = conn_tx.send(Err(error.into())).await;
return;
}
}
};
let _ = conn_tx.send(Ok(())).await;
let ping = {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), "ping".into());
let value = CoreValue::from(request);
let value = serialize(&value, true).unwrap();
Message::Binary(value)
};
let (socket_sink, socket_stream) = socket.split();
let mut state = RouterState::new(socket_sink, socket_stream);
'router: loop {
let mut interval = time::interval(PING_INTERVAL);
// don't bombard the server with pings if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut pinger = IntervalStream::new(interval);
state.last_activity = Instant::now();
state.live_queries.clear();
state.pending_requests.clear();
loop {
futures::select! {
route = route_rx.recv().fuse() => {
let Ok(route) = route else {
match ws.close().await {
Ok(..) => trace!("Connection closed successfully"),
Err(error) => {
warn!("Failed to close database connection; {error}")
}
}
break 'router;
};
match router_handle_request(route, &mut state,&endpoint).await {
HandleResult::Ok => {},
HandleResult::Disconnected => {
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break
}
}
}
message = state.stream.next().fuse() => {
let Some(message) = message else {
// socket disconnected,
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break
};
state.last_activity = Instant::now();
match router_handle_response(message, &mut state,&endpoint).await {
HandleResult::Ok => {},
HandleResult::Disconnected => {
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break
}
}
}
event = events.next().fuse() => {
let Some(event) = event else {
continue;
};
match event {
WsEvent::Error => {
trace!("connection errored");
break;
}
WsEvent::WsErr(error) => {
trace!("{error}");
}
WsEvent::Closed(..) => {
trace!("connection closed");
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break;
}
_ => {}
}
}
_ = pinger.next().fuse() => {
if state.last_activity.elapsed() >= PING_INTERVAL {
trace!("Pinging the server");
if let Err(error) = state.sink.send(ping.clone()).await {
trace!("failed to ping the server; {error:?}");
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break;
}
}
}
}
}
}
}
impl Response {
fn try_from(message: &Message) -> Result<Option<Self>> {
match message {
Message::Text(text) => {
trace!("Received an unexpected text message; {text}");
Ok(None)
}
Message::Binary(binary) => {
deserialize(&mut &binary[..], true).map(Some).map_err(|error| {
Error::ResponseFromBinary {
binary: binary.clone(),
error: bincode::ErrorKind::Custom(error.to_string()).into(),
}
.into()
})
}
}
}
}