Fix clippy tests, eliminate globals, remove unused function return variant. (#4303)
This commit is contained in:
parent
07a88383fa
commit
5706c9b368
28 changed files with 189 additions and 142 deletions
|
@ -457,6 +457,22 @@ pub async fn asynchronous(
|
|||
)
|
||||
}
|
||||
|
||||
fn get_execution_context<'a>(
|
||||
ctx: &'a Context<'_>,
|
||||
doc: Option<&'a CursorDoc<'_>>,
|
||||
) -> Option<(&'a QueryExecutor, &'a CursorDoc<'a>, &'a Thing)> {
|
||||
if let Some(doc) = doc {
|
||||
if let Some(thg) = doc.rid {
|
||||
if let Some(pla) = ctx.get_query_planner() {
|
||||
if let Some(exe) = pla.get_query_executor(&thg.tb) {
|
||||
return Some((exe, doc, thg));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[cfg(all(feature = "scripting", feature = "kv-mem"))]
|
||||
|
@ -537,19 +553,3 @@ mod tests {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_execution_context<'a>(
|
||||
ctx: &'a Context<'_>,
|
||||
doc: Option<&'a CursorDoc<'_>>,
|
||||
) -> Option<(&'a QueryExecutor, &'a CursorDoc<'a>, &'a Thing)> {
|
||||
if let Some(doc) = doc {
|
||||
if let Some(thg) = doc.rid {
|
||||
if let Some(pla) = ctx.get_query_planner() {
|
||||
if let Some(exe) = pla.get_query_executor(&thg.tb) {
|
||||
return Some((exe, doc, thg));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
|
|
@ -70,6 +70,6 @@ where
|
|||
D: ModuleDef,
|
||||
{
|
||||
let (m, promise) = Module::evaluate_def::<D, _>(ctx.clone(), name)?;
|
||||
promise.finish()?;
|
||||
promise.finish::<()>()?;
|
||||
m.get::<_, js::Value>("default")
|
||||
}
|
||||
|
|
|
@ -15,11 +15,7 @@ use jsonwebtoken::{encode, EncodingKey, Header};
|
|||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn signin(
|
||||
kvs: &Datastore,
|
||||
session: &mut Session,
|
||||
vars: Object,
|
||||
) -> Result<Option<String>, Error> {
|
||||
pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> {
|
||||
// Parse the specified variables
|
||||
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
||||
let db = vars.get("DB").or_else(|| vars.get("db"));
|
||||
|
@ -104,7 +100,7 @@ pub async fn db_access(
|
|||
db: String,
|
||||
ac: String,
|
||||
vars: Object,
|
||||
) -> Result<Option<String>, Error> {
|
||||
) -> Result<String, Error> {
|
||||
// Create a new readonly transaction
|
||||
let mut tx = kvs.transaction(Read, Optimistic).await?;
|
||||
// Fetch the specified access method from storage
|
||||
|
@ -203,7 +199,7 @@ pub async fn db_access(
|
|||
// Check the authentication token
|
||||
match enc {
|
||||
// The auth token was created successfully
|
||||
Ok(tk) => Ok(Some(tk)),
|
||||
Ok(tk) => Ok(tk),
|
||||
_ => Err(Error::TokenMakingFailed),
|
||||
}
|
||||
}
|
||||
|
@ -234,7 +230,7 @@ pub async fn db_user(
|
|||
db: String,
|
||||
user: String,
|
||||
pass: String,
|
||||
) -> Result<Option<String>, Error> {
|
||||
) -> Result<String, Error> {
|
||||
match verify_db_creds(kvs, &ns, &db, &user, &pass).await {
|
||||
Ok(u) => {
|
||||
// Create the authentication key
|
||||
|
@ -264,7 +260,7 @@ pub async fn db_user(
|
|||
// Check the authentication token
|
||||
match enc {
|
||||
// The auth token was created successfully
|
||||
Ok(tk) => Ok(Some(tk)),
|
||||
Ok(tk) => Ok(tk),
|
||||
_ => Err(Error::TokenMakingFailed),
|
||||
}
|
||||
}
|
||||
|
@ -278,7 +274,7 @@ pub async fn ns_user(
|
|||
ns: String,
|
||||
user: String,
|
||||
pass: String,
|
||||
) -> Result<Option<String>, Error> {
|
||||
) -> Result<String, Error> {
|
||||
match verify_ns_creds(kvs, &ns, &user, &pass).await {
|
||||
Ok(u) => {
|
||||
// Create the authentication key
|
||||
|
@ -306,7 +302,7 @@ pub async fn ns_user(
|
|||
// Check the authentication token
|
||||
match enc {
|
||||
// The auth token was created successfully
|
||||
Ok(tk) => Ok(Some(tk)),
|
||||
Ok(tk) => Ok(tk),
|
||||
_ => Err(Error::TokenMakingFailed),
|
||||
}
|
||||
}
|
||||
|
@ -320,7 +316,7 @@ pub async fn root_user(
|
|||
session: &mut Session,
|
||||
user: String,
|
||||
pass: String,
|
||||
) -> Result<Option<String>, Error> {
|
||||
) -> Result<String, Error> {
|
||||
match verify_root_creds(kvs, &user, &pass).await {
|
||||
Ok(u) => {
|
||||
// Create the authentication key
|
||||
|
@ -346,7 +342,7 @@ pub async fn root_user(
|
|||
// Check the authentication token
|
||||
match enc {
|
||||
// The auth token was created successfully
|
||||
Ok(tk) => Ok(Some(tk)),
|
||||
Ok(tk) => Ok(tk),
|
||||
_ => Err(Error::TokenMakingFailed),
|
||||
}
|
||||
}
|
||||
|
@ -606,7 +602,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
);
|
||||
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Check that token can be verified with the defined algorithm
|
||||
let val = Validation::new(Algorithm::RS256);
|
||||
// Check that token can be verified with the defined public key
|
||||
|
@ -724,7 +720,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
assert!(!sess.au.has_role(&Role::Owner), "Auth user expected to not have Owner role");
|
||||
assert_eq!(sess.exp, None, "Session expiration is expected to match defined duration");
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Decode token without validation
|
||||
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
|
||||
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
@ -803,7 +799,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
"Session expiration is expected to match the defined duration"
|
||||
);
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Decode token without validation
|
||||
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
|
||||
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
@ -922,7 +918,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
assert!(!sess.au.has_role(&Role::Owner), "Auth user expected to not have Owner role");
|
||||
assert_eq!(sess.exp, None, "Session expiration is expected to match defined duration");
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Decode token without validation
|
||||
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
|
||||
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
@ -992,7 +988,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
"Session expiration is expected to match the defined duration"
|
||||
);
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Decode token without validation
|
||||
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
|
||||
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
@ -1094,7 +1090,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
assert!(!sess.au.has_role(&Role::Owner), "Auth user expected to not have Owner role");
|
||||
assert_eq!(sess.exp, None, "Session expiration is expected to match defined duration");
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Decode token without validation
|
||||
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
|
||||
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
@ -1159,7 +1155,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
|
|||
"Session expiration is expected to match the defined duration"
|
||||
);
|
||||
// Decode token and check that it has been issued as intended
|
||||
if let Ok(Some(tk)) = res {
|
||||
if let Ok(tk) = res {
|
||||
// Decode token without validation
|
||||
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
|
||||
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
|
|
@ -137,24 +137,24 @@ mod tests {
|
|||
let v: HashSet<usize> = dyn_set.iter().cloned().collect();
|
||||
assert_eq!(v, control, "{capacity} - {sample}");
|
||||
// We should not have the element yet
|
||||
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
|
||||
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
|
||||
// The first insertion returns true
|
||||
assert_eq!(dyn_set.insert(sample), true);
|
||||
assert_eq!(dyn_set.contains(&sample), true, "{capacity} - {sample}");
|
||||
assert!(dyn_set.insert(sample));
|
||||
assert!(dyn_set.contains(&sample), "{capacity} - {sample}");
|
||||
// The second insertion returns false
|
||||
assert_eq!(dyn_set.insert(sample), false);
|
||||
assert_eq!(dyn_set.contains(&sample), true, "{capacity} - {sample}");
|
||||
assert!(!dyn_set.insert(sample));
|
||||
assert!(dyn_set.contains(&sample), "{capacity} - {sample}");
|
||||
// We update the control structure
|
||||
control.insert(sample);
|
||||
}
|
||||
// Test removals
|
||||
for sample in 0..capacity {
|
||||
// The first removal returns true
|
||||
assert_eq!(dyn_set.remove(&sample), true);
|
||||
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
|
||||
assert!(dyn_set.remove(&sample));
|
||||
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
|
||||
// The second removal returns false
|
||||
assert_eq!(dyn_set.remove(&sample), false);
|
||||
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
|
||||
assert!(!dyn_set.remove(&sample));
|
||||
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
|
||||
// We update the control structure
|
||||
control.remove(&sample);
|
||||
// The control structure and the dyn_set should be identical
|
||||
|
|
|
@ -315,7 +315,7 @@ mod tests {
|
|||
) -> HashSet<SharedVector> {
|
||||
let mut set = HashSet::new();
|
||||
for (_, obj) in collection.to_vec_ref() {
|
||||
let obj: SharedVector = obj.clone().into();
|
||||
let obj: SharedVector = obj.clone();
|
||||
h.insert(obj.clone());
|
||||
set.insert(obj);
|
||||
h.check_hnsw_properties(set.len());
|
||||
|
@ -367,7 +367,7 @@ mod tests {
|
|||
fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
|
||||
let mut h = HnswFlavor::new(p);
|
||||
insert_collection_hnsw(&mut h, collection);
|
||||
find_collection_hnsw(&h, &collection);
|
||||
find_collection_hnsw(&h, collection);
|
||||
}
|
||||
|
||||
fn new_params(
|
||||
|
@ -447,7 +447,7 @@ mod tests {
|
|||
) -> HashMap<SharedVector, HashSet<DocId>> {
|
||||
let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::new();
|
||||
for (doc_id, obj) in collection.to_vec_ref() {
|
||||
let obj: SharedVector = obj.clone().into();
|
||||
let obj: SharedVector = obj.clone();
|
||||
h.insert(obj.clone(), *doc_id);
|
||||
match map.entry(obj) {
|
||||
Entry::Occupied(mut e) => {
|
||||
|
@ -506,7 +506,7 @@ mod tests {
|
|||
mut map: HashMap<SharedVector, HashSet<DocId>>,
|
||||
) {
|
||||
for (doc_id, obj) in collection.to_vec_ref() {
|
||||
let obj: SharedVector = obj.clone().into();
|
||||
let obj: SharedVector = obj.clone();
|
||||
h.remove(obj.clone(), *doc_id);
|
||||
if let Entry::Occupied(mut e) = map.entry(obj.clone()) {
|
||||
let set = e.get_mut();
|
||||
|
|
|
@ -732,7 +732,7 @@ pub(super) mod tests {
|
|||
distance: &Distance,
|
||||
) -> Self {
|
||||
let mut rng = get_seed_rnd();
|
||||
let gen = RandomItemGenerator::new(&distance, dimension);
|
||||
let gen = RandomItemGenerator::new(distance, dimension);
|
||||
if unique {
|
||||
TestCollection::new_unique(collection_size, vt, dimension, &gen, &mut rng)
|
||||
} else {
|
||||
|
@ -766,7 +766,7 @@ pub(super) mod tests {
|
|||
}
|
||||
let mut coll = TestCollection::Unique(Vec::with_capacity(vector_set.len()));
|
||||
for (i, v) in vector_set.into_iter().enumerate() {
|
||||
coll.add(i as DocId, v.into());
|
||||
coll.add(i as DocId, v);
|
||||
}
|
||||
coll
|
||||
}
|
||||
|
@ -781,7 +781,7 @@ pub(super) mod tests {
|
|||
let mut coll = TestCollection::NonUnique(Vec::with_capacity(collection_size));
|
||||
// Prepare data set
|
||||
for doc_id in 0..collection_size {
|
||||
coll.add(doc_id as DocId, new_random_vec(rng, vector_type, dimension, gen).into());
|
||||
coll.add(doc_id as DocId, new_random_vec(rng, vector_type, dimension, gen));
|
||||
}
|
||||
coll
|
||||
}
|
||||
|
|
|
@ -1732,6 +1732,7 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn test_mtree_collection(
|
||||
stk: &mut Stk,
|
||||
capacities: &[u16],
|
||||
|
|
|
@ -214,7 +214,7 @@ mod test_check_lqs_and_send_notifications {
|
|||
use crate::sql::statements::{CreateStatement, DeleteStatement, LiveStatement};
|
||||
use crate::sql::{Fields, Object, Strand, Table, Thing, Uuid, Value, Values};
|
||||
|
||||
const SETUP: Lazy<Arc<TestSuite>> = Lazy::new(|| Arc::new(block_on(setup_test_suite_init())));
|
||||
static SETUP: Lazy<Arc<TestSuite>> = Lazy::new(|| Arc::new(block_on(setup_test_suite_init())));
|
||||
|
||||
struct TestSuite {
|
||||
ns: String,
|
||||
|
@ -392,9 +392,8 @@ mod test_check_lqs_and_send_notifications {
|
|||
fn a_usable_options(sender: &Sender<Notification>) -> Options {
|
||||
let mut ctx = Context::default();
|
||||
ctx.add_notifications(Some(sender));
|
||||
let opt = Options::default()
|
||||
Options::default()
|
||||
.with_ns(Some(SETUP.ns.clone().into()))
|
||||
.with_db(Some(SETUP.db.clone().into()));
|
||||
opt
|
||||
.with_db(Some(SETUP.db.clone().into()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -162,7 +162,7 @@ mod test {
|
|||
let r = stack
|
||||
.enter(|ctx| async move { parser.parse_query(ctx).await })
|
||||
.finish()
|
||||
.expect(&format!("failed on {}", ident));
|
||||
.unwrap_or_else(|_| panic!("failed on {}", ident));
|
||||
|
||||
assert_eq!(
|
||||
r,
|
||||
|
|
|
@ -8,15 +8,15 @@
|
|||
//!
|
||||
//! There are a bunch of common patterns for which this module has some confinence functions.
|
||||
//! - Whenever only one token can be next you should use the [`expected!`] macro. This macro
|
||||
//! ensures that the given token type is next and if not returns a parser error.
|
||||
//! ensures that the given token type is next and if not returns a parser error.
|
||||
//! - Whenever a limited set of tokens can be next it is common to match the token kind and then
|
||||
//! have a catch all arm which calles the macro [`unexpected!`]. This macro will raise an parse
|
||||
//! error with information about the type of token it recieves and what it expected.
|
||||
//! have a catch all arm which calles the macro [`unexpected!`]. This macro will raise an parse
|
||||
//! error with information about the type of token it recieves and what it expected.
|
||||
//! - If a single token can be optionally next use [`Parser::eat`] this function returns a bool
|
||||
//! depending on if the given tokenkind was eaten.
|
||||
//! depending on if the given tokenkind was eaten.
|
||||
//! - If a closing delimiting token is expected use [`Parser::expect_closing_delimiter`]. This
|
||||
//! function will raise an error if the expected delimiter isn't the next token. This error will
|
||||
//! also point to which delimiter the parser expected to be closed.
|
||||
//! function will raise an error if the expected delimiter isn't the next token. This error will
|
||||
//! also point to which delimiter the parser expected to be closed.
|
||||
//!
|
||||
//! ## Far Token Peek
|
||||
//!
|
||||
|
|
|
@ -405,7 +405,7 @@ fn parse_define_token_on_scope() {
|
|||
}
|
||||
);
|
||||
assert_eq!(stmt.comment, Some(Strand("bar".to_string())));
|
||||
assert_eq!(stmt.if_not_exists, false);
|
||||
assert!(!stmt.if_not_exists);
|
||||
match stmt.kind {
|
||||
AccessType::Record(ac) => {
|
||||
assert_eq!(ac.signup, None);
|
||||
|
@ -480,7 +480,7 @@ fn parse_define_token_jwks_on_scope() {
|
|||
}
|
||||
);
|
||||
assert_eq!(stmt.comment, Some(Strand("bar".to_string())));
|
||||
assert_eq!(stmt.if_not_exists, false);
|
||||
assert!(!stmt.if_not_exists);
|
||||
match stmt.kind {
|
||||
AccessType::Record(ac) => {
|
||||
assert_eq!(ac.signup, None);
|
||||
|
@ -523,7 +523,7 @@ fn parse_define_scope() {
|
|||
session: Some(Duration::from_secs(1)),
|
||||
}
|
||||
);
|
||||
assert_eq!(stmt.if_not_exists, false);
|
||||
assert!(!stmt.if_not_exists);
|
||||
match stmt.kind {
|
||||
AccessType::Record(ac) => {
|
||||
assert_eq!(ac.signup, Some(Value::Bool(true)));
|
||||
|
@ -942,7 +942,7 @@ fn parse_define_access_record() {
|
|||
}
|
||||
);
|
||||
assert_eq!(stmt.comment, Some(Strand("bar".to_string())));
|
||||
assert_eq!(stmt.if_not_exists, false);
|
||||
assert!(!stmt.if_not_exists);
|
||||
match stmt.kind {
|
||||
AccessType::Record(ac) => {
|
||||
assert_eq!(ac.signup, None);
|
||||
|
@ -988,7 +988,7 @@ fn parse_define_access_record() {
|
|||
}
|
||||
);
|
||||
assert_eq!(stmt.comment, None);
|
||||
assert_eq!(stmt.if_not_exists, false);
|
||||
assert!(!stmt.if_not_exists);
|
||||
match stmt.kind {
|
||||
AccessType::Record(ac) => {
|
||||
assert_eq!(ac.signup, Some(Value::Bool(true)));
|
||||
|
|
|
@ -533,7 +533,7 @@ mod tests {
|
|||
let r = stack
|
||||
.enter(|ctx| async move { parser.parse_thing(ctx).await })
|
||||
.finish()
|
||||
.expect(&format!("failed on {}", ident))
|
||||
.unwrap_or_else(|_| panic!("failed on {}", ident))
|
||||
.id;
|
||||
assert_eq!(r, Id::String(ident.to_string()),);
|
||||
|
||||
|
@ -541,7 +541,7 @@ mod tests {
|
|||
let r = stack
|
||||
.enter(|ctx| async move { parser.parse_query(ctx).await })
|
||||
.finish()
|
||||
.expect(&format!("failed on {}", ident));
|
||||
.unwrap_or_else(|_| panic!("failed on {}", ident));
|
||||
|
||||
assert_eq!(
|
||||
r,
|
||||
|
|
|
@ -218,7 +218,7 @@ fn hnsw() -> HnswIndex {
|
|||
fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> HnswIndex {
|
||||
let mut h = hnsw();
|
||||
for (id, content) in samples {
|
||||
h.index_document(&id, content).unwrap();
|
||||
h.index_document(id, content).unwrap();
|
||||
}
|
||||
h
|
||||
}
|
||||
|
@ -226,7 +226,7 @@ fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> HnswIndex {
|
|||
async fn insert_objects_db(session: &Session, create_index: bool, inserts: &[String]) -> Datastore {
|
||||
let ds = init_datastore(session, create_index).await;
|
||||
for sql in inserts {
|
||||
ds.execute(sql, session, None).await.expect(&sql);
|
||||
ds.execute(sql, session, None).await.expect(sql);
|
||||
}
|
||||
ds
|
||||
}
|
||||
|
@ -249,8 +249,8 @@ async fn knn_lookup_objects(h: &HnswIndex, samples: &[Vec<Number>]) {
|
|||
|
||||
async fn knn_lookup_objects_db(ds: &Datastore, session: &Session, selects: &[String]) {
|
||||
for sql in selects {
|
||||
let mut res = ds.execute(sql, session, None).await.expect(&sql);
|
||||
let res = res.remove(0).result.expect(&sql);
|
||||
let mut res = ds.execute(sql, session, None).await.expect(sql);
|
||||
let res = res.remove(0).result.expect(sql);
|
||||
if let Value::Array(a) = &res {
|
||||
assert_eq!(a.len(), NN);
|
||||
} else {
|
||||
|
|
|
@ -199,7 +199,7 @@ async fn knn_lookup_object(mt: &MTreeIndex, ctx: &Context<'_>, object: Vec<Numbe
|
|||
let mut stack = TreeStack::new();
|
||||
stack
|
||||
.enter(|stk| async {
|
||||
let chk = MTreeConditionChecker::new(&ctx);
|
||||
let chk = MTreeConditionChecker::new(ctx);
|
||||
let r = mt.knn_search(stk, ctx, &object, knn, chk).await.unwrap();
|
||||
assert_eq!(r.len(), knn);
|
||||
})
|
||||
|
|
|
@ -612,7 +612,7 @@ mod tests {
|
|||
//
|
||||
let compress = |v: &Vec<u8>| {
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(&v).unwrap();
|
||||
encoder.write_all(v).unwrap();
|
||||
encoder.finish().unwrap()
|
||||
};
|
||||
// Generate a random vector
|
||||
|
|
|
@ -974,10 +974,10 @@ async fn define_statement_index_on_schemafull_without_permission() -> Result<(),
|
|||
";
|
||||
let dbs = new_ds().await?;
|
||||
let ses = Session::owner().with_ns("test").with_db("test");
|
||||
let mut res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
let res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 2);
|
||||
//
|
||||
skip_ok(&mut res, 1)?;
|
||||
skip_ok(res, 1)?;
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
let s = format!("{:?}", tmp);
|
||||
|
|
|
@ -1881,7 +1881,12 @@ async fn function_math_deg2rad() -> Result<(), Error> {
|
|||
RETURN math::deg2rad(math::rad2deg(0.7853981633974483));
|
||||
"#;
|
||||
Test::new(sql).await?.expect_floats(
|
||||
&[0.7853981633974483, -1.5707963267948966, 6.283185307179586, 0.7853981633974483],
|
||||
&[
|
||||
std::f64::consts::FRAC_PI_4,
|
||||
-std::f64::consts::FRAC_PI_2,
|
||||
std::f64::consts::TAU,
|
||||
std::f64::consts::FRAC_PI_4,
|
||||
],
|
||||
f64::EPSILON,
|
||||
)?;
|
||||
Ok(())
|
||||
|
@ -2032,7 +2037,7 @@ async fn function_math_log10() -> Result<(), Error> {
|
|||
"#;
|
||||
Test::new(sql)
|
||||
.await?
|
||||
.expect_floats(&[0.43429738512450866, 0.3010299956639812, 0.0], f64::EPSILON)?
|
||||
.expect_floats(&[0.43429738512450866, std::f64::consts::LOG10_2, 0.0], f64::EPSILON)?
|
||||
.expect_vals(&["Math::Neg_Inf", "NaN", "true"])?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -595,10 +595,10 @@ async fn select_array_group_group_by() -> Result<(), Error> {
|
|||
";
|
||||
let dbs = new_ds().await?;
|
||||
let ses = Session::owner().with_ns("test").with_db("test");
|
||||
let mut res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
let res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 5);
|
||||
//
|
||||
skip_ok(&mut res, 4)?;
|
||||
skip_ok(res, 4)?;
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::parse(
|
||||
|
@ -633,10 +633,10 @@ async fn select_array_count_subquery_group_by() -> Result<(), Error> {
|
|||
"#;
|
||||
let dbs = new_ds().await?;
|
||||
let ses = Session::owner().with_ns("test").with_db("test");
|
||||
let mut res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
let res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 5);
|
||||
//
|
||||
skip_ok(&mut res, 3)?;
|
||||
skip_ok(res, 3)?;
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::parse(
|
||||
|
|
|
@ -272,6 +272,7 @@ impl Test {
|
|||
/// This method will panic if the responses list is empty, indicating that there are no more responses to retrieve.
|
||||
/// The panic message will include the last position in the responses list before it was emptied.
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn next(&mut self) -> Result<Response, Error> {
|
||||
assert!(!self.responses.is_empty(), "No response left - last position: {}", self.pos);
|
||||
self.pos += 1;
|
||||
|
|
|
@ -1930,7 +1930,7 @@ async fn select_with_record_id_link_no_index() -> Result<(), Error> {
|
|||
SELECT * FROM i WHERE t.name = 'h';
|
||||
SELECT * FROM i WHERE t.name = 'h' EXPLAIN;
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 8);
|
||||
skip_ok(&mut res, 6)?;
|
||||
|
@ -1989,7 +1989,7 @@ async fn select_with_record_id_link_index() -> Result<(), Error> {
|
|||
SELECT * FROM i WHERE t.name = 'h' EXPLAIN;
|
||||
SELECT * FROM i WHERE t.name = 'h';
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 10);
|
||||
skip_ok(&mut res, 8)?;
|
||||
|
@ -2054,7 +2054,7 @@ async fn select_with_record_id_link_unique_index() -> Result<(), Error> {
|
|||
SELECT * FROM i WHERE t.name = 'h' EXPLAIN;
|
||||
SELECT * FROM i WHERE t.name = 'h';
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 10);
|
||||
skip_ok(&mut res, 8)?;
|
||||
|
@ -2118,7 +2118,7 @@ async fn select_with_record_id_link_unique_remote_index() -> Result<(), Error> {
|
|||
SELECT * FROM i WHERE t.name IN ['a', 'b'] EXPLAIN;
|
||||
SELECT * FROM i WHERE t.name IN ['a', 'b'];
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 10);
|
||||
skip_ok(&mut res, 8)?;
|
||||
|
@ -2185,7 +2185,7 @@ async fn select_with_record_id_link_full_text_index() -> Result<(), Error> {
|
|||
SELECT * FROM i WHERE t.name @@ 'world' EXPLAIN;
|
||||
SELECT * FROM i WHERE t.name @@ 'world';
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
|
||||
assert_eq!(res.len(), 9);
|
||||
skip_ok(&mut res, 7)?;
|
||||
|
@ -2242,7 +2242,7 @@ async fn select_with_record_id_link_full_text_no_record_index() -> Result<(), Er
|
|||
SELECT * FROM i WHERE t.name @@ 'world' EXPLAIN;
|
||||
SELECT * FROM i WHERE t.name @@ 'world';
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
|
||||
assert_eq!(res.len(), 8);
|
||||
skip_ok(&mut res, 6)?;
|
||||
|
@ -2301,7 +2301,7 @@ async fn select_with_record_id_index() -> Result<(), Error> {
|
|||
SELECT * FROM t WHERE a:2 IN links;
|
||||
SELECT * FROM t WHERE a:2 IN links EXPLAIN;
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
|
||||
let expected = Value::parse(
|
||||
r#"[
|
||||
|
@ -2447,7 +2447,7 @@ async fn select_with_exact_operator() -> Result<(), Error> {
|
|||
SELECT * FROM t WHERE i == 2;
|
||||
SELECT * FROM t WHERE i == 2 EXPLAIN;
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 8);
|
||||
skip_ok(&mut res, 4)?;
|
||||
|
@ -2550,7 +2550,7 @@ async fn select_with_non_boolean_expression() -> Result<(), Error> {
|
|||
SELECT * FROM t WHERE v > $p3 - ( math::max([0, $p1]) + $p1 );
|
||||
SELECT * FROM t WHERE v > $p3 - ( math::max([0, $p1]) + $p1 ) EXPLAIN;
|
||||
";
|
||||
let mut res = dbs.execute(&sql, &ses, None).await?;
|
||||
let mut res = dbs.execute(sql, &ses, None).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 15);
|
||||
skip_ok(&mut res, 5)?;
|
||||
|
|
|
@ -328,10 +328,10 @@ async fn select_mtree_knn_with_condition() -> Result<(), Error> {
|
|||
";
|
||||
let dbs = new_ds().await?;
|
||||
let ses = Session::owner().with_ns("test").with_db("test");
|
||||
let mut res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
let res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 5);
|
||||
//
|
||||
skip_ok(&mut res, 3)?;
|
||||
skip_ok(res, 3)?;
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::parse(
|
||||
|
@ -400,10 +400,10 @@ async fn select_hnsw_knn_with_condition() -> Result<(), Error> {
|
|||
";
|
||||
let dbs = new_ds().await?;
|
||||
let ses = Session::owner().with_ns("test").with_db("test");
|
||||
let mut res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
let res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 5);
|
||||
//
|
||||
skip_ok(&mut res, 3)?;
|
||||
skip_ok(res, 3)?;
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::parse(
|
||||
|
@ -471,10 +471,10 @@ async fn select_bruteforce_knn_with_condition() -> Result<(), Error> {
|
|||
";
|
||||
let dbs = new_ds().await?;
|
||||
let ses = Session::owner().with_ns("test").with_db("test");
|
||||
let mut res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
let res = &mut dbs.execute(sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 4);
|
||||
//
|
||||
skip_ok(&mut res, 2)?;
|
||||
skip_ok(res, 2)?;
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::parse(
|
||||
|
|
|
@ -24,7 +24,7 @@ use crate::cli::CF;
|
|||
use crate::cnf;
|
||||
use crate::err::Error;
|
||||
use crate::net::signals::graceful_shutdown;
|
||||
use crate::rpc::notifications;
|
||||
use crate::rpc::{notifications, RpcState};
|
||||
use crate::telemetry::metrics::HttpMetricsLayer;
|
||||
use axum::response::Redirect;
|
||||
use axum::routing::get;
|
||||
|
@ -155,7 +155,7 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
|
|||
.max_age(Duration::from_secs(86400)),
|
||||
);
|
||||
|
||||
let axum_app = Router::new()
|
||||
let axum_app = Router::<Arc<RpcState>, _>::new()
|
||||
// Redirect until we provide a UI
|
||||
.route("/", get(|| async { Redirect::temporary(cnf::APP_ENDPOINT) }))
|
||||
.route("/status", get(|| async {}))
|
||||
|
@ -177,10 +177,16 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
|
|||
|
||||
// Get a new server handler
|
||||
let handle = Handle::new();
|
||||
|
||||
let rpc_state = Arc::new(RpcState::new());
|
||||
|
||||
// Setup the graceful shutdown handler
|
||||
let shutdown_handler = graceful_shutdown(ct.clone(), handle.clone());
|
||||
let shutdown_handler = graceful_shutdown(rpc_state.clone(), ct.clone(), handle.clone());
|
||||
|
||||
let axum_app = axum_app.with_state(rpc_state.clone());
|
||||
|
||||
// Spawn a task to handle notifications
|
||||
tokio::spawn(async move { notifications(ct.clone()).await });
|
||||
tokio::spawn(async move { notifications(rpc_state, ct.clone()).await });
|
||||
// If a certificate and key are specified then setup TLS
|
||||
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
|
||||
// Configure certificate and private key used by https
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::cnf;
|
||||
use crate::dbs::DB;
|
||||
|
@ -8,7 +9,8 @@ use crate::rpc::connection::Connection;
|
|||
use crate::rpc::format::HttpFormat;
|
||||
use crate::rpc::post_context::PostRpcContext;
|
||||
use crate::rpc::response::IntoRpcResponse;
|
||||
use crate::rpc::WEBSOCKETS;
|
||||
use crate::rpc::RpcState;
|
||||
use axum::extract::State;
|
||||
use axum::routing::get;
|
||||
use axum::routing::post;
|
||||
use axum::TypedHeader;
|
||||
|
@ -32,12 +34,11 @@ use super::headers::ContentType;
|
|||
|
||||
use surrealdb::rpc::rpc_context::RpcContext;
|
||||
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
pub(super) fn router<B>() -> Router<Arc<RpcState>, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new().route("/rpc", get(get_handler)).route("/rpc", post(post_handler))
|
||||
}
|
||||
|
@ -46,6 +47,7 @@ async fn get_handler(
|
|||
ws: WebSocketUpgrade,
|
||||
Extension(id): Extension<RequestId>,
|
||||
Extension(sess): Extension<Session>,
|
||||
State(rpc_state): State<Arc<RpcState>>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Check if there is a request id header specified
|
||||
let id = match id.header_value().is_empty() {
|
||||
|
@ -65,7 +67,7 @@ async fn get_handler(
|
|||
},
|
||||
};
|
||||
// Check if a connection with this id already exists
|
||||
if WEBSOCKETS.read().await.contains_key(&id) {
|
||||
if rpc_state.web_sockets.read().await.contains_key(&id) {
|
||||
return Err(Error::Request);
|
||||
}
|
||||
// Now let's upgrade the WebSocket connection
|
||||
|
@ -77,10 +79,10 @@ async fn get_handler(
|
|||
// Set the maximum WebSocket message size
|
||||
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
|
||||
// Handle the WebSocket upgrade and process messages
|
||||
.on_upgrade(move |socket| handle_socket(socket, sess, id)))
|
||||
.on_upgrade(move |socket| handle_socket(rpc_state, socket, sess, id)))
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
|
||||
async fn handle_socket(state: Arc<RpcState>, ws: WebSocket, sess: Session, id: Uuid) {
|
||||
// Check if there is a WebSocket protocol specified
|
||||
let format = match ws.protocol().map(HeaderValue::to_str) {
|
||||
// Any selected protocol will always be a valie value
|
||||
|
@ -90,7 +92,7 @@ async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
|
|||
};
|
||||
// Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here
|
||||
// Create a new connection instance
|
||||
let rpc = Connection::new(id, sess, format);
|
||||
let rpc = Connection::new(state, id, sess, format);
|
||||
// Serve the socket connection requests
|
||||
Connection::serve(rpc, ws).await;
|
||||
}
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum_server::Handle;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::{err::Error, rpc, telemetry};
|
||||
use crate::{
|
||||
err::Error,
|
||||
rpc::{self, RpcState},
|
||||
telemetry,
|
||||
};
|
||||
|
||||
/// Start a graceful shutdown:
|
||||
/// * Signal the Axum Handle when a shutdown signal is received.
|
||||
|
@ -10,7 +16,11 @@ use crate::{err::Error, rpc, telemetry};
|
|||
/// * Flush all telemetry data.
|
||||
///
|
||||
/// A second signal will force an immediate shutdown.
|
||||
pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHandle<()> {
|
||||
pub fn graceful_shutdown(
|
||||
state: Arc<RpcState>,
|
||||
ct: CancellationToken,
|
||||
http_handle: Handle,
|
||||
) -> JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let result = listen().await.expect("Failed to listen to shutdown signal");
|
||||
info!(target: super::LOG, "{} received. Waiting for graceful shutdown... A second signal will force an immediate shutdown", result);
|
||||
|
@ -18,6 +28,7 @@ pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHand
|
|||
let shutdown = {
|
||||
let http_handle = http_handle.clone();
|
||||
let ct = ct.clone();
|
||||
let state_clone = state.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Stop accepting new HTTP requests and wait until all connections are closed
|
||||
|
@ -26,7 +37,7 @@ pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHand
|
|||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
rpc::graceful_shutdown().await;
|
||||
rpc::graceful_shutdown(state_clone).await;
|
||||
|
||||
ct.cancel();
|
||||
|
||||
|
@ -52,7 +63,7 @@ pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHand
|
|||
http_handle.shutdown();
|
||||
|
||||
// Close all WebSocket connections immediately
|
||||
rpc::shutdown();
|
||||
rpc::shutdown(state);
|
||||
|
||||
// Cancel cancellation token
|
||||
ct.cancel();
|
||||
|
|
|
@ -27,9 +27,9 @@ struct Success {
|
|||
}
|
||||
|
||||
impl Success {
|
||||
fn new(token: Option<String>) -> Success {
|
||||
fn new(token: String) -> Success {
|
||||
Success {
|
||||
token,
|
||||
token: Some(token),
|
||||
code: 200,
|
||||
details: String::from("Authentication succeeded"),
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ async fn handler(
|
|||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
||||
// Text serialization
|
||||
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
||||
Some(Accept::TextPlain) => Ok(output::text(v)),
|
||||
// Internal serialization
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
||||
// Return nothing
|
||||
|
|
|
@ -126,7 +126,7 @@ where
|
|||
FailureClass: fmt::Display,
|
||||
{
|
||||
fn on_failure(&mut self, error: FailureClass, latency: Duration, span: &Span) {
|
||||
span.record("error_message", &error.to_string());
|
||||
span.record("error_message", error.to_string());
|
||||
span.record("http.latency.ms", latency.as_millis());
|
||||
tracing::event!(Level::ERROR, error = error.to_string(), "response failed");
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::dbs::DB;
|
|||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::format::WsFormat;
|
||||
use crate::rpc::response::{failure, IntoRpcResponse};
|
||||
use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
|
||||
use crate::rpc::CONN_CLOSED_ERR;
|
||||
use crate::telemetry;
|
||||
use crate::telemetry::metrics::ws::RequestContext;
|
||||
use crate::telemetry::traces::rpc::span_for_request;
|
||||
|
@ -33,6 +33,8 @@ use tracing::Instrument;
|
|||
use tracing::Span;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::RpcState;
|
||||
|
||||
pub struct Connection {
|
||||
pub(crate) id: Uuid,
|
||||
pub(crate) format: Format,
|
||||
|
@ -41,11 +43,17 @@ pub struct Connection {
|
|||
pub(crate) limiter: Arc<Semaphore>,
|
||||
pub(crate) canceller: CancellationToken,
|
||||
pub(crate) channels: (Sender<Message>, Receiver<Message>),
|
||||
pub(crate) state: Arc<RpcState>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Instantiate a new RPC
|
||||
pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc<RwLock<Connection>> {
|
||||
pub fn new(
|
||||
state: Arc<RpcState>,
|
||||
id: Uuid,
|
||||
mut session: Session,
|
||||
format: Format,
|
||||
) -> Arc<RwLock<Connection>> {
|
||||
// Enable real-time mode
|
||||
session.rt = true;
|
||||
// Create and store the RPC connection
|
||||
|
@ -57,18 +65,26 @@ impl Connection {
|
|||
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
|
||||
canceller: CancellationToken::new(),
|
||||
channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS),
|
||||
state,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Serve the RPC endpoint
|
||||
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
|
||||
// Get the WebSocket ID
|
||||
let id = rpc.read().await.id;
|
||||
let rpc_lock = rpc.read().await;
|
||||
// Get the WebSocket ID
|
||||
let id = rpc_lock.id;
|
||||
let state = rpc_lock.state.clone();
|
||||
|
||||
// Split the socket into sending and receiving streams
|
||||
let (sender, receiver) = ws.split();
|
||||
// Create an internal channel for sending and receiving
|
||||
let internal_sender = rpc.read().await.channels.0.clone();
|
||||
let internal_receiver = rpc.read().await.channels.1.clone();
|
||||
let internal_sender = rpc_lock.channels.0.clone();
|
||||
let internal_receiver = rpc_lock.channels.1.clone();
|
||||
|
||||
// drop the lock early so rpc is free to be written to.
|
||||
std::mem::drop(rpc_lock);
|
||||
|
||||
trace!("WebSocket {} connected", id);
|
||||
|
||||
|
@ -77,7 +93,7 @@ impl Connection {
|
|||
}
|
||||
|
||||
// Add this WebSocket to the list
|
||||
WEBSOCKETS.write().await.insert(id, rpc.clone());
|
||||
state.web_sockets.write().await.insert(id, rpc.clone());
|
||||
|
||||
// Spawn async tasks for the WebSocket
|
||||
let mut tasks = JoinSet::new();
|
||||
|
@ -97,11 +113,11 @@ impl Connection {
|
|||
trace!("WebSocket {} disconnected", id);
|
||||
|
||||
// Remove this WebSocket from the list
|
||||
WEBSOCKETS.write().await.remove(&id);
|
||||
state.web_sockets.write().await.remove(&id);
|
||||
|
||||
// Remove all live queries
|
||||
let mut gc = Vec::new();
|
||||
LIVE_QUERIES.write().await.retain(|key, value| {
|
||||
state.live_queries.write().await.retain(|key, value| {
|
||||
if value == &id {
|
||||
trace!("Removing live query: {}", key);
|
||||
gc.push(*key);
|
||||
|
@ -379,12 +395,12 @@ impl RpcContext for Connection {
|
|||
const LQ_SUPPORT: bool = true;
|
||||
|
||||
async fn handle_live(&self, lqid: &Uuid) {
|
||||
LIVE_QUERIES.write().await.insert(*lqid, self.id);
|
||||
self.state.live_queries.write().await.insert(*lqid, self.id);
|
||||
trace!("Registered live query {} on websocket {}", lqid, self.id);
|
||||
}
|
||||
|
||||
async fn handle_kill(&self, lqid: &Uuid) {
|
||||
if let Some(id) = LIVE_QUERIES.write().await.remove(lqid) {
|
||||
if let Some(id) = self.state.live_queries.write().await.remove(lqid) {
|
||||
trace!("Unregistered live query {} on websocket {}", lqid, id);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ use crate::dbs::DB;
|
|||
use crate::rpc::connection::Connection;
|
||||
use crate::rpc::response::success;
|
||||
use crate::telemetry::metrics::ws::NotificationContext;
|
||||
use once_cell::sync::Lazy;
|
||||
use opentelemetry::Context as TelemetryContext;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
@ -25,13 +24,24 @@ type WebSockets = RwLock<HashMap<Uuid, WebSocket>>;
|
|||
/// Mapping of LIVE Query ID to WebSocket ID
|
||||
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
||||
|
||||
/// Stores the currently connected WebSockets
|
||||
pub(crate) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
||||
/// Stores the currently initiated LIVE queries
|
||||
pub(crate) static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
||||
pub struct RpcState {
|
||||
/// Stores the currently connected WebSockets
|
||||
pub web_sockets: WebSockets,
|
||||
/// Stores the currently initiated LIVE queries
|
||||
pub live_queries: LiveQueries,
|
||||
}
|
||||
|
||||
impl RpcState {
|
||||
pub fn new() -> Self {
|
||||
RpcState {
|
||||
web_sockets: WebSockets::default(),
|
||||
live_queries: LiveQueries::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs notification delivery to the WebSockets
|
||||
pub(crate) async fn notifications(canceller: CancellationToken) {
|
||||
pub(crate) async fn notifications(state: Arc<RpcState>, canceller: CancellationToken) {
|
||||
// Listen to the notifications channel
|
||||
if let Some(channel) = DB.get().unwrap().notifications() {
|
||||
// Loop continuously
|
||||
|
@ -44,9 +54,9 @@ pub(crate) async fn notifications(canceller: CancellationToken) {
|
|||
// Receive a notification on the channel
|
||||
Ok(notification) = channel.recv() => {
|
||||
// Find which WebSocket the notification belongs to
|
||||
if let Some(id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||
if let Some(id) = state.live_queries.read().await.get(¬ification.id) {
|
||||
// Check to see if the WebSocket exists
|
||||
if let Some(rpc) = WEBSOCKETS.read().await.get(id) {
|
||||
if let Some(rpc) = state.web_sockets.read().await.get(id) {
|
||||
// Serialize the message to send
|
||||
let message = success(None, notification);
|
||||
// Add metrics
|
||||
|
@ -69,21 +79,21 @@ pub(crate) async fn notifications(canceller: CancellationToken) {
|
|||
}
|
||||
|
||||
/// Closes all WebSocket connections, waiting for graceful shutdown
|
||||
pub(crate) async fn graceful_shutdown() {
|
||||
pub(crate) async fn graceful_shutdown(state: Arc<RpcState>) {
|
||||
// Close WebSocket connections, ensuring queued messages are processed
|
||||
for (_, rpc) in WEBSOCKETS.read().await.iter() {
|
||||
for (_, rpc) in state.web_sockets.read().await.iter() {
|
||||
rpc.read().await.canceller.cancel();
|
||||
}
|
||||
// Wait for all existing WebSocket connections to finish sending
|
||||
while WEBSOCKETS.read().await.len() > 0 {
|
||||
while state.web_sockets.read().await.len() > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Forces a fast shutdown of all WebSocket connections
|
||||
pub(crate) fn shutdown() {
|
||||
pub(crate) fn shutdown(state: Arc<RpcState>) {
|
||||
// Close all WebSocket connections immediately
|
||||
if let Ok(mut writer) = WEBSOCKETS.try_write() {
|
||||
if let Ok(mut writer) = state.web_sockets.try_write() {
|
||||
writer.drain();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue