2024-02-02 22:10:47 +00:00
use crate::dbs::capabilities::NetTarget;
2024-01-09 17:17:48 +00:00
use crate::err::Error;
use crate::kvs::Datastore;
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::jwk::{Jwk, JwkSet, KeyOperations, PublicKeyUse};
use jsonwebtoken::{DecodingKey, Validation};
use once_cell::sync::Lazy;
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
2024-03-12 10:34:35 +00:00
use std::collections::HashMap;
2024-01-09 17:17:48 +00:00
use std::str::FromStr;
2024-03-12 10:34:35 +00:00
use std::sync::Arc;
use tokio::sync::RwLock;
pub(crate) type JwksCache = HashMap<String, JwksCacheEntry>;
#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct JwksCacheEntry {
jwks: JwkSet,
time: DateTime<Utc>,
2024-01-09 17:17:48 +00:00
static CACHE_EXPIRATION: Lazy<chrono::Duration> = Lazy::new(|| Duration::seconds(1));
static CACHE_EXPIRATION: Lazy<chrono::Duration> =
Lazy::new(|| match std::env::var("SURREAL_JWKS_CACHE_EXPIRATION_SECONDS") {
Ok(seconds_str) => {
let seconds = seconds_str.parse::<u64>().expect(
"Expected a valid number of seconds for SURREAL_JWKS_CACHE_EXPIRATION_SECONDS",
Duration::seconds(seconds as i64)
Err(_) => {
Duration::seconds(43200) // Set default cache expiration of 12 hours
static CACHE_COOLDOWN: Lazy<chrono::Duration> = Lazy::new(|| Duration::seconds(300));
static CACHE_COOLDOWN: Lazy<chrono::Duration> =
Lazy::new(|| match std::env::var("SURREAL_JWKS_CACHE_COOLDOWN_SECONDS") {
Ok(seconds_str) => {
let seconds = seconds_str.parse::<u64>().expect(
"Expected a valid number of seconds for SURREAL_JWKS_CACHE_COOLDOWN_SECONDS",
Duration::seconds(seconds as i64)
Err(_) => {
Duration::seconds(300) // Set default cache refresh cooldown of 5 minutes
2024-02-02 22:10:47 +00:00
#[cfg(not(target_arch = "wasm32"))]
2024-01-09 17:17:48 +00:00
static REMOTE_TIMEOUT: Lazy<chrono::Duration> =
Lazy::new(|| match std::env::var("SURREAL_JWKS_REMOTE_TIMEOUT_MILLISECONDS") {
Ok(milliseconds_str) => {
let milliseconds = milliseconds_str
.expect("Expected a valid number of milliseconds for SURREAL_JWKS_REMOTE_TIMEOUT_MILLISECONDS");
Duration::milliseconds(milliseconds as i64)
Err(_) => {
Duration::milliseconds(1000) // Set default remote timeout to 1 second
// Generates a verification configuration from a JWKS object hosted in a remote location
// Performs local caching of all JWKS objects to prevent unnecessary network requests
// Implements checks to prevent denial of service and unauthorized network requests
// Validates the JWK objects found in the JWKS object according to RFC 7517
// Source: https://datatracker.ietf.org/doc/html/rfc7517
pub(super) async fn config(
kvs: &Datastore,
kid: &str,
url: &str,
) -> Result<(DecodingKey, Validation), Error> {
2024-03-12 10:34:35 +00:00
// Retrieve JWKS cache
let cache = kvs.jwks_cache();
2024-01-09 17:17:48 +00:00
// Attempt to fetch relevant JWK object either from local cache or remote location
2024-03-12 10:34:35 +00:00
let jwk = match fetch_jwks_from_cache(cache, url).await {
Some(jwks) => {
2024-01-09 17:17:48 +00:00
trace!("Successfully fetched JWKS object from local cache");
// Check that the cached JWKS object has not expired yet
2024-03-12 10:34:35 +00:00
if Utc::now().signed_duration_since(jwks.time) < *CACHE_EXPIRATION {
2024-01-09 17:17:48 +00:00
// Attempt to find JWK in JWKS object from local cache
2024-03-12 10:34:35 +00:00
match jwks.jwks.find(kid) {
2024-01-09 17:17:48 +00:00
Some(jwk) => jwk.to_owned(),
_ => {
trace!("Could not find valid JWK object with key identifier '{kid}' in cached JWKS object");
// Check that the cached JWKS object has not been recently updated
2024-03-12 10:34:35 +00:00
if Utc::now().signed_duration_since(jwks.time) < *CACHE_COOLDOWN {
2024-01-09 17:17:48 +00:00
debug!("Refused to refresh cache before cooldown period is over");
return Err(Error::InvalidAuth); // Return opaque error
find_jwk_from_url(kvs, url, kid).await?
} else {
trace!("Fetched JWKS object from local cache has expired");
find_jwk_from_url(kvs, url, kid).await?
2024-03-12 10:34:35 +00:00
None => {
2024-01-09 17:17:48 +00:00
trace!("Could not fetch JWKS object from local cache");
find_jwk_from_url(kvs, url, kid).await?
// Check if algorithm specified is supported
let alg = match jwk.common.algorithm {
Some(alg) => alg,
_ => {
warn!("Invalid value for parameter 'alg' in JWK object: '{:?}'", jwk.common.algorithm);
return Err(Error::InvalidAuth); // Return opaque error
// Check if the key use (if specified) is intended to be used for signing
// Source: https://datatracker.ietf.org/doc/html/rfc7517#section-4.2
match &jwk.common.public_key_use {
Some(PublicKeyUse::Signature) => (),
Some(key_use) => {
warn!("Invalid value for parameter 'use' in JWK object: '{:?}'", key_use);
return Err(Error::InvalidAuth); // Return opaque error
None => (),
// Check if the key operations (if specified) include verification
// Source: https://datatracker.ietf.org/doc/html/rfc7517#section-4.3
if let Some(ops) = &jwk.common.key_operations {
if !ops.iter().any(|op| *op == KeyOperations::Verify) {
"Invalid values for parameter 'key_ops' in JWK object: '{:?}'",
return Err(Error::InvalidAuth); // Return opaque error
// Return verification configuration if a decoding key can be retrieved from the JWK object
match DecodingKey::from_jwk(&jwk) {
Ok(dec) => Ok((dec, Validation::new(alg))),
Err(err) => {
warn!("Failed to retrieve decoding key from JWK object: '{}'", err);
Err(Error::InvalidAuth) // Return opaque error
// Checks if network access to a remote location is allowed by the datastore capabilities
// Attempts to find a relevant JWK object inside a JWKS object fetched from the remote location
async fn find_jwk_from_url(kvs: &Datastore, url: &str, kid: &str) -> Result<Jwk, Error> {
// Check that the datastore capabilities allow connections to the URL host
if let Err(err) = check_capabilities_url(kvs, url) {
warn!("Network access to JWKS location is not allowed: '{}'", err);
return Err(Error::InvalidAuth); // Return opaque error
2024-03-12 10:34:35 +00:00
// Retrieve JWKS cache
let cache = kvs.jwks_cache();
2024-01-09 17:17:48 +00:00
// Attempt to fetch JWKS object from remote location
2024-03-12 10:34:35 +00:00
match fetch_jwks_from_url(cache, url).await {
2024-01-09 17:17:48 +00:00
Ok(jwks) => {
trace!("Successfully fetched JWKS object from remote location");
// Attempt to find JWK in JWKS by the key identifier
match jwks.find(kid) {
Some(jwk) => Ok(jwk.to_owned()),
_ => {
debug!("Failed to find JWK object with key identifier '{kid}' in remote JWKS object");
Err(Error::InvalidAuth) // Return opaque error
Err(err) => {
warn!("Failed to fetch JWKS object from remote location: '{}'", err);
Err(Error::InvalidAuth) // Return opaque error
// Returns an error if network access to the address from a given URL string is not allowed
fn check_capabilities_url(kvs: &Datastore, url: &str) -> Result<(), Error> {
let url_parsed = match Url::parse(url) {
Ok(url) => url,
Err(_) => {
return Err(Error::InvalidUrl(url.to_string()));
let addr = match url_parsed.host_str() {
Some(host) => {
if let Some(port) = url_parsed.port() {
} else {
None => {
return Err(Error::InvalidUrl(url.to_string()));
let target = match NetTarget::from_str(&addr) {
Ok(host) => host,
Err(_) => {
return Err(Error::InvalidUrl(url.to_string()));
if !kvs.allows_network_target(&target) {
return Err(Error::InvalidUrl(url.to_string()));
// Attempts to fetch a JWKS object from a remote location and stores it in the cache if successful
2024-03-12 10:34:35 +00:00
async fn fetch_jwks_from_url(cache: &Arc<RwLock<JwksCache>>, url: &str) -> Result<JwkSet, Error> {
2024-01-09 17:17:48 +00:00
let client = Client::new();
#[cfg(not(target_arch = "wasm32"))]
let res = client.get(url).timeout((*REMOTE_TIMEOUT).to_std().unwrap()).send().await?;
#[cfg(target_arch = "wasm32")]
let res = client.get(url).send().await?;
if !res.status().is_success() {
warn!("Unsuccessful HTTP status code received when fetching JWKS object from remote location: '{:?}'", res.status());
return Err(Error::InvalidAuth); // Return opaque error
let jwks = res.bytes().await?;
match serde_json::from_slice::<JwkSet>(&jwks) {
Ok(jwks) => {
// If successful, cache the JWKS object by its URL
2024-03-12 10:34:35 +00:00
match store_jwks_in_cache(cache, jwks.clone(), url).await {
None => trace!("Successfully added JWKS object to local cache"),
Some(_) => trace!("Successfully updated JWKS object in local cache"),
2024-01-09 17:17:48 +00:00
Err(err) => {
warn!("Failed to parse malformed JWKS object: '{}'", err);
Err(Error::InvalidAuth) // Return opaque error
// Attempts to fetch a JWKS object from the local cache
2024-03-12 10:34:35 +00:00
async fn fetch_jwks_from_cache(
cache: &Arc<RwLock<JwksCache>>,
url: &str,
) -> Option<JwksCacheEntry> {
let path = cache_key_from_url(url);
let cache = cache.read().await;
2024-01-09 17:17:48 +00:00
2024-03-12 10:34:35 +00:00
2024-01-09 17:17:48 +00:00
// Attempts to store a JWKS object in the local cache
2024-03-12 10:34:35 +00:00
async fn store_jwks_in_cache(
cache: &Arc<RwLock<JwksCache>>,
jwks: JwkSet,
url: &str,
) -> Option<JwksCacheEntry> {
let entry = JwksCacheEntry {
2024-01-09 17:17:48 +00:00
time: Utc::now(),
2024-03-12 10:34:35 +00:00
let path = cache_key_from_url(url);
let mut cache = cache.write().await;
2024-01-09 17:17:48 +00:00
2024-03-12 10:34:35 +00:00
cache.insert(path, entry)
2024-01-09 17:17:48 +00:00
2024-03-12 10:34:35 +00:00
// Generates a unique cache key for a given URL string
fn cache_key_from_url(url: &str) -> String {
2024-01-09 17:17:48 +00:00
let mut hasher = Sha256::new();
let result = hasher.finalize();
2024-03-12 10:34:35 +00:00
format!("{:x}", result)
2024-01-09 17:17:48 +00:00
mod tests {
use super::*;
2024-02-02 22:10:47 +00:00
use crate::dbs::capabilities::{Capabilities, NetTarget, Targets};
2024-01-09 17:17:48 +00:00
use rand::{distributions::Alphanumeric, Rng};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
// Use unique path to prevent accidental cache reuse
fn random_path() -> String {
let rng = rand::thread_rng();
static DEFAULT_JWKS: Lazy<JwkSet> = Lazy::new(|| {
keys: vec![Jwk{
common: jsonwebtoken::jwk::CommonParameters {
public_key_use: Some(jsonwebtoken::jwk::PublicKeyUse::Signature),
key_operations: None,
algorithm: Some(jsonwebtoken::Algorithm::RS256),
key_id: Some("test_1".to_string()),
x509_url: None,
x509_chain: Some(vec![
x509_sha1_fingerprint: None,
x509_sha256_fingerprint: None,
algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
n: "2nsSvrRnuw6OLJCqltkiRAGV07-35isdPwyTrrWQ3PwxEZc-lDbquQ7Z9Fkx5Y-ldVzBbTHEsbmhDYjBubUlS4dhstvpYD93963Sw6Q6gQjow_T4xWqsaeuj4PpcajPjI_ybbDwLa7bIXEBz7AC3UAgxY0khiERfq2quWIaeK0MLJ7bBcpyGF7hZy1SUehQ187-yBrM9Dsi2qKxQX981JFsctEnJLaabvoWUMQsMucTJXBRp5X_bGJ70XjgB85DNWTVqw7XwEfe_piM5DcvjVcR86bYMw-Qs46a3IzvIDs54X9--frM35IHLNrpwVbfsg4qgmya_GTPF4NSVab0xaQ".to_string(),
e: "AQAB".to_string(),
common: jsonwebtoken::jwk::CommonParameters {
public_key_use: Some(jsonwebtoken::jwk::PublicKeyUse::Signature),
key_operations: None,
algorithm: Some(jsonwebtoken::Algorithm::RS256),
key_id: Some("test_2".to_string()),
x509_url: None,
x509_chain: Some(vec![
x509_sha1_fingerprint: None,
x509_sha256_fingerprint: None,
algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
n: "vtOCOuiM_JI87mQ8E6ICCJLSC5KliR9vQC0s1XJV4A17m-CMmgFAN7u8AabrxId-3zjUZAE-nkpanENM76WIQJUCdt1H1gfC5lY4a49FVXA2q1WZLwDvlgb-ZNYZXi2vaH50uONXeO9XSG9dEnBUVKGVRL34GqB68UGgXrPGLkAcjH-TW0KDXLZ-FKXNhQfESIVGDHRGG0l-LPK_1AegtJjdEUjhA4CQ-1jA3kLVfr2cQc8rRD5b486R5XvC4xBlZNFFP7Fm5if4khhAJC-JnnYWgmytPM4Q7mOWatRr08wQmmfQDDrw53IseNA-yKnwHYlJ6ChU_UtNzS0OipUapQ".to_string(),
e: "AQAB".to_string(),
async fn test_golden_path() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
// Get first token configuration from remote location
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_ok(), "Failed to validate token the first time: {:?}", res.err());
// Drop server to force usage of the local cache
// Get second token configuration from local cache
let res = config(&ds, "test_2", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_ok(), "Failed to validate token the second time: {:?}", res.err());
async fn test_capabilities_default() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(Capabilities::default());
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
// Get token configuration from unallowed remote location
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_err(), "Unexpected success validating token from unallowed remote location");
async fn test_capabilities_specific_port() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
[NetTarget::from_str("").unwrap()].into(), // Different port from server
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
// Get token configuration from unallowed remote location
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_err(), "Unexpected success validating token from unallowed remote location");
async fn test_cache_expiration() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
// Get token configuration from remote location
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_ok(), "Failed to validate token the first time: {:?}", res.err());
// Wait for cache to expire
std::thread::sleep((*CACHE_EXPIRATION + Duration::seconds(1)).to_std().unwrap());
// Get same token configuration again after cache has expired
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_ok(), "Failed to validate token the second time: {:?}", res.err());
// The server will panic if it does not receive exactly two expected requests
async fn test_cache_cooldown() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
// Use token with invalid key identifier claim to force cache refresh
let res = config(&ds, "invalid", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_err(), "Unexpected success validating token with invalid key identifier");
// Use token with invalid key identifier claim to force cache refresh again before cooldown
let res = config(&ds, "invalid", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_err(), "Unexpected success validating token with invalid key identifier");
// The server will panic if it receives more than the single expected request
async fn test_cache_expiration_remote_down() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
.up_to_n_times(1) // Only respond the first time
let url = mock_server.uri();
// Get token configuration from remote location
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
assert!(res.is_ok(), "Failed to validate token the first time: {:?}", res.err());
// Wait for cache to expire
std::thread::sleep((*CACHE_EXPIRATION + Duration::seconds(1)).to_std().unwrap());
// Get same token configuration again after cache has expired
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Unexpected success validating token with an expired cache and remote down"
async fn test_unsupported_algorithm() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let mut jwks = DEFAULT_JWKS.clone();
jwks.keys[0].common.algorithm = None;
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Unexpected success validating token with key using unsupported algorithm"
async fn test_no_key_use() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let mut jwks = DEFAULT_JWKS.clone();
jwks.keys[0].common.public_key_use = None;
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Failed to validate token with key that does not specify use: {:?}",
async fn test_key_use_enc() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let mut jwks = DEFAULT_JWKS.clone();
jwks.keys[0].common.public_key_use = Some(jsonwebtoken::jwk::PublicKeyUse::Encryption);
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Unexpected success validating token with key that only supports encryption"
async fn test_key_ops_encrypt_only() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let mut jwks = DEFAULT_JWKS.clone();
jwks.keys[0].common.key_operations = Some(vec![jsonwebtoken::jwk::KeyOperations::Encrypt]);
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200).set_body_json(jwks);
let url = mock_server.uri();
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Unexpected success validating token with key that only supports encryption"
async fn test_remote_down() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(500);
let url = mock_server.uri();
// Get token configuration from remote location responding with Internal Server Error
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Unexpected success validating token configuration with unavailable remote location"
2024-02-02 22:10:47 +00:00
#[cfg(not(target_arch = "wasm32"))]
2024-01-09 17:17:48 +00:00
async fn test_remote_timeout() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
let jwks = DEFAULT_JWKS.clone();
let jwks_path = format!("{}/jwks.json", random_path());
let mock_server = MockServer::start().await;
let response = ResponseTemplate::new(200)
.set_delay((*REMOTE_TIMEOUT + Duration::seconds(10)).to_std().unwrap());
let url = mock_server.uri();
let start_time = Utc::now();
// Get token configuration from remote location responding very slowly
let res = config(&ds, "test_1", &format!("{}/{}", &url, &jwks_path)).await;
"Unexpected success validating token configuration with unavailable remote location"
Utc::now() - start_time < *REMOTE_TIMEOUT + Duration::seconds(1),
"Remote request was not aborted immediately after timeout"