Fix clippy tests, eliminate globals, remove unused function return variant. (#4303)

This commit is contained in:
Mees Delzenne 2024-07-05 11:34:43 +02:00 committed by GitHub
parent 07a88383fa
commit 5706c9b368
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 189 additions and 142 deletions

View file

@ -457,6 +457,22 @@ pub async fn asynchronous(
)
}
fn get_execution_context<'a>(
ctx: &'a Context<'_>,
doc: Option<&'a CursorDoc<'_>>,
) -> Option<(&'a QueryExecutor, &'a CursorDoc<'a>, &'a Thing)> {
if let Some(doc) = doc {
if let Some(thg) = doc.rid {
if let Some(pla) = ctx.get_query_planner() {
if let Some(exe) = pla.get_query_executor(&thg.tb) {
return Some((exe, doc, thg));
}
}
}
}
None
}
#[cfg(test)]
mod tests {
#[cfg(all(feature = "scripting", feature = "kv-mem"))]
@ -537,19 +553,3 @@ mod tests {
}
}
}
fn get_execution_context<'a>(
ctx: &'a Context<'_>,
doc: Option<&'a CursorDoc<'_>>,
) -> Option<(&'a QueryExecutor, &'a CursorDoc<'a>, &'a Thing)> {
if let Some(doc) = doc {
if let Some(thg) = doc.rid {
if let Some(pla) = ctx.get_query_planner() {
if let Some(exe) = pla.get_query_executor(&thg.tb) {
return Some((exe, doc, thg));
}
}
}
}
None
}

View file

@ -70,6 +70,6 @@ where
D: ModuleDef,
{
let (m, promise) = Module::evaluate_def::<D, _>(ctx.clone(), name)?;
promise.finish()?;
promise.finish::<()>()?;
m.get::<_, js::Value>("default")
}

View file

