From 21975548f21701de83f3c680be3944b1a149e7e7 Mon Sep 17 00:00:00 2001 From: Gerard Guillemas Martos Date: Tue, 12 Mar 2024 11:34:35 +0100 Subject: [PATCH] Move JWKS cache storage to memory (#3649) --- core/src/iam/jwks.rs | 86 ++++++++++++++++++++++---------------------- core/src/kvs/ds.rs | 12 +++++++ core/src/lib.rs | 2 +- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/core/src/iam/jwks.rs b/core/src/iam/jwks.rs index 0fd12803..661088c5 100644 --- a/core/src/iam/jwks.rs +++ b/core/src/iam/jwks.rs @@ -8,7 +8,17 @@ use once_cell::sync::Lazy; use reqwest::{Client, Url}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; +use std::collections::HashMap; use std::str::FromStr; +use std::sync::Arc; +use tokio::sync::RwLock; + +pub(crate) type JwksCache = HashMap; +#[derive(Clone, Serialize, Deserialize)] +pub(crate) struct JwksCacheEntry { + jwks: JwkSet, + time: DateTime, +} #[cfg(test)] static CACHE_EXPIRATION: Lazy = Lazy::new(|| Duration::seconds(1)); @@ -66,19 +76,21 @@ pub(super) async fn config( kid: &str, url: &str, ) -> Result<(DecodingKey, Validation), Error> { + // Retrieve JWKS cache + let cache = kvs.jwks_cache(); // Attempt to fetch relevant JWK object either from local cache or remote location - let jwk = match fetch_jwks_from_cache(url).await { - Ok(jwks_cache) => { + let jwk = match fetch_jwks_from_cache(cache, url).await { + Some(jwks) => { trace!("Successfully fetched JWKS object from local cache"); // Check that the cached JWKS object has not expired yet - if Utc::now().signed_duration_since(jwks_cache.time) < *CACHE_EXPIRATION { + if Utc::now().signed_duration_since(jwks.time) < *CACHE_EXPIRATION { // Attempt to find JWK in JWKS object from local cache - match jwks_cache.jwks.find(kid) { + match jwks.jwks.find(kid) { Some(jwk) => jwk.to_owned(), _ => { trace!("Could not find valid JWK object with key identifier '{kid}' in cached JWKS object"); // Check that the cached JWKS object has not been recently updated - if Utc::now().signed_duration_since(jwks_cache.time) < *CACHE_COOLDOWN { + if Utc::now().signed_duration_since(jwks.time) < *CACHE_COOLDOWN { debug!("Refused to refresh cache before cooldown period is over"); return Err(Error::InvalidAuth); // Return opaque error } @@ -90,7 +102,7 @@ pub(super) async fn config( find_jwk_from_url(kvs, url, kid).await? } } - Err(_) => { + None => { trace!("Could not fetch JWKS object from local cache"); find_jwk_from_url(kvs, url, kid).await? } @@ -145,8 +157,10 @@ async fn find_jwk_from_url(kvs: &Datastore, url: &str, kid: &str) -> Result { trace!("Successfully fetched JWKS object from remote location"); // Attempt to find JWK in JWKS by the key identifier @@ -199,7 +213,7 @@ fn check_capabilities_url(kvs: &Datastore, url: &str) -> Result<(), Error> { } // Attempts to fetch a JWKS object from a remote location and stores it in the cache if successful -async fn fetch_jwks_from_url(url: &str) -> Result { +async fn fetch_jwks_from_url(cache: &Arc>, url: &str) -> Result { let client = Client::new(); #[cfg(not(target_arch = "wasm32"))] let res = client.get(url).timeout((*REMOTE_TIMEOUT).to_std().unwrap()).send().await?; @@ -214,11 +228,9 @@ async fn fetch_jwks_from_url(url: &str) -> Result { match serde_json::from_slice::(&jwks) { Ok(jwks) => { // If successful, cache the JWKS object by its URL - match store_jwks_in_cache(jwks.clone(), url).await { - Ok(_) => trace!("Successfully stored JWKS object in local cache"), - Err(err) => { - warn!("Failed to store JWKS object in local cache: '{}'", err); - } + match store_jwks_in_cache(cache, jwks.clone(), url).await { + None => trace!("Successfully added JWKS object to local cache"), + Some(_) => trace!("Successfully updated JWKS object in local cache"), }; Ok(jwks) @@ -230,50 +242,40 @@ async fn fetch_jwks_from_url(url: &str) -> Result { } } -#[derive(Serialize, Deserialize)] -struct JwksCache { - jwks: JwkSet, - time: DateTime, -} - // Attempts to fetch a JWKS object from the local cache -async fn fetch_jwks_from_cache(url: &str) -> Result { - let path = cache_path_from_url(url); - let bytes = crate::obs::get(&path).await?; +async fn fetch_jwks_from_cache( + cache: &Arc>, + url: &str, +) -> Option { + let path = cache_key_from_url(url); + let cache = cache.read().await; - match serde_json::from_slice::(&bytes) { - Ok(jwks_cache) => Ok(jwks_cache), - Err(err) => { - warn!("Failed to parse malformed JWKS object: '{}'", err); - Err(Error::InvalidAuth) // Return opaque error - } - } + cache.get(&path).cloned() } // Attempts to store a JWKS object in the local cache -async fn store_jwks_in_cache(jwks: JwkSet, url: &str) -> Result<(), Error> { - let jwks_cache = JwksCache { +async fn store_jwks_in_cache( + cache: &Arc>, + jwks: JwkSet, + url: &str, +) -> Option { + let entry = JwksCacheEntry { jwks, time: Utc::now(), }; - let path = cache_path_from_url(url); + let path = cache_key_from_url(url); + let mut cache = cache.write().await; - match serde_json::to_vec(&jwks_cache) { - Ok(data) => crate::obs::put(&path, data).await, - Err(err) => { - warn!("Failed to cache malformed JWKS object: '{}'", err); - Err(Error::InvalidAuth) // Return opaque error - } - } + cache.insert(path, entry) } -// Generates a unique cache path for a given URL string -fn cache_path_from_url(url: &str) -> String { +// Generates a unique cache key for a given URL string +fn cache_key_from_url(url: &str) -> String { let mut hasher = Sha256::new(); hasher.update(url); let result = hasher.finalize(); - format!("jwks/{:x}.json", result) + format!("{:x}", result) } #[cfg(test)] diff --git a/core/src/kvs/ds.rs b/core/src/kvs/ds.rs index 8cc7fbcb..d4c748eb 100644 --- a/core/src/kvs/ds.rs +++ b/core/src/kvs/ds.rs @@ -26,6 +26,8 @@ use crate::dbs::{ use crate::doc::Document; use crate::err::Error; use crate::fflags::FFLAGS; +#[cfg(feature = "jwks")] +use crate::iam::jwks::JwksCache; use crate::iam::{Action, Auth, Error as IamError, Resource, Role}; use crate::idx::trees::store::IndexStores; use crate::key::root::hb::Hb; @@ -87,6 +89,9 @@ pub struct Datastore { clock: Arc, // The index store cache index_stores: IndexStores, + #[cfg(feature = "jwks")] + // The JWKS object cache + jwks_cache: Arc>, } /// We always want to be circulating the live query information @@ -359,6 +364,8 @@ impl Datastore { index_stores: IndexStores::default(), local_live_queries: Arc::new(RwLock::new(BTreeMap::new())), cf_watermarks: Arc::new(RwLock::new(BTreeMap::new())), + #[cfg(feature = "jwks")] + jwks_cache: Arc::new(RwLock::new(JwksCache::new())), }) } @@ -438,6 +445,11 @@ impl Datastore { self.capabilities.allows_network_target(net_target) } + #[cfg(feature = "jwks")] + pub(crate) fn jwks_cache(&self) -> &Arc> { + &self.jwks_cache + } + /// Setup the initial credentials /// Trigger the `unreachable definition` compilation error, probably due to this issue: /// https://github.com/rust-lang/rust/issues/111370 diff --git a/core/src/lib.rs b/core/src/lib.rs index 95bd5fe2..e3c83d8d 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -34,7 +34,7 @@ pub mod idx; pub mod key; #[doc(hidden)] pub mod kvs; -#[cfg(any(feature = "ml", feature = "ml2", feature = "jwks"))] +#[cfg(any(feature = "ml", feature = "ml2"))] #[doc(hidden)] pub mod obs; pub mod options;