@ -15,11 +15,7 @@ use jsonwebtoken::{encode, EncodingKey, Header};
use std::sync::Arc;
use uuid::Uuid;
pub async fn signin(
kvs: &Datastore,
session: &mut Session,
vars: Object,
) -> Result<Option<String>, Error> {
pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> {
// Parse the specified variables
let ns = vars.get("NS").or_else(|| vars.get("ns"));
let db = vars.get("DB").or_else(|| vars.get("db"));
@ -104,7 +100,7 @@ pub async fn db_access(
db: String,
ac: String,
vars: Object,
) -> Result<Option<String>, Error> {
) -> Result<String, Error> {
// Create a new readonly transaction
let mut tx = kvs.transaction(Read, Optimistic).await?;
// Fetch the specified access method from storage
@ -203,7 +199,7 @@ pub async fn db_access(
// Check the authentication token
match enc {
// The auth token was created successfully
Ok(tk) => Ok(Some(tk)),
Ok(tk) => Ok(tk),
_ => Err(Error::TokenMakingFailed),
}
}
@ -234,7 +230,7 @@ pub async fn db_user(
db: String,
user: String,
pass: String,
) -> Result<Option<String>, Error> {
) -> Result<String, Error> {
match verify_db_creds(kvs, &ns, &db, &user, &pass).await {
Ok(u) => {
// Create the authentication key
@ -264,7 +260,7 @@ pub async fn db_user(
// Check the authentication token
match enc {
// The auth token was created successfully
Ok(tk) => Ok(Some(tk)),
Ok(tk) => Ok(tk),
_ => Err(Error::TokenMakingFailed),
}
}
@ -278,7 +274,7 @@ pub async fn ns_user(
ns: String,
user: String,
pass: String,
) -> Result<Option<String>, Error> {
) -> Result<String, Error> {
match verify_ns_creds(kvs, &ns, &user, &pass).await {
Ok(u) => {
// Create the authentication key
@ -306,7 +302,7 @@ pub async fn ns_user(
// Check the authentication token
match enc {
// The auth token was created successfully
Ok(tk) => Ok(Some(tk)),
Ok(tk) => Ok(tk),
_ => Err(Error::TokenMakingFailed),
}
}
@ -320,7 +316,7 @@ pub async fn root_user(
session: &mut Session,
user: String,
pass: String,
) -> Result<Option<String>, Error> {
) -> Result<String, Error> {
match verify_root_creds(kvs, &user, &pass).await {
Ok(u) => {
// Create the authentication key
@ -346,7 +342,7 @@ pub async fn root_user(
// Check the authentication token
match enc {
// The auth token was created successfully
Ok(tk) => Ok(Some(tk)),
Ok(tk) => Ok(tk),
_ => Err(Error::TokenMakingFailed),
}
}
@ -606,7 +602,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
);
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Check that token can be verified with the defined algorithm
let val = Validation::new(Algorithm::RS256);
// Check that token can be verified with the defined public key
@ -724,7 +720,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
assert!(!sess.au.has_role(&Role::Owner), "Auth user expected to not have Owner role");
assert_eq!(sess.exp, None, "Session expiration is expected to match defined duration");
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Decode token without validation
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
@ -803,7 +799,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
"Session expiration is expected to match the defined duration"
);
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Decode token without validation
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
@ -922,7 +918,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
assert!(!sess.au.has_role(&Role::Owner), "Auth user expected to not have Owner role");
assert_eq!(sess.exp, None, "Session expiration is expected to match defined duration");
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Decode token without validation
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
@ -992,7 +988,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
"Session expiration is expected to match the defined duration"
);
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Decode token without validation
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
@ -1094,7 +1090,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
assert!(!sess.au.has_role(&Role::Owner), "Auth user expected to not have Owner role");
assert_eq!(sess.exp, None, "Session expiration is expected to match defined duration");
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Decode token without validation
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
@ -1159,7 +1155,7 @@ dn/RsYEONbwQSjIfMPkvxF+8HQ==
"Session expiration is expected to match the defined duration"
);
// Decode token and check that it has been issued as intended
if let Ok(Some(tk)) = res {
if let Ok(tk) = res {
// Decode token without validation
let token_data = decode::<Claims>(&tk, &DecodingKey::from_secret(&[]), &{
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);

View file

@ -137,24 +137,24 @@ mod tests {
let v: HashSet<usize> = dyn_set.iter().cloned().collect();
assert_eq!(v, control, "{capacity} - {sample}");
// We should not have the element yet
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
// The first insertion returns true
assert_eq!(dyn_set.insert(sample), true);
assert_eq!(dyn_set.contains(&sample), true, "{capacity} - {sample}");
assert!(dyn_set.insert(sample));
assert!(dyn_set.contains(&sample), "{capacity} - {sample}");
// The second insertion returns false
assert_eq!(dyn_set.insert(sample), false);
assert_eq!(dyn_set.contains(&sample), true, "{capacity} - {sample}");
assert!(!dyn_set.insert(sample));
assert!(dyn_set.contains(&sample), "{capacity} - {sample}");
// We update the control structure
control.insert(sample);
}
// Test removals
for sample in 0..capacity {
// The first removal returns true
assert_eq!(dyn_set.remove(&sample), true);
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
assert!(dyn_set.remove(&sample));
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
// The second removal returns false
assert_eq!(dyn_set.remove(&sample), false);
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
assert!(!dyn_set.remove(&sample));
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
// We update the control structure
control.remove(&sample);
// The control structure and the dyn_set should be identical

View file

@ -315,7 +315,7 @@ mod tests {
) -> HashSet<SharedVector> {
let mut set = HashSet::new();
for (_, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone().into();
let obj: SharedVector = obj.clone();
h.insert(obj.clone());
set.insert(obj);
h.check_hnsw_properties(set.len());
@ -367,7 +367,7 @@ mod tests {
fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
let mut h = HnswFlavor::new(p);
insert_collection_hnsw(&mut h, collection);
find_collection_hnsw(&h, &collection);
find_collection_hnsw(&h, collection);
}
fn new_params(
@ -447,7 +447,7 @@ mod tests {
) -> HashMap<SharedVector, HashSet<DocId>> {
let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::new();
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone().into();
let obj: SharedVector = obj.clone();
h.insert(obj.clone(), *doc_id);
match map.entry(obj) {
Entry::Occupied(mut e) => {
@ -506,7 +506,7 @@ mod tests {
mut map: HashMap<SharedVector, HashSet<DocId>>,
) {
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone().into();
let obj: SharedVector = obj.clone();
h.remove(obj.clone(), *doc_id);
if let Entry::Occupied(mut e) = map.entry(obj.clone()) {
let set = e.get_mut();

View file

@ -732,7 +732,7 @@ pub(super) mod tests {
distance: &Distance,
) -> Self {
let mut rng = get_seed_rnd();
let gen = RandomItemGenerator::new(&distance, dimension);
let gen = RandomItemGenerator::new(distance, dimension);
if unique {
TestCollection::new_unique(collection_size, vt, dimension, &gen, &mut rng)
} else {
@ -766,7 +766,7 @@ pub(super) mod tests {
}
let mut coll = TestCollection::Unique(Vec::with_capacity(vector_set.len()));
for (i, v) in vector_set.into_iter().enumerate() {
coll.add(i as DocId, v.into());
coll.add(i as DocId, v);
}
coll
}
@ -781,7 +781,7 @@ pub(super) mod tests {
let mut coll = TestCollection::NonUnique(Vec::with_capacity(collection_size));
// Prepare data set
for doc_id in 0..collection_size {
coll.add(doc_id as DocId, new_random_vec(rng, vector_type, dimension, gen).into());
coll.add(doc_id as DocId, new_random_vec(rng, vector_type, dimension, gen));
}
coll
}

View file

@ -1732,6 +1732,7 @@ mod tests {
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn test_mtree_collection(
stk: &mut Stk,
capacities: &[u16],

View file

@ -214,7 +214,7 @@ mod test_check_lqs_and_send_notifications {
use crate::sql::statements::{CreateStatement, DeleteStatement, LiveStatement};
use crate::sql::{Fields, Object, Strand, Table, Thing, Uuid, Value, Values};
const SETUP: Lazy<Arc<TestSuite>> = Lazy::new(|| Arc::new(block_on(setup_test_suite_init())));
static SETUP: Lazy<Arc<TestSuite>> = Lazy::new(|| Arc::new(block_on(setup_test_suite_init())));
struct TestSuite {
ns: String,
@ -392,9 +392,8 @@ mod test_check_lqs_and_send_notifications {
fn a_usable_options(sender: &Sender<Notification>) -> Options {
let mut ctx = Context::default();
ctx.add_notifications(Some(sender));
let opt = Options::default()
Options::default()
.with_ns(Some(SETUP.ns.clone().into()))
.with_db(Some(SETUP.db.clone().into()));
opt
.with_db(Some(SETUP.db.clone().into()))
}
}

View file

@ -162,7 +162,7 @@ mod test {
let r = stack
.enter(|ctx| async move { parser.parse_query(ctx).await })
.finish()
.expect(&format!("failed on {}", ident));
.unwrap_or_else(|_| panic!("failed on {}", ident));
assert_eq!(
r,

View file

@ -8,15 +8,15 @@
//!
//! There are a bunch of common patterns for which this module has some confinence functions.
//! - Whenever only one token can be next you should use the [`expected!`] macro. This macro
//! ensures that the given token type is next and if not returns a parser error.
//! ensures that the given token type is next and if not returns a parser error.
//! - Whenever a limited set of tokens can be next it is common to match the token kind and then
//! have a catch all arm which calles the macro [`unexpected!`]. This macro will raise an parse
//! error with information about the type of token it recieves and what it expected.
//! have a catch all arm which calles the macro [`unexpected!`]. This macro will raise an parse
//! error with information about the type of token it recieves and what it expected.
//! - If a single token can be optionally next use [`Parser::eat`] this function returns a bool
//! depending on if the given tokenkind was eaten.
//! depending on if the given tokenkind was eaten.
//! - If a closing delimiting token is expected use [`Parser::expect_closing_delimiter`]. This
//! function will raise an error if the expected delimiter isn't the next token. This error will
//! also point to which delimiter the parser expected to be closed.
//! function will raise an error if the expected delimiter isn't the next token. This error will
//! also point to which delimiter the parser expected to be closed.
//!
//! ## Far Token Peek
//!

View file

@ -405,7 +405,7 @@ fn parse_define_token_on_scope() {
}
);
assert_eq!(stmt.comment, Some(Strand("bar".to_string())));
assert_eq!(stmt.if_not_exists, false);
assert!(!stmt.if_not_exists);
match stmt.kind {
AccessType::Record(ac) => {
assert_eq!(ac.signup, None);
@ -480,7 +480,7 @@ fn parse_define_token_jwks_on_scope() {
}
);
assert_eq!(stmt.comment, Some(Strand("bar".to_string())));
assert_eq!(stmt.if_not_exists, false);
assert!(!stmt.if_not_exists);
match stmt.kind {
AccessType::Record(ac) => {
assert_eq!(ac.signup, None);
@ -523,7 +523,7 @@ fn parse_define_scope() {
session: Some(Duration::from_secs(1)),
}
);
assert_eq!(stmt.if_not_exists, false);
assert!(!stmt.if_not_exists);
match stmt.kind {
AccessType::Record(ac) => {
assert_eq!(ac.signup, Some(Value::Bool(true)));
@ -942,7 +942,7 @@ fn parse_define_access_record() {
}
);
assert_eq!(stmt.comment, Some(Strand("bar".to_string())));
assert_eq!(stmt.if_not_exists, false);
assert!(!stmt.if_not_exists);
match stmt.kind {
AccessType::Record(ac) => {
assert_eq!(ac.signup, None);
@ -988,7 +988,7 @@ fn parse_define_access_record() {
}
);
assert_eq!(stmt.comment, None);
assert_eq!(stmt.if_not_exists, false);
assert!(!stmt.if_not_exists);
match stmt.kind {
AccessType::Record(ac) => {
assert_eq!(ac.signup, Some(Value::Bool(true)));

View file

@ -533,7 +533,7 @@ mod tests {
let r = stack
.enter(|ctx| async move { parser.parse_thing(ctx).await })
.finish()
.expect(&format!("failed on {}", ident))
.unwrap_or_else(|_| panic!("failed on {}", ident))
.id;
assert_eq!(r, Id::String(ident.to_string()),);
@ -541,7 +541,7 @@ mod tests {
let r = stack
.enter(|ctx| async move { parser.parse_query(ctx).await })
.finish()
.expect(&format!("failed on {}", ident));
.unwrap_or_else(|_| panic!("failed on {}", ident));
assert_eq!(
r,

View file

@ -218,7 +218,7 @@ fn hnsw() -> HnswIndex {
fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> HnswIndex {
let mut h = hnsw();
for (id, content) in samples {
h.index_document(&id, content).unwrap();
h.index_document(id, content).unwrap();
}
h
}
@ -226,7 +226,7 @@ fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> HnswIndex {
async fn insert_objects_db(session: &Session, create_index: bool, inserts: &[String]) -> Datastore {
let ds = init_datastore(session, create_index).await;
for sql in inserts {
ds.execute(sql, session, None).await.expect(&sql);
ds.execute(sql, session, None).await.expect(sql);
}
ds
}
@ -249,8 +249,8 @@ async fn knn_lookup_objects(h: &HnswIndex, samples: &[Vec<Number>]) {
async fn knn_lookup_objects_db(ds: &Datastore, session: &Session, selects: &[String]) {
for sql in selects {
let mut res = ds.execute(sql, session, None).await.expect(&sql);
let res = res.remove(0).result.expect(&sql);
let mut res = ds.execute(sql, session, None).await.expect(sql);
let res = res.remove(0).result.expect(sql);
if let Value::Array(a) = &res {
assert_eq!(a.len(), NN);
} else {

View file

@ -199,7 +199,7 @@ async fn knn_lookup_object(mt: &MTreeIndex, ctx: &Context<'_>, object: Vec<Numbe
let mut stack = TreeStack::new();
stack
.enter(|stk| async {
let chk = MTreeConditionChecker::new(&ctx);
let chk = MTreeConditionChecker::new(ctx);
let r = mt.knn_search(stk, ctx, &object, knn, chk).await.unwrap();
assert_eq!(r.len(), knn);
})

View file

@ -612,7 +612,7 @@ mod tests {
//
let compress = |v: &Vec<u8>| {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&v).unwrap();
encoder.write_all(v).unwrap();
encoder.finish().unwrap()
};
// Generate a random vector

View file

@ -974,10 +974,10 @@ async fn define_statement_index_on_schemafull_without_permission() -> Result<(),
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 2);
//
skip_ok(&mut res, 1)?;
skip_ok(res, 1)?;
//
let tmp = res.remove(0).result;
let s = format!("{:?}", tmp);

View file

@ -1881,7 +1881,12 @@ async fn function_math_deg2rad() -> Result<(), Error> {
RETURN math::deg2rad(math::rad2deg(0.7853981633974483));
"#;
Test::new(sql).await?.expect_floats(
&[0.7853981633974483, -1.5707963267948966, 6.283185307179586, 0.7853981633974483],
&[
std::f64::consts::FRAC_PI_4,
-std::f64::consts::FRAC_PI_2,
std::f64::consts::TAU,
std::f64::consts::FRAC_PI_4,
],
f64::EPSILON,
)?;
Ok(())
@ -2032,7 +2037,7 @@ async fn function_math_log10() -> Result<(), Error> {
"#;
Test::new(sql)
.await?
.expect_floats(&[0.43429738512450866, 0.3010299956639812, 0.0], f64::EPSILON)?
.expect_floats(&[0.43429738512450866, std::f64::consts::LOG10_2, 0.0], f64::EPSILON)?
.expect_vals(&["Math::Neg_Inf", "NaN", "true"])?;
Ok(())
}

View file

@ -595,10 +595,10 @@ async fn select_array_group_group_by() -> Result<(), Error> {
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 4)?;
skip_ok(res, 4)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -633,10 +633,10 @@ async fn select_array_count_subquery_group_by() -> Result<(), Error> {
"#;
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
skip_ok(res, 3)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(

View file

@ -272,6 +272,7 @@ impl Test {
/// This method will panic if the responses list is empty, indicating that there are no more responses to retrieve.
/// The panic message will include the last position in the responses list before it was emptied.
#[allow(dead_code)]
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Result<Response, Error> {
assert!(!self.responses.is_empty(), "No response left - last position: {}", self.pos);
self.pos += 1;

View file

@ -1930,7 +1930,7 @@ async fn select_with_record_id_link_no_index() -> Result<(), Error> {
SELECT * FROM i WHERE t.name = 'h';
SELECT * FROM i WHERE t.name = 'h' EXPLAIN;
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 8);
skip_ok(&mut res, 6)?;
@ -1989,7 +1989,7 @@ async fn select_with_record_id_link_index() -> Result<(), Error> {
SELECT * FROM i WHERE t.name = 'h' EXPLAIN;
SELECT * FROM i WHERE t.name = 'h';
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 10);
skip_ok(&mut res, 8)?;
@ -2054,7 +2054,7 @@ async fn select_with_record_id_link_unique_index() -> Result<(), Error> {
SELECT * FROM i WHERE t.name = 'h' EXPLAIN;
SELECT * FROM i WHERE t.name = 'h';
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 10);
skip_ok(&mut res, 8)?;
@ -2118,7 +2118,7 @@ async fn select_with_record_id_link_unique_remote_index() -> Result<(), Error> {
SELECT * FROM i WHERE t.name IN ['a', 'b'] EXPLAIN;
SELECT * FROM i WHERE t.name IN ['a', 'b'];
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 10);
skip_ok(&mut res, 8)?;
@ -2185,7 +2185,7 @@ async fn select_with_record_id_link_full_text_index() -> Result<(), Error> {
SELECT * FROM i WHERE t.name @@ 'world' EXPLAIN;
SELECT * FROM i WHERE t.name @@ 'world';
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 9);
skip_ok(&mut res, 7)?;
@ -2242,7 +2242,7 @@ async fn select_with_record_id_link_full_text_no_record_index() -> Result<(), Er
SELECT * FROM i WHERE t.name @@ 'world' EXPLAIN;
SELECT * FROM i WHERE t.name @@ 'world';
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 8);
skip_ok(&mut res, 6)?;
@ -2301,7 +2301,7 @@ async fn select_with_record_id_index() -> Result<(), Error> {
SELECT * FROM t WHERE a:2 IN links;
SELECT * FROM t WHERE a:2 IN links EXPLAIN;
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
let expected = Value::parse(
r#"[
@ -2447,7 +2447,7 @@ async fn select_with_exact_operator() -> Result<(), Error> {
SELECT * FROM t WHERE i == 2;
SELECT * FROM t WHERE i == 2 EXPLAIN;
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 8);
skip_ok(&mut res, 4)?;
@ -2550,7 +2550,7 @@ async fn select_with_non_boolean_expression() -> Result<(), Error> {
SELECT * FROM t WHERE v > $p3 - ( math::max([0, $p1]) + $p1 );
SELECT * FROM t WHERE v > $p3 - ( math::max([0, $p1]) + $p1 ) EXPLAIN;
";
let mut res = dbs.execute(&sql, &ses, None).await?;
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 15);
skip_ok(&mut res, 5)?;

View file

@ -328,10 +328,10 @@ async fn select_mtree_knn_with_condition() -> Result<(), Error> {
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
skip_ok(res, 3)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -400,10 +400,10 @@ async fn select_hnsw_knn_with_condition() -> Result<(), Error> {
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
skip_ok(res, 3)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -471,10 +471,10 @@ async fn select_bruteforce_knn_with_condition() -> Result<(), Error> {
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 4);
//
skip_ok(&mut res, 2)?;
skip_ok(res, 2)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(

View file

@ -24,7 +24,7 @@ use crate::cli::CF;
use crate::cnf;
use crate::err::Error;
use crate::net::signals::graceful_shutdown;
use crate::rpc::notifications;
use crate::rpc::{notifications, RpcState};
use crate::telemetry::metrics::HttpMetricsLayer;
use axum::response::Redirect;
use axum::routing::get;
@ -155,7 +155,7 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
.max_age(Duration::from_secs(86400)),
);
let axum_app = Router::new()
let axum_app = Router::<Arc<RpcState>, _>::new()
// Redirect until we provide a UI
.route("/", get(|| async { Redirect::temporary(cnf::APP_ENDPOINT) }))
.route("/status", get(|| async {}))
@ -177,10 +177,16 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
// Get a new server handler
let handle = Handle::new();
let rpc_state = Arc::new(RpcState::new());
// Setup the graceful shutdown handler
let shutdown_handler = graceful_shutdown(ct.clone(), handle.clone());
let shutdown_handler = graceful_shutdown(rpc_state.clone(), ct.clone(), handle.clone());
let axum_app = axum_app.with_state(rpc_state.clone());
// Spawn a task to handle notifications
tokio::spawn(async move { notifications(ct.clone()).await });
tokio::spawn(async move { notifications(rpc_state, ct.clone()).await });
// If a certificate and key are specified then setup TLS
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
// Configure certificate and private key used by https

View file

@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use std::ops::Deref;
use std::sync::Arc;
use crate::cnf;
use crate::dbs::DB;
@ -8,7 +9,8 @@ use crate::rpc::connection::Connection;
use crate::rpc::format::HttpFormat;
use crate::rpc::post_context::PostRpcContext;
use crate::rpc::response::IntoRpcResponse;
use crate::rpc::WEBSOCKETS;
use crate::rpc::RpcState;
use axum::extract::State;
use axum::routing::get;
use axum::routing::post;
use axum::TypedHeader;
@ -32,12 +34,11 @@ use super::headers::ContentType;
use surrealdb::rpc::rpc_context::RpcContext;
pub(super) fn router<S, B>() -> Router<S, B>
pub(super) fn router<B>() -> Router<Arc<RpcState>, B>
where
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
Router::new().route("/rpc", get(get_handler)).route("/rpc", post(post_handler))
}
@ -46,6 +47,7 @@ async fn get_handler(
ws: WebSocketUpgrade,
Extension(id): Extension<RequestId>,
Extension(sess): Extension<Session>,
State(rpc_state): State<Arc<RpcState>>,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Check if there is a request id header specified
let id = match id.header_value().is_empty() {
@ -65,7 +67,7 @@ async fn get_handler(
},
};
// Check if a connection with this id already exists
if WEBSOCKETS.read().await.contains_key(&id) {
if rpc_state.web_sockets.read().await.contains_key(&id) {
return Err(Error::Request);
}
// Now let's upgrade the WebSocket connection
@ -77,10 +79,10 @@ async fn get_handler(
// Set the maximum WebSocket message size
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
// Handle the WebSocket upgrade and process messages
.on_upgrade(move |socket| handle_socket(socket, sess, id)))
.on_upgrade(move |socket| handle_socket(rpc_state, socket, sess, id)))
}
async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
async fn handle_socket(state: Arc<RpcState>, ws: WebSocket, sess: Session, id: Uuid) {
// Check if there is a WebSocket protocol specified
let format = match ws.protocol().map(HeaderValue::to_str) {
// Any selected protocol will always be a valie value
@ -90,7 +92,7 @@ async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
};
// Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here
// Create a new connection instance
let rpc = Connection::new(id, sess, format);
let rpc = Connection::new(state, id, sess, format);
// Serve the socket connection requests
Connection::serve(rpc, ws).await;
}

View file

@ -1,8 +1,14 @@
use std::sync::Arc;
use axum_server::Handle;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::{err::Error, rpc, telemetry};
use crate::{
err::Error,
rpc::{self, RpcState},
telemetry,
};
/// Start a graceful shutdown:
/// * Signal the Axum Handle when a shutdown signal is received.
@ -10,7 +16,11 @@ use crate::{err::Error, rpc, telemetry};
/// * Flush all telemetry data.
///
/// A second signal will force an immediate shutdown.
pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHandle<()> {
pub fn graceful_shutdown(
state: Arc<RpcState>,
ct: CancellationToken,
http_handle: Handle,
) -> JoinHandle<()> {
tokio::spawn(async move {
let result = listen().await.expect("Failed to listen to shutdown signal");
info!(target: super::LOG, "{} received. Waiting for graceful shutdown... A second signal will force an immediate shutdown", result);
@ -18,6 +28,7 @@ pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHand
let shutdown = {
let http_handle = http_handle.clone();
let ct = ct.clone();
let state_clone = state.clone();
tokio::spawn(async move {
// Stop accepting new HTTP requests and wait until all connections are closed
@ -26,7 +37,7 @@ pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHand
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
rpc::graceful_shutdown().await;
rpc::graceful_shutdown(state_clone).await;
ct.cancel();
@ -52,7 +63,7 @@ pub fn graceful_shutdown(ct: CancellationToken, http_handle: Handle) -> JoinHand
http_handle.shutdown();
// Close all WebSocket connections immediately
rpc::shutdown();
rpc::shutdown(state);
// Cancel cancellation token
ct.cancel();

View file

@ -27,9 +27,9 @@ struct Success {
}
impl Success {
fn new(token: Option<String>) -> Success {
fn new(token: String) -> Success {
Success {
token,
token: Some(token),
code: 200,
details: String::from("Authentication succeeded"),
}
@ -71,7 +71,7 @@ async fn handler(
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
// Text serialization
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
Some(Accept::TextPlain) => Ok(output::text(v)),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Return nothing

View file

@ -126,7 +126,7 @@ where
FailureClass: fmt::Display,
{
fn on_failure(&mut self, error: FailureClass, latency: Duration, span: &Span) {
span.record("error_message", &error.to_string());
span.record("error_message", error.to_string());
span.record("http.latency.ms", latency.as_millis());
tracing::event!(Level::ERROR, error = error.to_string(), "response failed");
}

View file

@ -5,7 +5,7 @@ use crate::dbs::DB;
use crate::rpc::failure::Failure;
use crate::rpc::format::WsFormat;
use crate::rpc::response::{failure, IntoRpcResponse};
use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
use crate::rpc::CONN_CLOSED_ERR;
use crate::telemetry;
use crate::telemetry::metrics::ws::RequestContext;
use crate::telemetry::traces::rpc::span_for_request;
@ -33,6 +33,8 @@ use tracing::Instrument;
use tracing::Span;
use uuid::Uuid;
use super::RpcState;
pub struct Connection {
pub(crate) id: Uuid,
pub(crate) format: Format,
@ -41,11 +43,17 @@ pub struct Connection {
pub(crate) limiter: Arc<Semaphore>,
pub(crate) canceller: CancellationToken,
pub(crate) channels: (Sender<Message>, Receiver<Message>),
pub(crate) state: Arc<RpcState>,
}
impl Connection {
/// Instantiate a new RPC
pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc<RwLock<Connection>> {
pub fn new(
state: Arc<RpcState>,
id: Uuid,
mut session: Session,
format: Format,
) -> Arc<RwLock<Connection>> {
// Enable real-time mode
session.rt = true;
// Create and store the RPC connection
@ -57,18 +65,26 @@ impl Connection {
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
canceller: CancellationToken::new(),
channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS),
state,
}))
}
/// Serve the RPC endpoint
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
// Get the WebSocket ID
let id = rpc.read().await.id;
let rpc_lock = rpc.read().await;
// Get the WebSocket ID
let id = rpc_lock.id;
let state = rpc_lock.state.clone();
// Split the socket into sending and receiving streams
let (sender, receiver) = ws.split();
// Create an internal channel for sending and receiving
let internal_sender = rpc.read().await.channels.0.clone();
let internal_receiver = rpc.read().await.channels.1.clone();
let internal_sender = rpc_lock.channels.0.clone();
let internal_receiver = rpc_lock.channels.1.clone();
// drop the lock early so rpc is free to be written to.
std::mem::drop(rpc_lock);
trace!("WebSocket {} connected", id);
@ -77,7 +93,7 @@ impl Connection {
}
// Add this WebSocket to the list
WEBSOCKETS.write().await.insert(id, rpc.clone());
state.web_sockets.write().await.insert(id, rpc.clone());
// Spawn async tasks for the WebSocket
let mut tasks = JoinSet::new();
@ -97,11 +113,11 @@ impl Connection {
trace!("WebSocket {} disconnected", id);
// Remove this WebSocket from the list
WEBSOCKETS.write().await.remove(&id);
state.web_sockets.write().await.remove(&id);
// Remove all live queries
let mut gc = Vec::new();
LIVE_QUERIES.write().await.retain(|key, value| {
state.live_queries.write().await.retain(|key, value| {
if value == &id {
trace!("Removing live query: {}", key);
gc.push(*key);
@ -379,12 +395,12 @@ impl RpcContext for Connection {
const LQ_SUPPORT: bool = true;
async fn handle_live(&self, lqid: &Uuid) {
LIVE_QUERIES.write().await.insert(*lqid, self.id);
self.state.live_queries.write().await.insert(*lqid, self.id);
trace!("Registered live query {} on websocket {}", lqid, self.id);
}
async fn handle_kill(&self, lqid: &Uuid) {
if let Some(id) = LIVE_QUERIES.write().await.remove(lqid) {
if let Some(id) = self.state.live_queries.write().await.remove(lqid) {
trace!("Unregistered live query {} on websocket {}", lqid, id);
}
}

View file

@ -8,7 +8,6 @@ use crate::dbs::DB;
use crate::rpc::connection::Connection;
use crate::rpc::response::success;
use crate::telemetry::metrics::ws::NotificationContext;
use once_cell::sync::Lazy;
use opentelemetry::Context as TelemetryContext;
use std::collections::HashMap;
use std::sync::Arc;
@ -25,13 +24,24 @@ type WebSockets = RwLock<HashMap<Uuid, WebSocket>>;
/// Mapping of LIVE Query ID to WebSocket ID
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
/// Stores the currently connected WebSockets
pub(crate) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
/// Stores the currently initiated LIVE queries
pub(crate) static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
pub struct RpcState {
/// Stores the currently connected WebSockets
pub web_sockets: WebSockets,
/// Stores the currently initiated LIVE queries
pub live_queries: LiveQueries,
}
impl RpcState {
pub fn new() -> Self {
RpcState {
web_sockets: WebSockets::default(),
live_queries: LiveQueries::default(),
}
}
}
/// Performs notification delivery to the WebSockets
pub(crate) async fn notifications(canceller: CancellationToken) {
pub(crate) async fn notifications(state: Arc<RpcState>, canceller: CancellationToken) {
// Listen to the notifications channel
if let Some(channel) = DB.get().unwrap().notifications() {
// Loop continuously
@ -44,9 +54,9 @@ pub(crate) async fn notifications(canceller: CancellationToken) {
// Receive a notification on the channel
Ok(notification) = channel.recv() => {
// Find which WebSocket the notification belongs to
if let Some(id) = LIVE_QUERIES.read().await.get(&notification.id) {
if let Some(id) = state.live_queries.read().await.get(&notification.id) {
// Check to see if the WebSocket exists
if let Some(rpc) = WEBSOCKETS.read().await.get(id) {
if let Some(rpc) = state.web_sockets.read().await.get(id) {
// Serialize the message to send
let message = success(None, notification);
// Add metrics
@ -69,21 +79,21 @@ pub(crate) async fn notifications(canceller: CancellationToken) {
}
/// Closes all WebSocket connections, waiting for graceful shutdown
pub(crate) async fn graceful_shutdown() {
pub(crate) async fn graceful_shutdown(state: Arc<RpcState>) {
// Close WebSocket connections, ensuring queued messages are processed
for (_, rpc) in WEBSOCKETS.read().await.iter() {
for (_, rpc) in state.web_sockets.read().await.iter() {
rpc.read().await.canceller.cancel();
}
// Wait for all existing WebSocket connections to finish sending
while WEBSOCKETS.read().await.len() > 0 {
while state.web_sockets.read().await.len() > 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
/// Forces a fast shutdown of all WebSocket connections
pub(crate) fn shutdown() {
pub(crate) fn shutdown(state: Arc<RpcState>) {
// Close all WebSocket connections immediately
if let Ok(mut writer) = WEBSOCKETS.try_write() {
if let Ok(mut writer) = state.web_sockets.try_write() {
writer.drain();
}
}