[metrics] HTTP Layer + move to Axum (#2237)
This commit is contained in:
parent
eef9b755cb
commit
53702c247a
60 changed files with 3873 additions and 1430 deletions
24
.github/workflows/ci.yml
vendored
24
.github/workflows/ci.yml
vendored
|
@ -118,7 +118,29 @@ jobs:
|
|||
sudo apt-get -y install protobuf-compiler libprotobuf-dev
|
||||
|
||||
- name: Run cargo test
|
||||
run: cargo test --locked --no-default-features --features storage-mem --workspace --test cli
|
||||
run: cargo test --locked --no-default-features --features storage-mem --workspace --test cli_integration
|
||||
|
||||
http-server:
|
||||
name: Test HTTP server
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
|
||||
- name: Install stable toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup cache
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get -y update
|
||||
sudo apt-get -y install protobuf-compiler libprotobuf-dev
|
||||
|
||||
- name: Run cargo test
|
||||
run: cargo test --locked --no-default-features --features storage-mem --workspace --test http_integration
|
||||
|
||||
test:
|
||||
name: Test workspace
|
||||
|
|
301
Cargo.lock
generated
301
Cargo.lock
generated
|
@ -405,20 +405,6 @@ dependencies = [
|
|||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-compression"
|
||||
version = "0.3.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a"
|
||||
dependencies = [
|
||||
"brotli",
|
||||
"flate2",
|
||||
"futures-core",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-executor"
|
||||
version = "1.5.1"
|
||||
|
@ -543,9 +529,11 @@ checksum = "a6a1de45611fdb535bfde7b7de4fd54f4fd2b17b1737c0a59b69bf9b92074b8c"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"base64 0.21.2",
|
||||
"bitflags 1.3.2",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
|
@ -560,11 +548,25 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite 0.19.0",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-client-ip"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df8e81eacc93f36480825da5f46a33b5fb2246ed024eacc9e8933425b80c5807"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"forwarded-header-value",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -582,6 +584,7 @@ dependencies = [
|
|||
"rustversion",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -595,6 +598,64 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-extra"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "febf23ab04509bd7672e6abe76bd8277af31b679e89fa5ffc6087dc289a448a3"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"axum-core",
|
||||
"axum-macros",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_html_form",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-macros"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2bb524613be645939e280b7279f7b017f98cf7f5ef084ec374df373530e73277"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.26",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-server"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"rustls 0.21.5",
|
||||
"rustls-pemfile",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.68"
|
||||
|
@ -616,6 +677,12 @@ version = "0.13.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.2"
|
||||
|
@ -1349,6 +1416,12 @@ dependencies = [
|
|||
"parking_lot_core 0.9.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308"
|
||||
|
||||
[[package]]
|
||||
name = "debugid"
|
||||
version = "0.8.0"
|
||||
|
@ -1654,6 +1727,16 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "forwarded-header-value"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
|
||||
dependencies = [
|
||||
"nonempty",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foundationdb"
|
||||
version = "0.8.0"
|
||||
|
@ -2170,6 +2253,12 @@ dependencies = [
|
|||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-range-header"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.8.0"
|
||||
|
@ -2741,24 +2830,6 @@ dependencies = [
|
|||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multer"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-util",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"memchr",
|
||||
"mime",
|
||||
"spin 0.9.8",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.8.3"
|
||||
|
@ -2844,6 +2915,12 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonempty"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
|
@ -2982,9 +3059,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "opentelemetry"
|
||||
version = "0.18.0"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e"
|
||||
checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f"
|
||||
dependencies = [
|
||||
"opentelemetry_api",
|
||||
"opentelemetry_sdk",
|
||||
|
@ -2992,9 +3069,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "opentelemetry-otlp"
|
||||
version = "0.11.0"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde"
|
||||
checksum = "8af72d59a4484654ea8eb183fea5ae4eb6a41d7ac3e3bae5f4d2a282a3a7d3ca"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"futures 0.3.28",
|
||||
|
@ -3010,39 +3087,38 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "opentelemetry-proto"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28"
|
||||
checksum = "045f8eea8c0fa19f7d48e7bc3128a39c2e5c533d5c61298c548dfefc1064474c"
|
||||
dependencies = [
|
||||
"futures 0.3.28",
|
||||
"futures-util",
|
||||
"opentelemetry",
|
||||
"prost 0.11.9",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry_api"
|
||||
version = "0.18.0"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22"
|
||||
checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"indexmap 1.9.3",
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry_sdk"
|
||||
version = "0.18.0"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113"
|
||||
checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"crossbeam-channel",
|
||||
|
@ -4241,12 +4317,6 @@ dependencies = [
|
|||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scoped-tls"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
|
@ -4358,6 +4428,19 @@ dependencies = [
|
|||
"syn 2.0.26",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_html_form"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7d9b8af723e4199801a643c622daa7aae8cf4a772dc2b3efcb3a95add6cb91a"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"indexmap 2.0.0",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.103"
|
||||
|
@ -4607,12 +4690,18 @@ version = "1.0.0-beta.9"
|
|||
dependencies = [
|
||||
"argon2",
|
||||
"assert_fs",
|
||||
"axum",
|
||||
"axum-client-ip",
|
||||
"axum-extra",
|
||||
"axum-server",
|
||||
"base64 0.21.2",
|
||||
"bytes",
|
||||
"clap 4.3.12",
|
||||
"futures 0.3.28",
|
||||
"futures-util",
|
||||
"glob",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"ipnet",
|
||||
"nix",
|
||||
|
@ -4620,6 +4709,7 @@ dependencies = [
|
|||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-proto",
|
||||
"pin-project-lite",
|
||||
"rand 0.8.5",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
|
@ -4637,12 +4727,13 @@ dependencies = [
|
|||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tonic",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4715,7 +4806,7 @@ dependencies = [
|
|||
"thiserror",
|
||||
"time 0.3.23",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-tungstenite 0.18.0",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
@ -5141,11 +5232,23 @@ dependencies = [
|
|||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls 0.23.4",
|
||||
"tungstenite",
|
||||
"tungstenite 0.18.0",
|
||||
"webpki",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.19.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.8"
|
||||
|
@ -5219,19 +5322,6 @@ dependencies = [
|
|||
"tracing-futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-build"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4"
|
||||
dependencies = [
|
||||
"prettyplease 0.1.25",
|
||||
"proc-macro2",
|
||||
"prost-build 0.11.9",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.4.13"
|
||||
|
@ -5252,6 +5342,29 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c"
|
||||
dependencies = [
|
||||
"base64 0.20.0",
|
||||
"bitflags 2.3.3",
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-range-header",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.2"
|
||||
|
@ -5321,9 +5434,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tracing-opentelemetry"
|
||||
version = "0.18.0"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de"
|
||||
checksum = "00a39dcf9bfc1742fa4d6215253b33a6e474be78275884c216fc2a06267b3600"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
|
@ -5396,6 +5509,25 @@ dependencies = [
|
|||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
|
@ -5544,39 +5676,6 @@ dependencies = [
|
|||
"try-lock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "warp"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http",
|
||||
"hyper",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"multer",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"rustls-pemfile",
|
||||
"scoped-tls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.4",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
|
|
18
Cargo.toml
18
Cargo.toml
|
@ -34,17 +34,24 @@ strip = false
|
|||
|
||||
[dependencies]
|
||||
argon2 = "0.5.1"
|
||||
axum = { version = "0.6.18", features = ["tracing", "ws", "headers"] }
|
||||
axum-client-ip = "0.4.1"
|
||||
axum-extra = { version = "0.7.4", features = ["typed-routing"] }
|
||||
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
||||
base64 = "0.21.2"
|
||||
bytes = "1.4.0"
|
||||
clap = { version = "4.3.12", features = ["env", "derive", "wrap_help", "unicode"] }
|
||||
futures = "0.3.28"
|
||||
futures-util = "0.3.28"
|
||||
glob = "0.3.1"
|
||||
http = "0.2.9"
|
||||
http-body = "0.4.5"
|
||||
hyper = "0.14.27"
|
||||
ipnet = "2.8.0"
|
||||
once_cell = "1.18.0"
|
||||
opentelemetry = { version = "0.18", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
opentelemetry = { version = "0.19", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
|
||||
pin-project-lite = "0.2.9"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.18", features = ["blocking"] }
|
||||
rustyline = { version = "11.0.0", features = ["derive"] }
|
||||
|
@ -57,19 +64,20 @@ tempfile = "3.6.0"
|
|||
thiserror = "1.0.43"
|
||||
tokio = { version = "1.29.1", features = ["macros", "signal"] }
|
||||
tokio-util = { version = "0.7.8", features = ["io"] }
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.4.1", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||
urlencoding = "2.1.2"
|
||||
uuid = { version = "1.4.0", features = ["serde", "js", "v4", "v7"] }
|
||||
warp = { version = "0.3.5", features = ["compression", "tls", "websocket"] }
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
nix = "0.26.2"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_fs = "1.0.13"
|
||||
opentelemetry-proto = { version = "0.1.0", features = ["gen-tonic", "traces", "build-server"] }
|
||||
opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
|
||||
rcgen = "0.10.0"
|
||||
serial_test = "2.0.0"
|
||||
temp-env = "0.3.4"
|
||||
|
|
69
dev/docker/compose.yaml
Normal file
69
dev/docker/compose.yaml
Normal file
|
@ -0,0 +1,69 @@
|
|||
---
|
||||
version: "3"
|
||||
services:
|
||||
grafana:
|
||||
image: "grafana/grafana-oss:latest"
|
||||
expose:
|
||||
- "3000"
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- "grafana:/var/lib/grafana"
|
||||
- "./grafana.ini:/etc/grafana/grafana.ini"
|
||||
- "./grafana-datasource.yaml:/etc/grafana/provisioning/datasources/grafana-datasource.yaml"
|
||||
- "./grafana-dashboards.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboards.yaml"
|
||||
- "./dashboards:/dashboards"
|
||||
healthcheck:
|
||||
test:
|
||||
- CMD-SHELL
|
||||
- bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/3001; exit $$?;'
|
||||
interval: 1s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
prometheus:
|
||||
image: "prom/prometheus:latest"
|
||||
command:
|
||||
- "--config.file=/etc/prometheus/prometheus.yaml"
|
||||
- "--storage.tsdb.path=/prometheus"
|
||||
- "--web.console.libraries=/usr/share/prometheus/console_libraries"
|
||||
- "--web.console.templates=/usr/share/prometheus/consoles"
|
||||
- "--web.listen-address=0.0.0.0:9090"
|
||||
- "--web.enable-remote-write-receiver"
|
||||
- "--storage.tsdb.retention.time=1d"
|
||||
expose:
|
||||
- "9090"
|
||||
ports:
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- "prometheus:/prometheus"
|
||||
- "./prometheus.yaml:/etc/prometheus/prometheus.yaml"
|
||||
|
||||
tempo:
|
||||
image: grafana/tempo:latest
|
||||
command: [ "-config.file=/etc/tempo.yaml" ]
|
||||
volumes:
|
||||
- ./tempo.yaml:/etc/tempo.yaml
|
||||
- tempo:/tmp/tempo
|
||||
ports:
|
||||
- "3200" # tempo
|
||||
- "4317" # otlp grpc
|
||||
|
||||
|
||||
otel-collector:
|
||||
image: "otel/opentelemetry-collector-contrib"
|
||||
command:
|
||||
- "--config=/etc/otel-collector.yaml"
|
||||
expose:
|
||||
- "4317"
|
||||
ports:
|
||||
- "4317:4317" # otlp grpc
|
||||
- "9090" # for prometheus
|
||||
volumes: ["./otel-collector.yaml:/etc/otel-collector.yaml"]
|
||||
|
||||
volumes:
|
||||
grafana:
|
||||
external: false
|
||||
prometheus:
|
||||
external: false
|
||||
tempo:
|
||||
external: false
|
15
dev/docker/grafana-dashboards.yaml
Normal file
15
dev/docker/grafana-dashboards.yaml
Normal file
|
@ -0,0 +1,15 @@
|
|||
apiVersion: 1
|
||||
|
||||
providers:
|
||||
- name: 'surrealdb-grafana'
|
||||
orgId: 1
|
||||
folder: ''
|
||||
folderUid: ''
|
||||
type: file
|
||||
disableDeletion: false
|
||||
updateIntervalSeconds: 1
|
||||
allowUiUpdates: true
|
||||
options:
|
||||
path: /dashboards
|
||||
foldersFromFilesStructure: false
|
||||
|
30
dev/docker/grafana-datasource.yaml
Normal file
30
dev/docker/grafana-datasource.yaml
Normal file
|
@ -0,0 +1,30 @@
|
|||
apiVersion: 1
|
||||
deleteDatasources:
|
||||
- name: Prometheus
|
||||
- name: Tempo
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
type: prometheus
|
||||
access: proxy
|
||||
url: http://prometheus:9090
|
||||
withCredentials: false
|
||||
isDefault: true
|
||||
tlsAuth: false
|
||||
tlsAuthWithCACert: false
|
||||
version: 1
|
||||
editable: true
|
||||
- name: Tempo
|
||||
type: tempo
|
||||
access: proxy
|
||||
orgId: 1
|
||||
url: http://tempo:3200
|
||||
basicAuth: false
|
||||
isDefault: false
|
||||
version: 1
|
||||
editable: false
|
||||
apiVersion: 1
|
||||
uid: tempo
|
||||
jsonData:
|
||||
httpMethod: GET
|
||||
serviceMap:
|
||||
datasourceUid: prometheus
|
6
dev/docker/grafana.ini
Normal file
6
dev/docker/grafana.ini
Normal file
|
@ -0,0 +1,6 @@
|
|||
[server]
|
||||
http_addr = 0.0.0.0
|
||||
http_port = 3000
|
||||
|
||||
[users]
|
||||
default_theme = light
|
36
dev/docker/otel-collector.yaml
Normal file
36
dev/docker/otel-collector.yaml
Normal file
|
@ -0,0 +1,36 @@
|
|||
receivers:
|
||||
otlp:
|
||||
protocols:
|
||||
grpc:
|
||||
|
||||
exporters:
|
||||
otlp:
|
||||
endpoint: 'tempo:4317'
|
||||
tls:
|
||||
insecure: true
|
||||
|
||||
prometheus:
|
||||
endpoint: ':9090'
|
||||
send_timestamps: true
|
||||
metric_expiration: 60m
|
||||
resource_to_telemetry_conversion:
|
||||
enabled: true
|
||||
|
||||
logging: # add to a pipeline for debugging
|
||||
loglevel: debug
|
||||
|
||||
# processors:
|
||||
# batch:
|
||||
# timeout: 1s
|
||||
# span:
|
||||
# name:
|
||||
# from_attributes: ["name"]
|
||||
|
||||
service:
|
||||
pipelines:
|
||||
traces:
|
||||
receivers: [otlp]
|
||||
exporters: [otlp, logging]
|
||||
metrics:
|
||||
receivers: [otlp]
|
||||
exporters: [prometheus]
|
17
dev/docker/prometheus.yaml
Normal file
17
dev/docker/prometheus.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
global:
|
||||
scrape_interval: 5s
|
||||
evaluation_interval: 10s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: prometheus
|
||||
static_configs:
|
||||
- targets: ["prometheus:9500"]
|
||||
|
||||
- job_name: 'tempo'
|
||||
static_configs:
|
||||
- targets: ["tempo:3200"]
|
||||
|
||||
- job_name: otel-collector
|
||||
static_configs:
|
||||
# Scrap the SurrealDB metrics sent to OpenTelemetry collector
|
||||
- targets: ["otel-collector:9090"]
|
23
dev/docker/tempo.yaml
Normal file
23
dev/docker/tempo.yaml
Normal file
|
@ -0,0 +1,23 @@
|
|||
server:
|
||||
http_listen_port: 3200
|
||||
|
||||
distributor:
|
||||
receivers:
|
||||
otlp:
|
||||
protocols:
|
||||
grpc:
|
||||
|
||||
ingester:
|
||||
max_block_duration: 5m # cut the headblock when this much time passes. this is being set for dev purposes and should probably be left alone normally
|
||||
|
||||
compactor:
|
||||
compaction:
|
||||
block_retention: 1h # overall Tempo trace retention. set for dev purposes
|
||||
|
||||
storage:
|
||||
trace:
|
||||
backend: local
|
||||
wal:
|
||||
path: /tmp/tempo/wal
|
||||
local:
|
||||
path: /tmp/tempo/blocks
|
16
doc/TELEMETRY.md
Normal file
16
doc/TELEMETRY.md
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Telemetry
|
||||
|
||||
SurrealDB leverages the tracing and opentelemetry libraries to instrument the code.
|
||||
|
||||
Both metrics and traces are pushed to an OTEL compatible receiver.
|
||||
|
||||
For local development, you can start the observability stack defined in `dev/docker`. It spins up an instance of Opentelemetry collector, Grafana, Prometheus and Tempo:
|
||||
|
||||
```
|
||||
$ docker-compose -f dev/docker/compose.yaml up -d
|
||||
$ SURREAL_TRACING_TRACER=otlp OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317" surreal start
|
||||
```
|
||||
|
||||
Now you can use the SurrealDB server and see the telemetry data opening this URL in the browser: http://localhost:3000
|
||||
|
||||
To login into Grafana, use the default user `admin` and password `admin`.
|
|
@ -37,7 +37,7 @@ pub async fn init(
|
|||
}: BackupCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("error").init();
|
||||
crate::telemetry::builder().with_log_level("error").init();
|
||||
|
||||
// Process the source->destination response
|
||||
let into_local = into.ends_with(".db");
|
||||
|
|
|
@ -38,7 +38,7 @@ pub async fn init(
|
|||
}: ExportCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("error").init();
|
||||
crate::telemetry::builder().with_log_level("error").init();
|
||||
|
||||
let root = Root {
|
||||
username: &username,
|
||||
|
|
|
@ -36,7 +36,7 @@ pub async fn init(
|
|||
}: ImportCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("error").init();
|
||||
crate::telemetry::builder().with_log_level("error").init();
|
||||
|
||||
let root = Root {
|
||||
username: &username,
|
||||
|
|
|
@ -17,7 +17,7 @@ pub async fn init(
|
|||
}: IsReadyCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("error").init();
|
||||
crate::telemetry::builder().with_log_level("error").init();
|
||||
// Connect to the database engine
|
||||
connect(endpoint).await?;
|
||||
println!("OK");
|
||||
|
|
|
@ -49,7 +49,7 @@ pub async fn init(
|
|||
}: SqlCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("warn").init();
|
||||
crate::telemetry::builder().with_log_level("warn").init();
|
||||
|
||||
let root = Root {
|
||||
username: &username,
|
||||
|
|
|
@ -101,7 +101,7 @@ pub async fn init(
|
|||
}: StartCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_filter(log).init();
|
||||
crate::telemetry::builder().with_filter(log).init();
|
||||
|
||||
// Check if a banner should be outputted
|
||||
if !no_banner {
|
||||
|
|
|
@ -47,7 +47,7 @@ impl UpgradeCommandArguments {
|
|||
|
||||
pub async fn init(args: UpgradeCommandArguments) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("error").init();
|
||||
crate::telemetry::builder().with_log_level("error").init();
|
||||
|
||||
// Upgrading overwrites the existing executable
|
||||
let exe = std::env::current_exe()?;
|
||||
|
|
|
@ -31,6 +31,10 @@ impl TypedValueParser for CustomEnvFilterParser {
|
|||
arg: Option<&clap::Arg>,
|
||||
value: &std::ffi::OsStr,
|
||||
) -> Result<Self::Value, clap::Error> {
|
||||
if let Ok(dirs) = std::env::var("RUST_LOG") {
|
||||
return Ok(CustomEnvFilter(EnvFilter::builder().parse_lossy(dirs)));
|
||||
}
|
||||
|
||||
let inner = NonEmptyStringValueParser::new();
|
||||
let v = inner.parse_ref(cmd, arg, value)?;
|
||||
let filter = (match v.as_str() {
|
||||
|
|
|
@ -18,7 +18,7 @@ pub async fn init(
|
|||
}: VersionCommandArguments,
|
||||
) -> Result<(), Error> {
|
||||
// Initialize opentelemetry and logging
|
||||
crate::o11y::builder().with_log_level("error").init();
|
||||
crate::telemetry::builder().with_log_level("error").init();
|
||||
// Print server version if endpoint supplied else CLI version
|
||||
if let Some(e) = endpoint {
|
||||
// Print remote server version
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use base64::DecodeError as Base64Error;
|
||||
use http::StatusCode;
|
||||
use reqwest::Error as ReqwestError;
|
||||
use serde::Serialize;
|
||||
use serde_cbor::error::Error as CborError;
|
||||
|
@ -6,6 +9,7 @@ use serde_json::error::Error as JsonError;
|
|||
use serde_pack::encode::Error as PackError;
|
||||
use std::io::Error as IoError;
|
||||
use std::string::FromUtf8Error as Utf8Error;
|
||||
use surrealdb::error::Db as SurrealDbError;
|
||||
use surrealdb::Error as SurrealError;
|
||||
use thiserror::Error;
|
||||
|
||||
|
@ -51,8 +55,6 @@ pub enum Error {
|
|||
Remote(#[from] ReqwestError),
|
||||
}
|
||||
|
||||
impl warp::reject::Reject for Error {}
|
||||
|
||||
impl From<Error> for String {
|
||||
fn from(e: Error) -> String {
|
||||
e.to_string()
|
||||
|
@ -85,3 +87,56 @@ impl Serialize for Error {
|
|||
serializer.serialize_str(self.to_string().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(super) struct Message {
|
||||
code: u16,
|
||||
details: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
information: Option<String>,
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
err @ Error::InvalidAuth | err @ Error::Db(SurrealError::Db(SurrealDbError::InvalidAuth)) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(Message {
|
||||
code: 401,
|
||||
details: Some("Authentication failed".to_string()),
|
||||
description: Some("Your authentication details are invalid. Reauthenticate using valid authentication parameters.".to_string()),
|
||||
information: Some(err.to_string()),
|
||||
})
|
||||
),
|
||||
Error::InvalidType => (
|
||||
StatusCode::UNSUPPORTED_MEDIA_TYPE,
|
||||
Json(Message {
|
||||
code: 415,
|
||||
details: Some("Unsupported media type".to_string()),
|
||||
description: Some("The request needs to adhere to certain constraints. Refer to the documentation for supported content types.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
),
|
||||
Error::InvalidStorage => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(Message {
|
||||
code: 500,
|
||||
details: Some("Health check failed".to_string()),
|
||||
description: Some("The database health check for this instance failed. There was an issue with the underlying storage engine.".to_string()),
|
||||
information: Some(self.to_string()),
|
||||
}),
|
||||
),
|
||||
_ => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(Message {
|
||||
code: 400,
|
||||
details: Some("Request problems detected".to_string()),
|
||||
description: Some("There is a problem with your request. Refer to the documentation for further information.".to_string()),
|
||||
information: Some(self.to_string()),
|
||||
}),
|
||||
),
|
||||
}.into_response()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@ pub mod verify;
|
|||
use crate::cli::CF;
|
||||
use crate::err::Error;
|
||||
|
||||
pub const BASIC: &str = "Basic ";
|
||||
|
||||
pub async fn init() -> Result<(), Error> {
|
||||
// Get local copy of options
|
||||
let opt = CF.get().unwrap();
|
||||
|
|
|
@ -1,76 +1,63 @@
|
|||
use crate::cli::CF;
|
||||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::iam::BASIC;
|
||||
use argon2::password_hash::{PasswordHash, PasswordVerifier};
|
||||
use argon2::Argon2;
|
||||
use std::sync::Arc;
|
||||
use surrealdb::dbs::Auth;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::iam::base::{Engine, BASE64};
|
||||
|
||||
pub async fn basic(session: &mut Session, auth: String) -> Result<(), Error> {
|
||||
// Log the authentication type
|
||||
trace!("Attempting basic authentication");
|
||||
// Retrieve just the auth data
|
||||
let auth = auth.trim_start_matches(BASIC).trim();
|
||||
pub async fn basic(session: &mut Session, user: &str, pass: &str) -> Result<(), Error> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
// Get the config options
|
||||
let opts = CF.get().unwrap();
|
||||
// Decode the encoded auth data
|
||||
let auth = BASE64.decode(auth)?;
|
||||
// Convert the auth data to String
|
||||
let auth = String::from_utf8(auth)?;
|
||||
// Split the auth data into user and pass
|
||||
if let Some((user, pass)) = auth.split_once(':') {
|
||||
// Check that the details are not empty
|
||||
if user.is_empty() || pass.is_empty() {
|
||||
return Err(Error::InvalidAuth);
|
||||
|
||||
// Check that the details are not empty
|
||||
if user.is_empty() || pass.is_empty() {
|
||||
return Err(Error::InvalidAuth);
|
||||
}
|
||||
// Check if this is root authentication
|
||||
if let Some(root) = &opts.pass {
|
||||
if user == opts.user && pass == root {
|
||||
// Log the authentication type
|
||||
debug!("Authenticated as super user");
|
||||
// Store the authentication data
|
||||
session.au = Arc::new(Auth::Kv);
|
||||
return Ok(());
|
||||
}
|
||||
// Check if this is root authentication
|
||||
if let Some(root) = &opts.pass {
|
||||
if user == opts.user && pass == root {
|
||||
// Log the authentication type
|
||||
debug!("Authenticated as super user");
|
||||
}
|
||||
// Check if this is NS authentication
|
||||
if let Some(ns) = &session.ns {
|
||||
// Create a new readonly transaction
|
||||
let mut tx = kvs.transaction(false, false).await?;
|
||||
// Check if the supplied NS Login exists
|
||||
if let Ok(nl) = tx.get_nl(ns, user).await {
|
||||
// Compute the hash and verify the password
|
||||
let hash = PasswordHash::new(&nl.hash).unwrap();
|
||||
if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() {
|
||||
// Log the successful namespace authentication
|
||||
debug!("Authenticated as namespace user: {}", user);
|
||||
// Store the authentication data
|
||||
session.au = Arc::new(Auth::Kv);
|
||||
session.au = Arc::new(Auth::Ns(ns.to_owned()));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
// Check if this is NS authentication
|
||||
if let Some(ns) = &session.ns {
|
||||
// Create a new readonly transaction
|
||||
let mut tx = kvs.transaction(false, false).await?;
|
||||
// Check if the supplied NS Login exists
|
||||
if let Ok(nl) = tx.get_nl(ns, user).await {
|
||||
};
|
||||
// Check if this is DB authentication
|
||||
if let Some(db) = &session.db {
|
||||
// Check if the supplied DB Login exists
|
||||
if let Ok(dl) = tx.get_dl(ns, db, user).await {
|
||||
// Compute the hash and verify the password
|
||||
let hash = PasswordHash::new(&nl.hash).unwrap();
|
||||
let hash = PasswordHash::new(&dl.hash).unwrap();
|
||||
if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() {
|
||||
// Log the successful namespace authentication
|
||||
debug!("Authenticated as namespace user: {}", user);
|
||||
// Store the authentication data
|
||||
session.au = Arc::new(Auth::Ns(ns.to_owned()));
|
||||
session.au = Arc::new(Auth::Db(ns.to_owned(), db.to_owned()));
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
// Check if this is DB authentication
|
||||
if let Some(db) = &session.db {
|
||||
// Check if the supplied DB Login exists
|
||||
if let Ok(dl) = tx.get_dl(ns, db, user).await {
|
||||
// Compute the hash and verify the password
|
||||
let hash = PasswordHash::new(&dl.hash).unwrap();
|
||||
if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() {
|
||||
// Log the successful namespace authentication
|
||||
debug!("Authenticated as database user: {}", user);
|
||||
// Store the authentication data
|
||||
session.au = Arc::new(Auth::Db(ns.to_owned(), db.to_owned()));
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
// There was an auth error
|
||||
Err(Error::InvalidAuth)
|
||||
}
|
||||
|
|
|
@ -27,9 +27,9 @@ mod err;
|
|||
mod iam;
|
||||
#[cfg(feature = "has-storage")]
|
||||
mod net;
|
||||
mod o11y;
|
||||
#[cfg(feature = "has-storage")]
|
||||
mod rpc;
|
||||
mod telemetry;
|
||||
|
||||
use std::future::Future;
|
||||
use std::process::ExitCode;
|
||||
|
|
106
src/net/auth.rs
Normal file
106
src/net/auth.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
use axum::{
|
||||
body::{boxed, Body, BoxBody},
|
||||
headers::{
|
||||
authorization::{Basic, Bearer},
|
||||
Authorization, Origin,
|
||||
},
|
||||
Extension, RequestPartsExt, TypedHeader,
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use http::{request::Parts, StatusCode};
|
||||
use hyper::{Request, Response};
|
||||
use surrealdb::{dbs::Session, iam::verify::token};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::{dbs::DB, err::Error, iam::verify::basic};
|
||||
|
||||
use super::{client_ip::ExtractClientIP, AppState};
|
||||
|
||||
///
|
||||
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
|
||||
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
|
||||
///
|
||||
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
|
||||
///
|
||||
/// ```rust
|
||||
/// use tower_http::auth::RequireAuthorizationLayer;
|
||||
/// use surrealdb::net::SurrealAuth;
|
||||
/// use axum::Router;
|
||||
///
|
||||
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/version", get(|| async { "0.1.0" }))
|
||||
/// .layer(auth);
|
||||
/// ```
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct SurrealAuth;
|
||||
|
||||
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type RequestBody = B;
|
||||
type ResponseBody = BoxBody;
|
||||
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
||||
|
||||
fn authorize(&mut self, request: Request<B>) -> Self::Future {
|
||||
Box::pin(async {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
match check_auth(&mut parts).await {
|
||||
Ok(sess) => {
|
||||
parts.extensions.insert(sess);
|
||||
Ok(Request::from_parts(parts, body))
|
||||
}
|
||||
Err(err) => {
|
||||
let unauthorized_response = Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(boxed(Body::from(err.to_string())))
|
||||
.unwrap();
|
||||
Err(unauthorized_response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
|
||||
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
|
||||
if !or.is_null() {
|
||||
Some(or.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
|
||||
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
|
||||
tracing::error!("Error extracting the app state: {:?}", err);
|
||||
Error::InvalidAuth
|
||||
})?;
|
||||
let ExtractClientIP(ip) =
|
||||
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
|
||||
|
||||
// Create session
|
||||
#[rustfmt::skip]
|
||||
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
|
||||
|
||||
// If Basic authentication data was supplied
|
||||
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
|
||||
basic(&mut session, au.username(), au.password()).await?;
|
||||
};
|
||||
|
||||
// If Token authentication data was supplied
|
||||
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
|
||||
token(kvs, &mut session, au.token().into()).await?;
|
||||
};
|
||||
|
||||
Ok(session)
|
||||
}
|
|
@ -1,10 +1,21 @@
|
|||
use crate::cli::CF;
|
||||
use axum::async_trait;
|
||||
use axum::extract::ConnectInfo;
|
||||
use axum::extract::FromRef;
|
||||
use axum::extract::FromRequestParts;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use axum::Extension;
|
||||
use axum::RequestPartsExt;
|
||||
use clap::ValueEnum;
|
||||
use std::net::IpAddr;
|
||||
use http::request::Parts;
|
||||
use http::Request;
|
||||
use http::StatusCode;
|
||||
use std::net::SocketAddr;
|
||||
use warp::Filter;
|
||||
|
||||
use super::AppState;
|
||||
|
||||
// TODO: Support Forwarded, X-Forwarded-For headers.
|
||||
// Get inspiration from https://github.com/imbolc/axum-client-ip or simply use it
|
||||
#[derive(ValueEnum, Clone, Copy, Debug)]
|
||||
pub enum ClientIp {
|
||||
/// Don't use client IP
|
||||
|
@ -25,31 +36,95 @@ pub enum ClientIp {
|
|||
XRealIp,
|
||||
}
|
||||
|
||||
/// Creates an string represenation of the client's IP address
|
||||
pub fn build() -> impl Filter<Extract = (Option<String>,), Error = warp::Rejection> + Clone {
|
||||
// Get configured client IP source
|
||||
let client_ip = CF.get().unwrap().client_ip;
|
||||
// Enable on any path
|
||||
let conf = warp::any();
|
||||
// Add raw remote IP address
|
||||
let conf =
|
||||
conf.and(warp::filters::addr::remote().and_then(move |s: Option<SocketAddr>| async move {
|
||||
match client_ip {
|
||||
ClientIp::None => Ok(None),
|
||||
ClientIp::Socket => Ok(s.map(|s| s.ip())),
|
||||
// Move on to parsing selected IP header.
|
||||
_ => Err(warp::reject::reject()),
|
||||
}
|
||||
}));
|
||||
// Add selected IP header
|
||||
let conf = conf.or(warp::header::optional::<IpAddr>(match client_ip {
|
||||
ClientIp::CfConectingIp => "Cf-Connecting-IP",
|
||||
ClientIp::FlyClientIp => "Fly-Client-IP",
|
||||
ClientIp::TrueClientIP => "True-Client-IP",
|
||||
ClientIp::XRealIp => "X-Real-IP",
|
||||
// none and socket are already handled so this will never be used
|
||||
_ => "unreachable",
|
||||
}));
|
||||
// Join the two filters
|
||||
conf.unify().map(|ip: Option<IpAddr>| ip.map(|ip| ip.to_string()))
|
||||
impl std::fmt::Display for ClientIp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ClientIp::None => write!(f, "None"),
|
||||
ClientIp::Socket => write!(f, "Socket"),
|
||||
ClientIp::CfConectingIp => write!(f, "CF-Connecting-IP"),
|
||||
ClientIp::FlyClientIp => write!(f, "Fly-Client-IP"),
|
||||
ClientIp::TrueClientIP => write!(f, "True-Client-IP"),
|
||||
ClientIp::XRealIp => write!(f, "X-Real-IP"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientIp {
|
||||
fn is_header(&self) -> bool {
|
||||
match self {
|
||||
ClientIp::None => false,
|
||||
ClientIp::Socket => false,
|
||||
ClientIp::CfConectingIp => true,
|
||||
ClientIp::FlyClientIp => true,
|
||||
ClientIp::TrueClientIP => true,
|
||||
ClientIp::XRealIp => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct ExtractClientIP(pub Option<String>);
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for ExtractClientIP
|
||||
where
|
||||
AppState: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = (StatusCode, &'static str);
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let app_state = AppState::from_ref(state);
|
||||
|
||||
let res = match app_state.client_ip {
|
||||
ClientIp::None => ExtractClientIP(None),
|
||||
ClientIp::Socket => {
|
||||
if let Ok(ConnectInfo(addr)) =
|
||||
ConnectInfo::<SocketAddr>::from_request_parts(parts, state).await
|
||||
{
|
||||
ExtractClientIP(Some(addr.ip().to_string()))
|
||||
} else {
|
||||
ExtractClientIP(None)
|
||||
}
|
||||
}
|
||||
// Get the IP from the corresponding header
|
||||
var if var.is_header() => {
|
||||
if let Some(ip) = parts.headers.get(var.to_string()) {
|
||||
ip.to_str().map(|s| ExtractClientIP(Some(s.to_string()))).unwrap_or_else(
|
||||
|err| {
|
||||
debug!("Invalid header value for {}: {}", var, err);
|
||||
ExtractClientIP(None)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
ExtractClientIP(None)
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected ClientIp variant: {:?}", app_state.client_ip);
|
||||
ExtractClientIP(None)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn client_ip_middleware<B>(
|
||||
request: Request<B>,
|
||||
next: Next<B>,
|
||||
) -> Result<Response, StatusCode>
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
if let Ok(Extension(state)) = parts.extract::<Extension<AppState>>().await {
|
||||
if let Ok(client_ip) = parts.extract_with_state::<ExtractClientIP, AppState>(&state).await {
|
||||
parts.extensions.insert(client_ip);
|
||||
}
|
||||
} else {
|
||||
trace!("No AppState found, skipping client_ip_middleware");
|
||||
}
|
||||
|
||||
Ok(next.run(Request::from_parts(parts, body)).await)
|
||||
}
|
||||
|
|
|
@ -1,21 +1,25 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::net::session;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use axum::{response::Response, Extension};
|
||||
use bytes::Bytes;
|
||||
use http::StatusCode;
|
||||
use http_body::Body as HttpBody;
|
||||
use hyper::body::Body;
|
||||
use surrealdb::dbs::Session;
|
||||
use warp::Filter;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path("export")
|
||||
.and(warp::path::end())
|
||||
.and(warp::get())
|
||||
.and(session::build())
|
||||
.and_then(handler)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new().route("/export", get(handler))
|
||||
}
|
||||
|
||||
async fn handler(session: Session) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
async fn handler(
|
||||
Extension(session): Extension<Session>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Check the permissions
|
||||
match session.au.is_db() {
|
||||
true => {
|
||||
|
@ -24,12 +28,12 @@ async fn handler(session: Session) -> Result<impl warp::Reply, warp::Rejection>
|
|||
// Extract the NS header value
|
||||
let nsv = match session.ns {
|
||||
Some(ns) => ns,
|
||||
None => return Err(warp::reject::custom(Error::NoNsHeader)),
|
||||
None => return Err((StatusCode::BAD_REQUEST, "No namespace provided")),
|
||||
};
|
||||
// Extract the DB header value
|
||||
let dbv = match session.db {
|
||||
Some(db) => db,
|
||||
None => return Err(warp::reject::custom(Error::NoDbHeader)),
|
||||
None => return Err((StatusCode::BAD_REQUEST, "No database provided")),
|
||||
};
|
||||
// Create a chunked response
|
||||
let (mut chn, bdy) = Body::channel();
|
||||
|
@ -44,9 +48,9 @@ async fn handler(session: Session) -> Result<impl warp::Reply, warp::Rejection>
|
|||
}
|
||||
});
|
||||
// Return the chunked body
|
||||
Ok(warp::reply::Response::new(bdy))
|
||||
Ok(Response::builder().status(StatusCode::OK).body(bdy).unwrap())
|
||||
}
|
||||
// There was an error with permissions
|
||||
_ => Err(warp::reject::custom(Error::InvalidAuth)),
|
||||
// The user does not have the correct permissions
|
||||
_ => Err((StatusCode::FORBIDDEN, "Invalid permissions")),
|
||||
}
|
||||
}
|
||||
|
|
126
src/net/fail.rs
126
src/net/fail.rs
|
@ -1,126 +0,0 @@
|
|||
use crate::err::Error;
|
||||
use serde::Serialize;
|
||||
use warp::http::StatusCode;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Message {
|
||||
code: u16,
|
||||
details: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
information: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn recover(err: warp::Rejection) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
if let Some(err) = err.find::<Error>() {
|
||||
match err {
|
||||
Error::InvalidAuth => Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 403,
|
||||
details: Some("Authentication failed".to_string()),
|
||||
description: Some("Your authentication details are invalid. Reauthenticate using valid authentication parameters.".to_string()),
|
||||
information: Some(err.to_string()),
|
||||
}),
|
||||
StatusCode::FORBIDDEN,
|
||||
)),
|
||||
Error::InvalidType => Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 415,
|
||||
details: Some("Unsupported media type".to_string()),
|
||||
description: Some("The request needs to adhere to certain constraints. Refer to the documentation for supported content types.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::UNSUPPORTED_MEDIA_TYPE,
|
||||
)),
|
||||
Error::InvalidStorage => Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 500,
|
||||
details: Some("Health check failed".to_string()),
|
||||
description: Some("The database health check for this instance failed. There was an issue with the underlying storage engine.".to_string()),
|
||||
information: Some(err.to_string()),
|
||||
}),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
)),
|
||||
_ => Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 400,
|
||||
details: Some("Request problems detected".to_string()),
|
||||
description: Some("There is a problem with your request. Refer to the documentation for further information.".to_string()),
|
||||
information: Some(err.to_string()),
|
||||
}),
|
||||
StatusCode::BAD_REQUEST,
|
||||
))
|
||||
}
|
||||
} else if err.is_not_found() {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 404,
|
||||
details: Some("Requested resource not found".to_string()),
|
||||
description: Some("The requested resource does not exist. Check that you have entered the url correctly.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::NOT_FOUND,
|
||||
))
|
||||
} else if err.find::<warp::reject::MissingHeader>().is_some() {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 412,
|
||||
details: Some("Request problems detected".to_string()),
|
||||
description: Some("The request appears to be missing a required header. Refer to the documentation for request requirements.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::PRECONDITION_FAILED,
|
||||
))
|
||||
} else if err.find::<warp::reject::PayloadTooLarge>().is_some() {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 413,
|
||||
details: Some("Payload too large".to_string()),
|
||||
description: Some("The request has exceeded the maximum payload size. Refer to the documentation for the request limitations.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
))
|
||||
} else if err.find::<warp::reject::InvalidQuery>().is_some() {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 501,
|
||||
details: Some("Not implemented".to_string()),
|
||||
description: Some("The server either does not recognize the query, or it lacks the ability to fulfill the request.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
))
|
||||
} else if err.find::<warp::reject::InvalidHeader>().is_some() {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 501,
|
||||
details: Some("Not implemented".to_string()),
|
||||
description: Some("The server either does not recognize a request header, or it lacks the ability to fulfill the request.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
))
|
||||
} else if err.find::<warp::reject::MethodNotAllowed>().is_some() {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 405,
|
||||
details: Some("Requested method not allowed".to_string()),
|
||||
description: Some("The requested http method is not allowed for this resource. Refer to the documentation for allowed methods.".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::METHOD_NOT_ALLOWED,
|
||||
))
|
||||
} else {
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&Message {
|
||||
code: 500,
|
||||
details: Some("Internal server error".to_string()),
|
||||
description: Some("There was a problem with our servers, and we have been notified. Refer to the documentation for further information".to_string()),
|
||||
information: None,
|
||||
}),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
))
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
use crate::cnf::PKG_NAME;
|
||||
use crate::cnf::PKG_VERSION;
|
||||
use surrealdb::cnf::SERVER_NAME;
|
||||
|
||||
const ID: &str = "ID";
|
||||
const NS: &str = "NS";
|
||||
const DB: &str = "DB";
|
||||
const SERVER: &str = "Server";
|
||||
const VERSION: &str = "Version";
|
||||
|
||||
pub fn version() -> warp::filters::reply::WithHeader {
|
||||
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
|
||||
warp::reply::with::header(VERSION, val)
|
||||
}
|
||||
|
||||
pub fn server() -> warp::filters::reply::WithHeader {
|
||||
warp::reply::with::header(SERVER, SERVER_NAME)
|
||||
}
|
||||
|
||||
pub fn cors() -> warp::filters::cors::Builder {
|
||||
warp::cors()
|
||||
.max_age(86400)
|
||||
.allow_any_origin()
|
||||
.allow_methods(vec![
|
||||
http::Method::GET,
|
||||
http::Method::PUT,
|
||||
http::Method::POST,
|
||||
http::Method::PATCH,
|
||||
http::Method::DELETE,
|
||||
http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers(vec![
|
||||
http::header::ACCEPT,
|
||||
http::header::AUTHORIZATION,
|
||||
http::header::CONTENT_TYPE,
|
||||
http::header::ORIGIN,
|
||||
NS.parse().unwrap(),
|
||||
DB.parse().unwrap(),
|
||||
ID.parse().unwrap(),
|
||||
])
|
||||
}
|
95
src/net/headers.rs
Normal file
95
src/net/headers.rs
Normal file
|
@ -0,0 +1,95 @@
|
|||
use crate::cnf::PKG_NAME;
|
||||
use crate::cnf::PKG_VERSION;
|
||||
use axum::headers;
|
||||
use axum::headers::Header;
|
||||
use http::HeaderName;
|
||||
use http::HeaderValue;
|
||||
use surrealdb::cnf::SERVER_NAME;
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
|
||||
pub(super) const ID: &str = "ID";
|
||||
pub(super) const NS: &str = "NS";
|
||||
pub(super) const DB: &str = "DB";
|
||||
const SERVER: &str = "server";
|
||||
const VERSION: &str = "version";
|
||||
|
||||
pub fn add_version_header() -> SetResponseHeaderLayer<HeaderValue> {
|
||||
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
|
||||
SetResponseHeaderLayer::if_not_present(
|
||||
HeaderName::from_static(VERSION),
|
||||
HeaderValue::try_from(val).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn add_server_header() -> SetResponseHeaderLayer<HeaderValue> {
|
||||
SetResponseHeaderLayer::if_not_present(
|
||||
HeaderName::from_static(SERVER),
|
||||
HeaderValue::try_from(SERVER_NAME).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Typed header implementation for the `Accept` header.
|
||||
pub enum Accept {
|
||||
TextPlain,
|
||||
ApplicationJson,
|
||||
ApplicationCbor,
|
||||
ApplicationPack,
|
||||
ApplicationOctetStream,
|
||||
Surrealdb,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Accept {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Accept::TextPlain => write!(f, "text/plain"),
|
||||
Accept::ApplicationJson => write!(f, "application/json"),
|
||||
Accept::ApplicationCbor => write!(f, "application/cbor"),
|
||||
Accept::ApplicationPack => write!(f, "application/pack"),
|
||||
Accept::ApplicationOctetStream => write!(f, "application/octet-stream"),
|
||||
Accept::Surrealdb => write!(f, "application/surrealdb"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Header for Accept {
|
||||
fn name() -> &'static HeaderName {
|
||||
&http::header::ACCEPT
|
||||
}
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
|
||||
where
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
{
|
||||
let value = values.next().ok_or_else(headers::Error::invalid)?;
|
||||
|
||||
match value.to_str().map_err(|_| headers::Error::invalid())? {
|
||||
"text/plain" => Ok(Accept::TextPlain),
|
||||
"application/json" => Ok(Accept::ApplicationJson),
|
||||
"application/cbor" => Ok(Accept::ApplicationCbor),
|
||||
"application/pack" => Ok(Accept::ApplicationPack),
|
||||
"application/octet-stream" => Ok(Accept::ApplicationOctetStream),
|
||||
"application/surrealdb" => Ok(Accept::Surrealdb),
|
||||
// TODO: Support more (all?) mime-types
|
||||
_ => Err(headers::Error::invalid()),
|
||||
}
|
||||
}
|
||||
|
||||
fn encode<E>(&self, values: &mut E)
|
||||
where
|
||||
E: Extend<HeaderValue>,
|
||||
{
|
||||
values.extend(std::iter::once(self.into()));
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Accept> for HeaderValue {
|
||||
fn from(value: Accept) -> Self {
|
||||
HeaderValue::from(&value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Accept> for HeaderValue {
|
||||
fn from(value: &Accept) -> Self {
|
||||
HeaderValue::from_str(value.to_string().as_str()).unwrap()
|
||||
}
|
||||
}
|
|
@ -1,25 +1,31 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use warp::Filter;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use http_body::Body as HttpBody;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path("health").and(warp::path::end()).and(warp::get()).and_then(handler)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new().route("/health", get(handler))
|
||||
}
|
||||
|
||||
async fn handler() -> Result<impl warp::Reply, warp::Rejection> {
|
||||
async fn handler() -> impl IntoResponse {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Attempt to open a transaction
|
||||
match db.transaction(false, false).await {
|
||||
// The transaction failed to start
|
||||
Err(_) => Err(warp::reject::custom(Error::InvalidStorage)),
|
||||
Err(_) => Err(Error::InvalidStorage),
|
||||
// The transaction was successful
|
||||
Ok(mut tx) => {
|
||||
// Cancel the transaction
|
||||
let _ = tx.cancel().await;
|
||||
// Return the response
|
||||
Ok(warp::reply())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,31 +2,39 @@ use crate::dbs::DB;
|
|||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
use crate::net::session;
|
||||
use axum::extract::DefaultBodyLimit;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::post;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use axum::TypedHeader;
|
||||
use bytes::Bytes;
|
||||
use http_body::Body as HttpBody;
|
||||
use surrealdb::dbs::Session;
|
||||
use warp::http;
|
||||
use warp::Filter;
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
const MAX: u64 = 1024 * 1024 * 1024 * 4; // 4 GiB
|
||||
use super::headers::Accept;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path("import")
|
||||
.and(warp::path::end())
|
||||
.and(warp::post())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(session::build())
|
||||
.and_then(handler)
|
||||
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
|
||||
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new()
|
||||
.route("/import", post(handler))
|
||||
.route_layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(MAX))
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
output: String,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
sql: Bytes,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Check the permissions
|
||||
match session.au.is_db() {
|
||||
true => {
|
||||
|
@ -36,22 +44,22 @@ async fn handler(
|
|||
let sql = bytes_to_utf8(&sql)?;
|
||||
// Execute the sql query in the database
|
||||
match db.execute(sql, &session, None).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// Return nothing
|
||||
"application/octet-stream" => Ok(output::none()),
|
||||
Some(Accept::ApplicationOctetStream) => Ok(output::none()),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
_ => Err(warp::reject::custom(Error::InvalidAuth)),
|
||||
_ => Err(Error::InvalidAuth),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
use crate::cnf;
|
||||
use warp::http::Uri;
|
||||
use warp::Filter;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path::end()
|
||||
.and(warp::get())
|
||||
.map(|| warp::redirect::temporary(Uri::from_static(cnf::APP_ENDPOINT)))
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use crate::err::Error;
|
||||
use bytes::Bytes;
|
||||
|
||||
pub(crate) fn bytes_to_utf8(bytes: &Bytes) -> Result<&str, warp::Rejection> {
|
||||
std::str::from_utf8(bytes).map_err(|_| warp::reject::custom(Error::Request))
|
||||
pub(crate) fn bytes_to_utf8(bytes: &Bytes) -> Result<&str, Error> {
|
||||
std::str::from_utf8(bytes).map_err(|_| Error::Request)
|
||||
}
|
||||
|
|
426
src/net/key.rs
426
src/net/key.rs
|
@ -2,145 +2,62 @@ use crate::dbs::DB;
|
|||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
use crate::net::params::{Param, Params};
|
||||
use crate::net::session;
|
||||
use crate::net::params::Params;
|
||||
use axum::extract::{DefaultBodyLimit, Path, Query};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::options;
|
||||
use axum::{Extension, Router, TypedHeader};
|
||||
use bytes::Bytes;
|
||||
use http_body::Body as HttpBody;
|
||||
use serde::Deserialize;
|
||||
use std::str;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::sql::Value;
|
||||
use warp::path;
|
||||
use warp::Filter;
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
const MAX: u64 = 1024 * 16; // 16 KiB
|
||||
use super::headers::Accept;
|
||||
|
||||
const MAX: usize = 1024 * 16; // 16 KiB
|
||||
|
||||
#[derive(Default, Deserialize, Debug, Clone)]
|
||||
struct Query {
|
||||
struct QueryOptions {
|
||||
pub limit: Option<String>,
|
||||
pub start: Option<String>,
|
||||
}
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
// ------------------------------
|
||||
// Routes for OPTIONS
|
||||
// ------------------------------
|
||||
|
||||
let base = warp::path("key");
|
||||
// Set opts method
|
||||
let opts = base.and(warp::options()).map(warp::reply);
|
||||
|
||||
// ------------------------------
|
||||
// Routes for a table
|
||||
// ------------------------------
|
||||
|
||||
// Set select method
|
||||
let select = warp::any()
|
||||
.and(warp::get())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param).and(warp::path::end()))
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(select_all);
|
||||
// Set create method
|
||||
let create = warp::any()
|
||||
.and(warp::post())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param).and(warp::path::end()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(create_all);
|
||||
// Set update method
|
||||
let update = warp::any()
|
||||
.and(warp::put())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param).and(warp::path::end()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(update_all);
|
||||
// Set modify method
|
||||
let modify = warp::any()
|
||||
.and(warp::patch())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param).and(warp::path::end()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(modify_all);
|
||||
// Set delete method
|
||||
let delete = warp::any()
|
||||
.and(warp::delete())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param).and(warp::path::end()))
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(delete_all);
|
||||
// Specify route
|
||||
let all = select.or(create).or(update).or(modify).or(delete);
|
||||
|
||||
// ------------------------------
|
||||
// Routes for a thing
|
||||
// ------------------------------
|
||||
|
||||
// Set select method
|
||||
let select = warp::any()
|
||||
.and(warp::get())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param / Param).and(warp::path::end()))
|
||||
.and(session::build())
|
||||
.and_then(select_one);
|
||||
// Set create method
|
||||
let create = warp::any()
|
||||
.and(warp::post())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param / Param).and(warp::path::end()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(create_one);
|
||||
// Set update method
|
||||
let update = warp::any()
|
||||
.and(warp::put())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param / Param).and(warp::path::end()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(update_one);
|
||||
// Set modify method
|
||||
let modify = warp::any()
|
||||
.and(warp::patch())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param / Param).and(warp::path::end()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(modify_one);
|
||||
// Set delete method
|
||||
let delete = warp::any()
|
||||
.and(warp::delete())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(path!("key" / Param / Param).and(warp::path::end()))
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(delete_one);
|
||||
// Specify route
|
||||
let one = select.or(create).or(update).or(modify).or(delete);
|
||||
|
||||
// ------------------------------
|
||||
// All routes
|
||||
// ------------------------------
|
||||
|
||||
// Specify route
|
||||
opts.or(all).or(one)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new()
|
||||
.route(
|
||||
"/key/:table",
|
||||
options(|| async {})
|
||||
.get(select_all)
|
||||
.post(create_all)
|
||||
.put(update_all)
|
||||
.patch(modify_all)
|
||||
.delete(delete_all),
|
||||
)
|
||||
.route_layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(MAX))
|
||||
.merge(
|
||||
Router::new()
|
||||
.route(
|
||||
"/key/:table/:key",
|
||||
options(|| async {})
|
||||
.get(select_one)
|
||||
.post(create_one)
|
||||
.put(update_one)
|
||||
.patch(modify_one)
|
||||
.delete(delete_one),
|
||||
)
|
||||
.route_layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(MAX)),
|
||||
)
|
||||
}
|
||||
|
||||
// ------------------------------
|
||||
|
@ -148,11 +65,11 @@ pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejecti
|
|||
// ------------------------------
|
||||
|
||||
async fn select_all(
|
||||
output: String,
|
||||
table: Param,
|
||||
query: Query,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(query): Query<QueryOptions>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Specify the request statement
|
||||
|
@ -167,28 +84,28 @@ async fn select_all(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql.as_str(), &session, Some(vars)).await {
|
||||
Ok(ref res) => match output.as_ref() {
|
||||
Ok(ref res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_all(
|
||||
output: String,
|
||||
table: Param,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<Params>,
|
||||
body: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the HTTP request body
|
||||
|
@ -206,31 +123,31 @@ async fn create_all(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
Err(_) => Err(warp::reject::custom(Error::Request)),
|
||||
Err(_) => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_all(
|
||||
output: String,
|
||||
table: Param,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<Params>,
|
||||
body: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the HTTP request body
|
||||
|
@ -248,31 +165,31 @@ async fn update_all(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
Err(_) => Err(warp::reject::custom(Error::Request)),
|
||||
Err(_) => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
||||
async fn modify_all(
|
||||
output: String,
|
||||
table: Param,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<Params>,
|
||||
body: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the HTTP request body
|
||||
|
@ -290,30 +207,30 @@ async fn modify_all(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
Err(_) => Err(warp::reject::custom(Error::Request)),
|
||||
Err(_) => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_all(
|
||||
output: String,
|
||||
table: Param,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Specify the request statement
|
||||
|
@ -325,18 +242,18 @@ async fn delete_all(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -345,11 +262,10 @@ async fn delete_all(
|
|||
// ------------------------------
|
||||
|
||||
async fn select_one(
|
||||
output: String,
|
||||
table: Param,
|
||||
id: Param,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Specify the request statement
|
||||
|
@ -366,29 +282,28 @@ async fn select_one(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_one(
|
||||
output: String,
|
||||
table: Param,
|
||||
id: Param,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Query(params): Query<Params>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
body: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the HTTP request body
|
||||
|
@ -412,32 +327,31 @@ async fn create_one(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
Err(_) => Err(warp::reject::custom(Error::Request)),
|
||||
Err(_) => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_one(
|
||||
output: String,
|
||||
table: Param,
|
||||
id: Param,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Query(params): Query<Params>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
body: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the HTTP request body
|
||||
|
@ -461,32 +375,31 @@ async fn update_one(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
Err(_) => Err(warp::reject::custom(Error::Request)),
|
||||
Err(_) => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
||||
async fn modify_one(
|
||||
output: String,
|
||||
table: Param,
|
||||
id: Param,
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Query(params): Query<Params>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
body: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the HTTP request body
|
||||
|
@ -510,31 +423,29 @@ async fn modify_one(
|
|||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
Err(_) => Err(warp::reject::custom(Error::Request)),
|
||||
Err(_) => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_one(
|
||||
output: String,
|
||||
table: Param,
|
||||
id: Param,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
Extension(session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
// Specify the request statement
|
||||
|
@ -548,21 +459,20 @@ async fn delete_one(
|
|||
let vars = map! {
|
||||
String::from("table") => Value::from(table),
|
||||
String::from("id") => rid,
|
||||
=> params.parse()
|
||||
};
|
||||
// Execute the query and return the result
|
||||
match db.execute(sql, &session, Some(vars)).await {
|
||||
Ok(res) => match output.as_ref() {
|
||||
Ok(res) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
use std::fmt;
|
||||
use tracing::Level;
|
||||
|
||||
struct OptFmt<T>(Option<T>);
|
||||
|
||||
impl<T: fmt::Display> fmt::Display for OptFmt<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if let Some(ref t) = self.0 {
|
||||
fmt::Display::fmt(t, f)
|
||||
} else {
|
||||
f.write_str("-")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write() -> warp::filters::log::Log<impl Fn(warp::filters::log::Info) + Copy> {
|
||||
warp::log::custom(|info| {
|
||||
event!(
|
||||
Level::INFO,
|
||||
"{} {} {} {:?} {} \"{}\" {:?}",
|
||||
OptFmt(info.remote_addr()),
|
||||
info.method(),
|
||||
info.path(),
|
||||
info.version(),
|
||||
info.status().as_u16(),
|
||||
OptFmt(info.user_agent()),
|
||||
info.elapsed(),
|
||||
);
|
||||
})
|
||||
}
|
209
src/net/mod.rs
209
src/net/mod.rs
|
@ -1,105 +1,164 @@
|
|||
mod auth;
|
||||
pub mod client_ip;
|
||||
mod export;
|
||||
mod fail;
|
||||
mod head;
|
||||
mod headers;
|
||||
mod health;
|
||||
mod import;
|
||||
mod index;
|
||||
mod input;
|
||||
mod key;
|
||||
mod log;
|
||||
mod output;
|
||||
mod params;
|
||||
mod rpc;
|
||||
mod session;
|
||||
mod signals;
|
||||
mod signin;
|
||||
mod signup;
|
||||
mod sql;
|
||||
mod status;
|
||||
mod sync;
|
||||
mod tracer;
|
||||
mod version;
|
||||
|
||||
use axum::response::Redirect;
|
||||
use axum::routing::get;
|
||||
use axum::{middleware, Router};
|
||||
use axum_server::Handle;
|
||||
use http::header;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::add_extension::AddExtensionLayer;
|
||||
use tower_http::auth::AsyncRequireAuthorizationLayer;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::request_id::MakeRequestUuid;
|
||||
use tower_http::sensitive_headers::{
|
||||
SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer,
|
||||
};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::ServiceBuilderExt;
|
||||
|
||||
use crate::cli::CF;
|
||||
use crate::cnf;
|
||||
use crate::err::Error;
|
||||
use warp::Filter;
|
||||
use crate::net::signals::graceful_shutdown;
|
||||
use crate::telemetry::metrics::HttpMetricsLayer;
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
|
||||
const LOG: &str = "surrealdb::net";
|
||||
|
||||
///
|
||||
/// AppState is used to share data between routes.
|
||||
///
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
client_ip: client_ip::ClientIp,
|
||||
}
|
||||
|
||||
pub async fn init() -> Result<(), Error> {
|
||||
// Setup web routes
|
||||
let net = index::config()
|
||||
// Version endpoint
|
||||
.or(version::config())
|
||||
// Status endpoint
|
||||
.or(status::config())
|
||||
// Health endpoint
|
||||
.or(health::config())
|
||||
// Signup endpoint
|
||||
.or(signup::config())
|
||||
// Signin endpoint
|
||||
.or(signin::config())
|
||||
// Export endpoint
|
||||
.or(export::config())
|
||||
// Import endpoint
|
||||
.or(import::config())
|
||||
// Backup endpoint
|
||||
.or(sync::config())
|
||||
// RPC query endpoint
|
||||
.or(rpc::config())
|
||||
// SQL query endpoint
|
||||
.or(sql::config())
|
||||
// API query endpoint
|
||||
.or(key::config())
|
||||
// Catch all errors
|
||||
.recover(fail::recover)
|
||||
// End routes setup
|
||||
;
|
||||
// Specify a generic version header
|
||||
let net = net.with(head::version());
|
||||
// Specify a generic server header
|
||||
let net = net.with(head::server());
|
||||
// Set cors headers on all requests
|
||||
let net = net.with(head::cors());
|
||||
// Log all requests to the console
|
||||
let net = net.with(log::write());
|
||||
// Trace requests
|
||||
let net = net.with(warp::trace::request());
|
||||
|
||||
// Get local copy of options
|
||||
let opt = CF.get().unwrap();
|
||||
|
||||
let app_state = AppState {
|
||||
client_ip: opt.client_ip,
|
||||
};
|
||||
|
||||
// Specify headers to be obfuscated from all requests/responses
|
||||
let headers: Arc<[_]> = Arc::new([
|
||||
header::AUTHORIZATION,
|
||||
header::PROXY_AUTHORIZATION,
|
||||
header::COOKIE,
|
||||
header::SET_COOKIE,
|
||||
]);
|
||||
|
||||
// Build the middleware to our service.
|
||||
let service = ServiceBuilder::new()
|
||||
.catch_panic()
|
||||
.set_x_request_id(MakeRequestUuid)
|
||||
.propagate_x_request_id()
|
||||
.layer(AddExtensionLayer::new(app_state))
|
||||
.layer(middleware::from_fn(client_ip::client_ip_middleware))
|
||||
.layer(SetSensitiveRequestHeadersLayer::from_shared(Arc::clone(&headers)))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(tracer::HttpTraceLayerHooks)
|
||||
.on_request(tracer::HttpTraceLayerHooks)
|
||||
.on_response(tracer::HttpTraceLayerHooks)
|
||||
.on_failure(tracer::HttpTraceLayerHooks),
|
||||
)
|
||||
.layer(HttpMetricsLayer)
|
||||
.layer(SetSensitiveResponseHeadersLayer::from_shared(headers))
|
||||
.layer(AsyncRequireAuthorizationLayer::new(auth::SurrealAuth))
|
||||
.layer(headers::add_server_header())
|
||||
.layer(headers::add_version_header())
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_methods([
|
||||
http::Method::GET,
|
||||
http::Method::PUT,
|
||||
http::Method::POST,
|
||||
http::Method::PATCH,
|
||||
http::Method::DELETE,
|
||||
http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers([
|
||||
http::header::ACCEPT,
|
||||
http::header::AUTHORIZATION,
|
||||
http::header::CONTENT_TYPE,
|
||||
http::header::ORIGIN,
|
||||
headers::NS.parse().unwrap(),
|
||||
headers::DB.parse().unwrap(),
|
||||
headers::ID.parse().unwrap(),
|
||||
])
|
||||
// allow requests from any origin
|
||||
.allow_origin(Any)
|
||||
.max_age(Duration::from_secs(86400)),
|
||||
);
|
||||
|
||||
let axum_app = Router::new()
|
||||
// Redirect until we provide a UI
|
||||
.route("/", get(|| async { Redirect::temporary(cnf::APP_ENDPOINT) }))
|
||||
.route("/status", get(|| async {}))
|
||||
.merge(health::router())
|
||||
.merge(export::router())
|
||||
.merge(import::router())
|
||||
.merge(rpc::router())
|
||||
.merge(version::router())
|
||||
.merge(sync::router())
|
||||
.merge(sql::router())
|
||||
.merge(signin::router())
|
||||
.merge(signup::router())
|
||||
.merge(key::router())
|
||||
.layer(service);
|
||||
|
||||
info!("Starting web server on {}", &opt.bind);
|
||||
|
||||
if let (Some(c), Some(k)) = (&opt.crt, &opt.key) {
|
||||
// Bind the server to the desired port
|
||||
let (adr, srv) = warp::serve(net)
|
||||
.tls()
|
||||
.cert_path(c)
|
||||
.key_path(k)
|
||||
.bind_with_graceful_shutdown(opt.bind, async move {
|
||||
// Capture the shutdown signals and log that the graceful shutdown has started
|
||||
let result = signals::listen().await.expect("Failed to listen to shutdown signal");
|
||||
info!("{} received. Start graceful shutdown...", result);
|
||||
});
|
||||
// Log the server startup status
|
||||
info!("Started web server on {}", &adr);
|
||||
// Run the server forever
|
||||
srv.await;
|
||||
// Log the server shutdown event
|
||||
info!("Shutdown complete. Bye!")
|
||||
// Setup the graceful shutdown with no timeout
|
||||
let handle = Handle::new();
|
||||
graceful_shutdown(handle.clone(), None);
|
||||
|
||||
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
|
||||
// configure certificate and private key used by https
|
||||
let tls = RustlsConfig::from_pem_file(cert, key).await.unwrap();
|
||||
|
||||
let server = axum_server::bind_rustls(opt.bind, tls);
|
||||
|
||||
info!(target: LOG, "Started web server on {}", &opt.bind);
|
||||
|
||||
server
|
||||
.handle(handle)
|
||||
.serve(axum_app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await?;
|
||||
} else {
|
||||
// Bind the server to the desired port
|
||||
let (adr, srv) = warp::serve(net).bind_with_graceful_shutdown(opt.bind, async move {
|
||||
// Capture the shutdown signals and log that the graceful shutdown has started
|
||||
let result = signals::listen().await.expect("Failed to listen to shutdown signal");
|
||||
info!("{} received. Start graceful shutdown...", result);
|
||||
});
|
||||
// Log the server startup status
|
||||
info!("Started web server on {}", &adr);
|
||||
// Run the server forever
|
||||
srv.await;
|
||||
// Log the server shutdown event
|
||||
info!("Shutdown complete. Bye!")
|
||||
let server = axum_server::bind(opt.bind);
|
||||
|
||||
info!(target: LOG, "Started web server on {}", &opt.bind);
|
||||
|
||||
server
|
||||
.handle(handle)
|
||||
.serve(axum_app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await?;
|
||||
};
|
||||
|
||||
info!(target: LOG, "Web server stopped. Bye!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
use axum::response::{IntoResponse, Response};
|
||||
use http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use http::StatusCode;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as Json;
|
||||
use surrealdb::sql;
|
||||
|
||||
use super::headers::Accept;
|
||||
|
||||
pub enum Output {
|
||||
None,
|
||||
Fail,
|
||||
|
@ -67,38 +70,23 @@ pub fn simplify<T: Serialize>(v: T) -> Json {
|
|||
sql::to_value(v).unwrap().into()
|
||||
}
|
||||
|
||||
impl warp::Reply for Output {
|
||||
fn into_response(self) -> warp::reply::Response {
|
||||
impl IntoResponse for Output {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
Output::Text(v) => {
|
||||
let mut res = warp::reply::Response::new(v.into());
|
||||
let con = HeaderValue::from_static("text/plain");
|
||||
res.headers_mut().insert(CONTENT_TYPE, con);
|
||||
res
|
||||
([(CONTENT_TYPE, HeaderValue::from(Accept::TextPlain))], v).into_response()
|
||||
}
|
||||
Output::Json(v) => {
|
||||
let mut res = warp::reply::Response::new(v.into());
|
||||
let con = HeaderValue::from_static("application/json");
|
||||
res.headers_mut().insert(CONTENT_TYPE, con);
|
||||
res
|
||||
([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationJson))], v).into_response()
|
||||
}
|
||||
Output::Cbor(v) => {
|
||||
let mut res = warp::reply::Response::new(v.into());
|
||||
let con = HeaderValue::from_static("application/cbor");
|
||||
res.headers_mut().insert(CONTENT_TYPE, con);
|
||||
res
|
||||
([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationCbor))], v).into_response()
|
||||
}
|
||||
Output::Pack(v) => {
|
||||
let mut res = warp::reply::Response::new(v.into());
|
||||
let con = HeaderValue::from_static("application/pack");
|
||||
res.headers_mut().insert(CONTENT_TYPE, con);
|
||||
res
|
||||
([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationPack))], v).into_response()
|
||||
}
|
||||
Output::Full(v) => {
|
||||
let mut res = warp::reply::Response::new(v.into());
|
||||
let con = HeaderValue::from_static("application/surrealdb");
|
||||
res.headers_mut().insert(CONTENT_TYPE, con);
|
||||
res
|
||||
([(CONTENT_TYPE, HeaderValue::from(Accept::Surrealdb))], v).into_response()
|
||||
}
|
||||
Output::None => StatusCode::OK.into_response(),
|
||||
Output::Fail => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
|
|
|
@ -5,13 +5,16 @@ use crate::cnf::PKG_VERSION;
|
|||
use crate::cnf::WEBSOCKET_PING_FREQUENCY;
|
||||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::net::session;
|
||||
use crate::rpc::args::Take;
|
||||
use crate::rpc::paths::{ID, METHOD, PARAMS};
|
||||
use crate::rpc::res;
|
||||
use crate::rpc::res::Failure;
|
||||
use crate::rpc::res::Output;
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use http_body::Body as HttpBody;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
|
@ -27,8 +30,11 @@ use surrealdb::sql::Value;
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::instrument;
|
||||
use uuid::Uuid;
|
||||
use warp::ws::{Message, WebSocket, Ws};
|
||||
use warp::Filter;
|
||||
|
||||
use axum::{
|
||||
extract::ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
// Mapping of WebSocketID to WebSocket
|
||||
type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>;
|
||||
|
@ -38,17 +44,22 @@ type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
|||
static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
||||
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path("rpc")
|
||||
.and(warp::path::end())
|
||||
.and(warp::ws())
|
||||
.and(session::build())
|
||||
.map(|ws: Ws, session: Session| ws.on_upgrade(move |ws| socket(ws, session)))
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new().route("/rpc", get(handler))
|
||||
}
|
||||
|
||||
async fn socket(ws: WebSocket, session: Session) {
|
||||
let rpc = Rpc::new(session);
|
||||
async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension<Session>) -> impl IntoResponse {
|
||||
// finalize the upgrade process by returning upgrade callback.
|
||||
// we can customize the callback by sending additional info such as address.
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, sess))
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, sess: Session) {
|
||||
let rpc = Rpc::new(sess);
|
||||
Rpc::serve(rpc, ws).await
|
||||
}
|
||||
|
||||
|
@ -89,7 +100,7 @@ impl Rpc {
|
|||
let png = chn.clone();
|
||||
// The WebSocket has connected
|
||||
Rpc::connected(rpc.clone(), chn.clone()).await;
|
||||
// Send messages to the client
|
||||
// Send Ping messages to the client
|
||||
tokio::task::spawn(async move {
|
||||
// Create the interval ticker
|
||||
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
||||
|
@ -98,7 +109,7 @@ impl Rpc {
|
|||
// Wait for the timer
|
||||
interval.tick().await;
|
||||
// Create the ping message
|
||||
let msg = Message::ping(vec![]);
|
||||
let msg = Message::Ping(vec![]);
|
||||
// Send the message to the client
|
||||
if png.send(msg).await.is_err() {
|
||||
// Exit out of the loop
|
||||
|
@ -146,20 +157,18 @@ impl Rpc {
|
|||
while let Some(msg) = wrx.next().await {
|
||||
match msg {
|
||||
// We've received a message from the client
|
||||
// Ping is automatically handled by the WebSocket library
|
||||
Ok(msg) => match msg {
|
||||
msg if msg.is_ping() => {
|
||||
let _ = chn.send(Message::pong(vec![])).await;
|
||||
}
|
||||
msg if msg.is_text() => {
|
||||
Message::Text(_) => {
|
||||
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
|
||||
}
|
||||
msg if msg.is_binary() => {
|
||||
Message::Binary(_) => {
|
||||
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
|
||||
}
|
||||
msg if msg.is_close() => {
|
||||
Message::Close(_) => {
|
||||
break;
|
||||
}
|
||||
msg if msg.is_pong() => {
|
||||
Message::Pong(_) => {
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
|
@ -214,16 +223,14 @@ impl Rpc {
|
|||
// Parse the request
|
||||
let req = match msg {
|
||||
// This is a binary message
|
||||
m if m.is_binary() => {
|
||||
Message::Binary(val) => {
|
||||
// Use binary output
|
||||
out = Output::Full;
|
||||
// Deserialize the input
|
||||
Value::from(m.into_bytes())
|
||||
Value::from(val)
|
||||
}
|
||||
// This is a text message
|
||||
m if m.is_text() => {
|
||||
// This won't panic due to the check above
|
||||
let val = m.to_str().unwrap();
|
||||
Message::Text(ref val) => {
|
||||
// Parse the SurrealQL object
|
||||
match surrealdb::sql::value(val) {
|
||||
// The SurrealQL message parsed ok
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::iam::verify::basic;
|
||||
use crate::iam::BASIC;
|
||||
use crate::net::client_ip;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::iam::verify::token;
|
||||
use surrealdb::iam::TOKEN;
|
||||
use warp::Filter;
|
||||
|
||||
pub fn build() -> impl Filter<Extract = (Session,), Error = warp::Rejection> + Clone {
|
||||
// Enable on any path
|
||||
let conf = warp::any();
|
||||
// Add remote ip address
|
||||
let conf = conf.and(client_ip::build());
|
||||
// Add authorization header
|
||||
let conf = conf.and(warp::header::optional::<String>("authorization"));
|
||||
// Add http origin header
|
||||
let conf = conf.and(warp::header::optional::<String>("origin"));
|
||||
// Add session id header
|
||||
let conf = conf.and(warp::header::optional::<String>("id"));
|
||||
// Add namespace header
|
||||
let conf = conf.and(warp::header::optional::<String>("ns"));
|
||||
// Add database header
|
||||
let conf = conf.and(warp::header::optional::<String>("db"));
|
||||
// Process all headers
|
||||
conf.and_then(process)
|
||||
}
|
||||
|
||||
async fn process(
|
||||
ip: Option<String>,
|
||||
au: Option<String>,
|
||||
or: Option<String>,
|
||||
id: Option<String>,
|
||||
ns: Option<String>,
|
||||
db: Option<String>,
|
||||
) -> Result<Session, warp::Rejection> {
|
||||
let kvs = DB.get().unwrap();
|
||||
// Create session
|
||||
#[rustfmt::skip]
|
||||
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
|
||||
// Parse the authentication header
|
||||
match au {
|
||||
// Basic authentication data was supplied
|
||||
Some(auth) if auth.starts_with(BASIC) => basic(&mut session, auth).await,
|
||||
// Token authentication data was supplied
|
||||
Some(auth) if auth.starts_with(TOKEN) => {
|
||||
token(kvs, &mut session, auth).await.map_err(Error::from)
|
||||
}
|
||||
// Wrong authentication data was supplied
|
||||
Some(_) => Err(Error::InvalidAuth),
|
||||
// No authentication data was supplied
|
||||
None => Ok(()),
|
||||
}?;
|
||||
// Pass the authenticated session through
|
||||
Ok(session)
|
||||
}
|
|
@ -1,5 +1,19 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use axum_server::Handle;
|
||||
|
||||
use crate::err::Error;
|
||||
|
||||
/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received.
|
||||
pub fn graceful_shutdown(handle: Handle, dur: Option<Duration>) {
|
||||
tokio::spawn(async move {
|
||||
let result = listen().await.expect("Failed to listen to shutdown signal");
|
||||
info!(target: super::LOG, "{} received. Start graceful shutdown...", result);
|
||||
|
||||
handle.graceful_shutdown(dur)
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub async fn listen() -> Result<String, Error> {
|
||||
// Import the OS signals
|
||||
|
@ -11,7 +25,7 @@ pub async fn listen() -> Result<String, Error> {
|
|||
let mut sigterm = signal(SignalKind::terminate())?;
|
||||
// Listen and wait for the system signals
|
||||
tokio::select! {
|
||||
// Wait for a SIGQUIT signal
|
||||
// Wait for a SIGHUP signal
|
||||
_ = sighup.recv() => {
|
||||
Ok(String::from("SIGHUP"))
|
||||
}
|
||||
|
|
|
@ -2,16 +2,24 @@ use crate::dbs::DB;
|
|||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
use crate::net::session;
|
||||
use crate::net::CF;
|
||||
use axum::extract::DefaultBodyLimit;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::options;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use axum::TypedHeader;
|
||||
use bytes::Bytes;
|
||||
use http_body::Body as HttpBody;
|
||||
use serde::Serialize;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::opt::auth::Root;
|
||||
use surrealdb::sql::Value;
|
||||
use warp::Filter;
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
const MAX: u64 = 1024; // 1 KiB
|
||||
use super::headers::Accept;
|
||||
|
||||
const MAX: usize = 1024; // 1 KiB
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Success {
|
||||
|
@ -30,29 +38,24 @@ impl Success {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
// Set base path
|
||||
let base = warp::path("signin").and(warp::path::end());
|
||||
// Set opts method
|
||||
let opts = base.and(warp::options()).map(warp::reply);
|
||||
// Set post method
|
||||
let post = base
|
||||
.and(warp::post())
|
||||
.and(warp::header::optional::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(session::build())
|
||||
.and_then(handler);
|
||||
// Specify route
|
||||
opts.or(post)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new()
|
||||
.route("/signin", options(|| async {}).post(handler))
|
||||
.route_layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(MAX))
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
output: Option<String>,
|
||||
Extension(mut session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
body: Bytes,
|
||||
mut session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
// Get the config options
|
||||
|
@ -72,25 +75,25 @@ async fn handler(
|
|||
.map_err(Error::from)
|
||||
{
|
||||
// Authentication was successful
|
||||
Ok(v) => match output.as_deref() {
|
||||
Ok(v) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
||||
// Internal serialization
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
||||
// Text serialization
|
||||
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
||||
// Return nothing
|
||||
None => Ok(output::none()),
|
||||
// Text serialization
|
||||
Some("text/plain") => Ok(output::text(v.unwrap_or_default())),
|
||||
// Simple serialization
|
||||
Some("application/json") => Ok(output::json(&Success::new(v))),
|
||||
Some("application/cbor") => Ok(output::cbor(&Success::new(v))),
|
||||
Some("application/pack") => Ok(output::pack(&Success::new(v))),
|
||||
// Internal serialization
|
||||
Some("application/surrealdb") => Ok(output::full(&Success::new(v))),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error with authentication
|
||||
Err(e) => Err(warp::reject::custom(e)),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
// The provided value was not an object
|
||||
_ => Err(warp::reject::custom(Error::Request)),
|
||||
_ => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,14 +2,20 @@ use crate::dbs::DB;
|
|||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
use crate::net::session;
|
||||
use axum::extract::DefaultBodyLimit;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::options;
|
||||
use axum::{Extension, Router, TypedHeader};
|
||||
use bytes::Bytes;
|
||||
use http_body::Body as HttpBody;
|
||||
use serde::Serialize;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::sql::Value;
|
||||
use warp::Filter;
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
const MAX: u64 = 1024; // 1 KiB
|
||||
use super::headers::Accept;
|
||||
|
||||
const MAX: usize = 1024; // 1 KiB
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Success {
|
||||
|
@ -28,29 +34,24 @@ impl Success {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
// Set base path
|
||||
let base = warp::path("signup").and(warp::path::end());
|
||||
// Set opts method
|
||||
let opts = base.and(warp::options()).map(warp::reply);
|
||||
// Set post method
|
||||
let post = base
|
||||
.and(warp::post())
|
||||
.and(warp::header::optional::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(session::build())
|
||||
.and_then(handler);
|
||||
// Specify route
|
||||
opts.or(post)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new()
|
||||
.route("/signup", options(|| async {}).post(handler))
|
||||
.route_layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(MAX))
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
output: Option<String>,
|
||||
Extension(mut session): Extension<Session>,
|
||||
maybe_output: Option<TypedHeader<Accept>>,
|
||||
body: Bytes,
|
||||
mut session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
// Convert the HTTP body into text
|
||||
|
@ -62,25 +63,25 @@ async fn handler(
|
|||
match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from)
|
||||
{
|
||||
// Authentication was successful
|
||||
Ok(v) => match output.as_deref() {
|
||||
Ok(v) => match maybe_output.as_deref() {
|
||||
// Simple serialization
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
||||
// Internal serialization
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
||||
// Text serialization
|
||||
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
||||
// Return nothing
|
||||
None => Ok(output::none()),
|
||||
// Text serialization
|
||||
Some("text/plain") => Ok(output::text(v.unwrap_or_default())),
|
||||
// Simple serialization
|
||||
Some("application/json") => Ok(output::json(&Success::new(v))),
|
||||
Some("application/cbor") => Ok(output::cbor(&Success::new(v))),
|
||||
Some("application/pack") => Ok(output::pack(&Success::new(v))),
|
||||
// Internal serialization
|
||||
Some("application/surrealdb") => Ok(output::full(&Success::new(v))),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error with authentication
|
||||
Err(e) => Err(warp::reject::custom(e)),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
// The provided value was not an object
|
||||
_ => Err(warp::reject::custom(Error::Request)),
|
||||
_ => Err(Error::Request),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,74 +3,80 @@ use crate::err::Error;
|
|||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
use crate::net::params::Params;
|
||||
use crate::net::session;
|
||||
use axum::extract::ws::Message;
|
||||
use axum::extract::ws::WebSocket;
|
||||
use axum::extract::DefaultBodyLimit;
|
||||
use axum::extract::Query;
|
||||
use axum::extract::WebSocketUpgrade;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::options;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use axum::TypedHeader;
|
||||
use bytes::Bytes;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use http_body::Body as HttpBody;
|
||||
use surrealdb::dbs::Session;
|
||||
use warp::ws::{Message, WebSocket, Ws};
|
||||
use warp::Filter;
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
const MAX: u64 = 1024 * 1024; // 1 MiB
|
||||
use super::headers::Accept;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
// Set base path
|
||||
let base = warp::path("sql").and(warp::path::end());
|
||||
// Set opts method
|
||||
let opts = base.and(warp::options()).map(warp::reply);
|
||||
// Set post method
|
||||
let post = base
|
||||
.and(warp::post())
|
||||
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
|
||||
.and(warp::body::content_length_limit(MAX))
|
||||
.and(warp::body::bytes())
|
||||
.and(warp::query())
|
||||
.and(session::build())
|
||||
.and_then(handler);
|
||||
// Set sock method
|
||||
let sock = base
|
||||
.and(warp::ws())
|
||||
.and(session::build())
|
||||
.map(|ws: Ws, session: Session| ws.on_upgrade(move |ws| socket(ws, session)));
|
||||
// Specify route
|
||||
opts.or(post).or(sock)
|
||||
const MAX: usize = 1024 * 1024; // 1 MiB
|
||||
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new()
|
||||
.route("/sql", options(|| async {}).get(ws_handler).post(post_handler))
|
||||
.route_layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(MAX))
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
output: String,
|
||||
async fn post_handler(
|
||||
Extension(session): Extension<Session>,
|
||||
output: Option<TypedHeader<Accept>>,
|
||||
params: Query<Params>,
|
||||
sql: Bytes,
|
||||
params: Params,
|
||||
session: Session,
|
||||
) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get a database reference
|
||||
let db = DB.get().unwrap();
|
||||
// Convert the received sql query
|
||||
let sql = bytes_to_utf8(&sql)?;
|
||||
// Execute the received sql query
|
||||
match db.execute(sql, &session, params.parse().into()).await {
|
||||
// Convert the response to JSON
|
||||
Ok(res) => match output.as_ref() {
|
||||
match db.execute(sql, &session, params.0.parse().into()).await {
|
||||
Ok(res) => match output.as_deref() {
|
||||
// Simple serialization
|
||||
"application/json" => Ok(output::json(&output::simplify(res))),
|
||||
"application/cbor" => Ok(output::cbor(&output::simplify(res))),
|
||||
"application/pack" => Ok(output::pack(&output::simplify(res))),
|
||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||
// Internal serialization
|
||||
"application/surrealdb" => Ok(output::full(&res)),
|
||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||
// An incorrect content-type was requested
|
||||
_ => Err(warp::reject::custom(Error::InvalidType)),
|
||||
_ => Err(Error::InvalidType),
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(err) => Err(warp::reject::custom(Error::from(err))),
|
||||
Err(err) => Err(Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn socket(ws: WebSocket, session: Session) {
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(sess): Extension<Session>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, sess))
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, session: Session) {
|
||||
// Split the WebSocket connection
|
||||
let (mut tx, mut rx) = ws.split();
|
||||
// Wait to receive the next message
|
||||
while let Some(res) = rx.next().await {
|
||||
if let Ok(msg) = res {
|
||||
if let Ok(sql) = msg.to_str() {
|
||||
if let Ok(sql) = msg.to_text() {
|
||||
// Get a database reference
|
||||
let db = DB.get().unwrap();
|
||||
// Execute the received sql query
|
||||
|
@ -78,12 +84,12 @@ async fn socket(ws: WebSocket, session: Session) {
|
|||
// Convert the response to JSON
|
||||
Ok(v) => match serde_json::to_string(&v) {
|
||||
// Send the JSON response to the client
|
||||
Ok(v) => tx.send(Message::text(v)).await,
|
||||
Ok(v) => tx.send(Message::Text(v)).await,
|
||||
// There was an error converting to JSON
|
||||
Err(e) => tx.send(Message::text(Error::from(e))).await,
|
||||
Err(e) => tx.send(Message::Text(Error::from(e).to_string())).await,
|
||||
},
|
||||
// There was an error when executing the query
|
||||
Err(e) => tx.send(Message::text(Error::from(e))).await,
|
||||
Err(e) => tx.send(Message::Text(Error::from(e).to_string())).await,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
use warp::Filter;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path("status").and(warp::path::end()).and(warp::get()).map(warp::reply)
|
||||
}
|
|
@ -1,22 +1,20 @@
|
|||
use warp::http;
|
||||
use warp::Filter;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use http_body::Body as HttpBody;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
// Set base path
|
||||
let base = warp::path("sync").and(warp::path::end());
|
||||
// Set save method
|
||||
let save = base.and(warp::get()).and_then(save);
|
||||
// Set load method
|
||||
let load = base.and(warp::post()).and_then(load);
|
||||
// Specify route
|
||||
save.or(load)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new().route("/sync", get(save).post(load))
|
||||
}
|
||||
|
||||
pub async fn load() -> Result<impl warp::Reply, warp::Rejection> {
|
||||
Ok(warp::reply::with_status("Load", http::StatusCode::OK))
|
||||
async fn load() -> impl IntoResponse {
|
||||
"Load"
|
||||
}
|
||||
|
||||
pub async fn save() -> Result<impl warp::Reply, warp::Rejection> {
|
||||
Ok(warp::reply::with_status("Save", http::StatusCode::OK))
|
||||
async fn save() -> impl IntoResponse {
|
||||
"Save"
|
||||
}
|
||||
|
|
235
src/net/tracer.rs
Normal file
235
src/net/tracer.rs
Normal file
|
@ -0,0 +1,235 @@
|
|||
use std::{fmt, time::Duration};
|
||||
|
||||
use axum::{
|
||||
body::{boxed, Body, BoxBody},
|
||||
extract::MatchedPath,
|
||||
headers::{
|
||||
authorization::{Basic, Bearer},
|
||||
Authorization, Origin,
|
||||
},
|
||||
Extension, RequestPartsExt, TypedHeader,
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use http::{header, request::Parts, StatusCode};
|
||||
use hyper::{Request, Response};
|
||||
use surrealdb::{dbs::Session, iam::verify::token};
|
||||
use tower_http::{
|
||||
auth::AsyncAuthorizeRequest,
|
||||
request_id::RequestId,
|
||||
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
|
||||
};
|
||||
use tracing::{field, Level, Span};
|
||||
|
||||
use crate::{dbs::DB, err::Error, iam::verify::basic};
|
||||
|
||||
use super::{client_ip::ExtractClientIP, AppState};
|
||||
|
||||
///
|
||||
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
|
||||
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
|
||||
///
|
||||
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
|
||||
///
|
||||
/// ```rust
|
||||
/// use tower_http::auth::RequireAuthorizationLayer;
|
||||
/// use surrealdb::net::SurrealAuth;
|
||||
/// use axum::Router;
|
||||
///
|
||||
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/version", get(|| async { "0.1.0" }))
|
||||
/// .layer(auth);
|
||||
/// ```
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct SurrealAuth;
|
||||
|
||||
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type RequestBody = B;
|
||||
type ResponseBody = BoxBody;
|
||||
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
||||
|
||||
fn authorize(&mut self, request: Request<B>) -> Self::Future {
|
||||
Box::pin(async {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
match check_auth(&mut parts).await {
|
||||
Ok(sess) => {
|
||||
parts.extensions.insert(sess);
|
||||
Ok(Request::from_parts(parts, body))
|
||||
}
|
||||
Err(err) => {
|
||||
let unauthorized_response = Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(boxed(Body::from(err.to_string())))
|
||||
.unwrap();
|
||||
Err(unauthorized_response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
|
||||
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
|
||||
if !or.is_null() {
|
||||
Some(or.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
|
||||
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
|
||||
tracing::error!("Error extracting the app state: {:?}", err);
|
||||
Error::InvalidAuth
|
||||
})?;
|
||||
let ExtractClientIP(ip) =
|
||||
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
|
||||
|
||||
// Create session
|
||||
#[rustfmt::skip]
|
||||
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
|
||||
|
||||
// If Basic authentication data was supplied
|
||||
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
|
||||
basic(&mut session, au.username(), au.password()).await
|
||||
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
|
||||
token(kvs, &mut session, au.token().into()).await.map_err(|e| e.into())
|
||||
} else {
|
||||
Err(Error::InvalidAuth)
|
||||
}?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
///
|
||||
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```rust
|
||||
/// use tower_http::trace::TraceLayer;
|
||||
/// use surrealdb::net::HttpTraceLayerHooks;
|
||||
/// use axum::Router;
|
||||
///
|
||||
/// let trace = TraceLayer::new_for_http().on_request(HttpTraceLayerHooks::default());
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/version", get(|| async { "0.1.0" }))
|
||||
/// .layer(trace);
|
||||
/// ```
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub(crate) struct HttpTraceLayerHooks;
|
||||
|
||||
impl<B> MakeSpan<B> for HttpTraceLayerHooks {
|
||||
fn make_span(&mut self, req: &Request<B>) -> Span {
|
||||
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
|
||||
let span = tracing::info_span!(
|
||||
target: "surreal::http",
|
||||
"request",
|
||||
otel.name = field::Empty,
|
||||
otel.kind = "server",
|
||||
http.route = field::Empty,
|
||||
http.request.method = req.method().as_str(),
|
||||
http.request.body.size = field::Empty,
|
||||
url.path = req.uri().path(),
|
||||
url.query = field::Empty,
|
||||
url.scheme = field::Empty,
|
||||
http.request.id = field::Empty,
|
||||
user_agent.original = field::Empty,
|
||||
network.protocol.name = "http",
|
||||
network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"),
|
||||
client.address = field::Empty,
|
||||
client.port = field::Empty,
|
||||
client.socket.address = field::Empty,
|
||||
server.address = field::Empty,
|
||||
server.port = field::Empty,
|
||||
// set on the response hook
|
||||
http.latency.ms = field::Empty,
|
||||
http.response.status_code = field::Empty,
|
||||
http.response.body.size = field::Empty,
|
||||
// set on the failure hook
|
||||
error = field::Empty,
|
||||
error_message = field::Empty,
|
||||
);
|
||||
|
||||
req.uri().query().map(|v| span.record("url.query", v));
|
||||
req.uri().scheme().map(|v| span.record("url.scheme", v.as_str()));
|
||||
req.uri().host().map(|v| span.record("server.address", v));
|
||||
req.uri().port_u16().map(|v| span.record("server.port", v));
|
||||
|
||||
req.headers()
|
||||
.get(header::CONTENT_LENGTH)
|
||||
.map(|v| v.to_str().map(|v| span.record("http.request.body.size", v)));
|
||||
req.headers()
|
||||
.get(header::USER_AGENT)
|
||||
.map(|v| v.to_str().map(|v| span.record("user_agent.original", v)));
|
||||
|
||||
if let Some(path) = req.extensions().get::<MatchedPath>() {
|
||||
span.record("otel.name", format!("{} {}", req.method(), path.as_str()));
|
||||
span.record("http.route", path.as_str());
|
||||
} else {
|
||||
span.record("otel.name", format!("{} -", req.method()));
|
||||
};
|
||||
|
||||
if let Some(req_id) = req.extensions().get::<RequestId>() {
|
||||
match req_id.header_value().to_str() {
|
||||
Err(err) => tracing::error!(error = %err, "failed to parse request id"),
|
||||
Ok(request_id) => {
|
||||
span.record("http.request.id", request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(client_ip) = req.extensions().get::<ExtractClientIP>() {
|
||||
if let Some(ref client_ip) = client_ip.0 {
|
||||
span.record("client.address", client_ip);
|
||||
}
|
||||
}
|
||||
|
||||
span
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> OnRequest<B> for HttpTraceLayerHooks {
|
||||
fn on_request(&mut self, _: &Request<B>, _: &Span) {
|
||||
tracing::event!(Level::INFO, "started processing request");
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> OnResponse<B> for HttpTraceLayerHooks {
|
||||
fn on_response(self, response: &Response<B>, latency: Duration, span: &Span) {
|
||||
if let Some(size) = response.headers().get(header::CONTENT_LENGTH) {
|
||||
span.record("http.response.body.size", size.to_str().unwrap());
|
||||
}
|
||||
span.record("http.response.status_code", response.status().as_u16());
|
||||
|
||||
// Server errors are handled by the OnFailure hook
|
||||
if !response.status().is_server_error() {
|
||||
span.record("http.latency.ms", latency.as_millis());
|
||||
tracing::event!(Level::INFO, "finished processing request");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<FailureClass> OnFailure<FailureClass> for HttpTraceLayerHooks
|
||||
where
|
||||
FailureClass: fmt::Display,
|
||||
{
|
||||
fn on_failure(&mut self, error: FailureClass, latency: Duration, span: &Span) {
|
||||
span.record("error_message", &error.to_string());
|
||||
span.record("http.latency.ms", latency.as_millis());
|
||||
tracing::event!(Level::ERROR, error = error.to_string(), "response failed");
|
||||
}
|
||||
}
|
|
@ -1,14 +1,18 @@
|
|||
use crate::cnf::PKG_NAME;
|
||||
use crate::cnf::PKG_VERSION;
|
||||
use warp::http;
|
||||
use warp::Filter;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use http_body::Body as HttpBody;
|
||||
|
||||
#[allow(opaque_hidden_inferred_bound)]
|
||||
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path("version").and(warp::path::end()).and(warp::get()).and_then(handler)
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Router::new().route("/version", get(handler))
|
||||
}
|
||||
|
||||
pub async fn handler() -> Result<impl warp::Reply, warp::Rejection> {
|
||||
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
|
||||
Ok(warp::reply::with_status(val, http::StatusCode::OK))
|
||||
async fn handler() -> impl IntoResponse {
|
||||
format!("{PKG_NAME}-{}", *PKG_VERSION)
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use axum::extract::ws::Message;
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value as Json};
|
||||
use std::borrow::Cow;
|
||||
|
@ -7,7 +8,6 @@ use surrealdb::dbs::Notification;
|
|||
use surrealdb::sql;
|
||||
use surrealdb::sql::Value;
|
||||
use tracing::instrument;
|
||||
use warp::ws::Message;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Output {
|
||||
|
@ -87,19 +87,19 @@ impl Response {
|
|||
let message = match out {
|
||||
Output::Json => {
|
||||
let res = serde_json::to_string(&self.simplify()).unwrap();
|
||||
Message::text(res)
|
||||
Message::Text(res)
|
||||
}
|
||||
Output::Cbor => {
|
||||
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
||||
Message::binary(res)
|
||||
Message::Binary(res)
|
||||
}
|
||||
Output::Pack => {
|
||||
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
||||
Message::binary(res)
|
||||
Message::Binary(res)
|
||||
}
|
||||
Output::Full => {
|
||||
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
||||
Message::binary(res)
|
||||
Message::Binary(res)
|
||||
}
|
||||
};
|
||||
let _ = chn.send(message).await;
|
||||
|
|
112
src/telemetry/metrics/http/mod.rs
Normal file
112
src/telemetry/metrics/http/mod.rs
Normal file
|
@ -0,0 +1,112 @@
|
|||
pub(super) mod tower_layer;
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use opentelemetry::{
|
||||
metrics::{Histogram, Meter, MeterProvider, ObservableUpDownCounter, Unit},
|
||||
runtime,
|
||||
sdk::{
|
||||
export::metrics::aggregation,
|
||||
metrics::{
|
||||
controllers::{self, BasicController},
|
||||
processors, selectors,
|
||||
},
|
||||
},
|
||||
Context,
|
||||
};
|
||||
use opentelemetry_otlp::MetricsExporterBuilder;
|
||||
|
||||
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
|
||||
|
||||
// Histogram buckets in milliseconds
|
||||
static HTTP_DURATION_MS_HISTOGRAM_BUCKETS: &[f64] = &[
|
||||
5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0,
|
||||
2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0,
|
||||
];
|
||||
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = 1024.0 * KB;
|
||||
|
||||
const HTTP_SIZE_HISTOGRAM_BUCKETS: &[f64] = &[
|
||||
1.0 * KB, // 1 KB
|
||||
2.0 * KB, // 2 KB
|
||||
5.0 * KB, // 5 KB
|
||||
10.0 * KB, // 10 KB
|
||||
100.0 * KB, // 100 KB
|
||||
500.0 * KB, // 500 KB
|
||||
1.0 * MB, // 1 MB
|
||||
2.5 * MB, // 2 MB
|
||||
5.0 * MB, // 5 MB
|
||||
10.0 * MB, // 10 MB
|
||||
25.0 * MB, // 25 MB
|
||||
50.0 * MB, // 50 MB
|
||||
100.0 * MB, // 100 MB
|
||||
];
|
||||
|
||||
static METER_PROVIDER_HTTP_DURATION: Lazy<BasicController> = Lazy::new(|| {
|
||||
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
|
||||
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
|
||||
.unwrap();
|
||||
|
||||
let builder = controllers::basic(processors::factory(
|
||||
selectors::simple::histogram(HTTP_DURATION_MS_HISTOGRAM_BUCKETS),
|
||||
aggregation::cumulative_temporality_selector(),
|
||||
))
|
||||
.with_exporter(exporter)
|
||||
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
|
||||
|
||||
let controller = builder.build();
|
||||
controller.start(&Context::current(), runtime::Tokio).unwrap();
|
||||
controller
|
||||
});
|
||||
|
||||
static METER_PROVIDER_HTTP_SIZE: Lazy<BasicController> = Lazy::new(|| {
|
||||
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
|
||||
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
|
||||
.unwrap();
|
||||
|
||||
let builder = controllers::basic(processors::factory(
|
||||
selectors::simple::histogram(HTTP_SIZE_HISTOGRAM_BUCKETS),
|
||||
aggregation::cumulative_temporality_selector(),
|
||||
))
|
||||
.with_exporter(exporter)
|
||||
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
|
||||
|
||||
let controller = builder.build();
|
||||
controller.start(&Context::current(), runtime::Tokio).unwrap();
|
||||
controller
|
||||
});
|
||||
|
||||
static HTTP_DURATION_METER: Lazy<Meter> =
|
||||
Lazy::new(|| METER_PROVIDER_HTTP_DURATION.meter("http_duration"));
|
||||
static HTTP_SIZE_METER: Lazy<Meter> = Lazy::new(|| METER_PROVIDER_HTTP_SIZE.meter("http_size"));
|
||||
|
||||
pub static HTTP_SERVER_DURATION: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||
HTTP_DURATION_METER
|
||||
.u64_histogram("http.server.duration")
|
||||
.with_description("The HTTP server duration in milliseconds.")
|
||||
.with_unit(Unit::new("ms"))
|
||||
.init()
|
||||
});
|
||||
|
||||
pub static HTTP_SERVER_ACTIVE_REQUESTS: Lazy<ObservableUpDownCounter<i64>> = Lazy::new(|| {
|
||||
HTTP_DURATION_METER
|
||||
.i64_observable_up_down_counter("http.server.active_requests")
|
||||
.with_description("The number of active HTTP requests.")
|
||||
.init()
|
||||
});
|
||||
|
||||
pub static HTTP_SERVER_REQUEST_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||
HTTP_SIZE_METER
|
||||
.u64_histogram("http.server.request.size")
|
||||
.with_description("Measures the size of HTTP request messages.")
|
||||
.with_unit(Unit::new("mb"))
|
||||
.init()
|
||||
});
|
||||
|
||||
pub static HTTP_SERVER_RESPONSE_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||
HTTP_SIZE_METER
|
||||
.u64_histogram("http.server.response.size")
|
||||
.with_description("Measures the size of HTTP response messages.")
|
||||
.with_unit(Unit::new("mb"))
|
||||
.init()
|
||||
});
|
310
src/telemetry/metrics/http/tower_layer.rs
Normal file
310
src/telemetry/metrics/http/tower_layer.rs
Normal file
|
@ -0,0 +1,310 @@
|
|||
use axum::extract::MatchedPath;
|
||||
use opentelemetry::{metrics::MetricsError, Context as TelemetryContext, KeyValue};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
cell::Cell,
|
||||
fmt,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use futures::Future;
|
||||
use http::{Request, Response, StatusCode, Version};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use super::{
|
||||
HTTP_DURATION_METER, HTTP_SERVER_ACTIVE_REQUESTS, HTTP_SERVER_DURATION,
|
||||
HTTP_SERVER_REQUEST_SIZE, HTTP_SERVER_RESPONSE_SIZE,
|
||||
};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct HttpMetricsLayer;
|
||||
|
||||
impl<S> Layer<S> for HttpMetricsLayer {
|
||||
type Service = HttpMetrics<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
HttpMetrics {
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HttpMetrics<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for HttpMetrics<S>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
||||
ReqBody: http_body::Body,
|
||||
ResBody: http_body::Body,
|
||||
S::Error: fmt::Display + 'static,
|
||||
{
|
||||
type Response = Response<ResBody>;
|
||||
type Error = S::Error;
|
||||
type Future = HttpCallMetricsFuture<S::Future>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
|
||||
let tracker = HttpCallMetricTracker::new(&request);
|
||||
|
||||
HttpCallMetricsFuture::new(self.inner.call(request), tracker)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct HttpCallMetricsFuture<F> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
tracker: HttpCallMetricTracker,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> HttpCallMetricsFuture<F> {
|
||||
fn new(inner: F, tracker: HttpCallMetricTracker) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut, ResBody, E> Future for HttpCallMetricsFuture<Fut>
|
||||
where
|
||||
Fut: Future<Output = Result<Response<ResBody>, E>>,
|
||||
ResBody: http_body::Body,
|
||||
E: std::fmt::Display + 'static,
|
||||
{
|
||||
type Output = Result<Response<ResBody>, E>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
this.tracker.set_state(ResultState::Started);
|
||||
|
||||
if let Err(err) = on_request_start(this.tracker) {
|
||||
error!("Failed to setup metrics when request started: {}", err);
|
||||
// Consider this request not tracked: reset the state to None, so that the drop handler does not decrease the counter.
|
||||
this.tracker.set_state(ResultState::None);
|
||||
};
|
||||
|
||||
let response = futures_util::ready!(this.inner.poll(cx));
|
||||
|
||||
let result = match response {
|
||||
Ok(reply) => {
|
||||
this.tracker.set_state(ResultState::Result(
|
||||
reply.status(),
|
||||
reply.version(),
|
||||
reply.body().size_hint().exact(),
|
||||
));
|
||||
Ok(reply)
|
||||
}
|
||||
Err(e) => {
|
||||
this.tracker.set_state(ResultState::Failed);
|
||||
Err(e)
|
||||
}
|
||||
};
|
||||
Poll::Ready(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HttpCallMetricTracker {
|
||||
version: String,
|
||||
method: hyper::Method,
|
||||
scheme: Option<http::uri::Scheme>,
|
||||
host: Option<String>,
|
||||
route: Option<String>,
|
||||
state: Cell<ResultState>,
|
||||
status_code: Option<StatusCode>,
|
||||
request_size: Option<u64>,
|
||||
response_size: Option<u64>,
|
||||
start: Instant,
|
||||
finish: Option<Instant>,
|
||||
}
|
||||
|
||||
pub enum ResultState {
|
||||
/// The result was already processed.
|
||||
None,
|
||||
/// Request was started.
|
||||
Started,
|
||||
/// The result failed with an error.
|
||||
Failed,
|
||||
/// The result is an actual HTTP response.
|
||||
Result(StatusCode, Version, Option<u64>),
|
||||
}
|
||||
|
||||
impl HttpCallMetricTracker {
|
||||
fn new<B>(request: &Request<B>) -> Self
|
||||
where
|
||||
B: http_body::Body,
|
||||
{
|
||||
Self {
|
||||
version: format!("{:?}", request.version()),
|
||||
method: request.method().clone(),
|
||||
scheme: request.uri().scheme().cloned(),
|
||||
host: request.uri().host().map(|s| s.to_string()),
|
||||
route: request.extensions().get::<MatchedPath>().map(|v| v.as_str().to_string()),
|
||||
state: Cell::new(ResultState::None),
|
||||
status_code: None,
|
||||
request_size: request.body().size_hint().exact(),
|
||||
response_size: None,
|
||||
start: Instant::now(),
|
||||
finish: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_state(&self, state: ResultState) {
|
||||
self.state.set(state)
|
||||
}
|
||||
|
||||
pub fn duration(&self) -> Duration {
|
||||
self.finish.unwrap_or(Instant::now()) - self.start
|
||||
}
|
||||
|
||||
// Follows the OpenTelemetry semantic conventions for HTTP metrics define here: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/metrics/semantic_conventions/http-metrics.md
|
||||
fn olel_common_attrs(&self) -> Vec<KeyValue> {
|
||||
let mut res = vec![
|
||||
KeyValue::new("http.request.method", self.method.as_str().to_owned()),
|
||||
KeyValue::new("network.protocol.name", "http".to_owned()),
|
||||
];
|
||||
|
||||
if let Some(scheme) = &self.scheme {
|
||||
res.push(KeyValue::new("url.scheme", scheme.as_str().to_owned()));
|
||||
}
|
||||
|
||||
if let Some(host) = &self.host {
|
||||
res.push(KeyValue::new("server.address", host.to_owned()));
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
pub(super) fn active_req_attrs(&self) -> Vec<KeyValue> {
|
||||
self.olel_common_attrs()
|
||||
}
|
||||
|
||||
pub(super) fn request_duration_attrs(&self) -> Vec<KeyValue> {
|
||||
let mut res = self.olel_common_attrs();
|
||||
|
||||
res.push(KeyValue::new(
|
||||
"http.response.status_code",
|
||||
self.status_code.map(|v| v.as_str().to_owned()).unwrap_or("000".to_owned()),
|
||||
));
|
||||
|
||||
if let Some(v) = self.version.strip_prefix("HTTP/") {
|
||||
res.push(KeyValue::new("network.protocol.version", v.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(target) = &self.route {
|
||||
res.push(KeyValue::new("http.route", target.to_owned()));
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
pub(super) fn request_size_attrs(&self) -> Vec<KeyValue> {
|
||||
self.request_duration_attrs()
|
||||
}
|
||||
|
||||
pub(super) fn response_size_attrs(&self) -> Vec<KeyValue> {
|
||||
self.request_duration_attrs()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HttpCallMetricTracker {
|
||||
fn drop(&mut self) {
|
||||
match self.state.replace(ResultState::None) {
|
||||
ResultState::None => {
|
||||
// Request was not tracked, so no need to decrease the counter.
|
||||
return;
|
||||
}
|
||||
ResultState::Started => {
|
||||
// If the response was never processed, we can't get a valid status code
|
||||
}
|
||||
ResultState::Failed => {
|
||||
// If there's an error processing the request and we don't have a response, we can't get a valid status code
|
||||
}
|
||||
ResultState::Result(s, v, size) => {
|
||||
self.status_code = Some(s);
|
||||
self.version = format!("{:?}", v);
|
||||
self.response_size = size;
|
||||
}
|
||||
};
|
||||
|
||||
self.finish = Some(Instant::now());
|
||||
|
||||
if let Err(err) = on_request_finish(self) {
|
||||
error!(target: "surrealdb::telemetry", "Failed to setup metrics when request finished: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn on_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||
// Setup the active_requests observer
|
||||
observe_active_request_start(tracker)
|
||||
}
|
||||
|
||||
pub fn on_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||
// Setup the active_requests observer
|
||||
observe_active_request_finish(tracker)?;
|
||||
|
||||
// Record the duration of the request.
|
||||
record_request_duration(tracker);
|
||||
|
||||
// Record the request size if known
|
||||
if let Some(size) = tracker.request_size {
|
||||
record_request_size(tracker, size)
|
||||
}
|
||||
|
||||
// Record the response size if known
|
||||
if let Some(size) = tracker.response_size {
|
||||
record_response_size(tracker, size)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn observe_active_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||
let attrs = tracker.active_req_attrs();
|
||||
// Setup the callback to observe the active requests.
|
||||
HTTP_DURATION_METER
|
||||
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, 1, &attrs))
|
||||
}
|
||||
|
||||
fn observe_active_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||
let attrs = tracker.active_req_attrs();
|
||||
// Setup the callback to observe the active requests.
|
||||
HTTP_DURATION_METER
|
||||
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, -1, &attrs))
|
||||
}
|
||||
|
||||
fn record_request_duration(tracker: &HttpCallMetricTracker) {
|
||||
// Record the duration of the request.
|
||||
HTTP_SERVER_DURATION.record(
|
||||
&TelemetryContext::current(),
|
||||
tracker.duration().as_millis() as u64,
|
||||
&tracker.request_duration_attrs(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) {
|
||||
HTTP_SERVER_REQUEST_SIZE.record(
|
||||
&TelemetryContext::current(),
|
||||
size,
|
||||
&tracker.request_size_attrs(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) {
|
||||
HTTP_SERVER_RESPONSE_SIZE.record(
|
||||
&TelemetryContext::current(),
|
||||
size,
|
||||
&tracker.response_size_attrs(),
|
||||
);
|
||||
}
|
3
src/telemetry/metrics/mod.rs
Normal file
3
src/telemetry/metrics/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub mod http;
|
||||
|
||||
pub use self::http::tower_layer::HttpMetricsLayer;
|
|
@ -1,7 +1,16 @@
|
|||
mod logger;
|
||||
mod tracers;
|
||||
mod logs;
|
||||
pub mod metrics;
|
||||
mod traces;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
||||
use once_cell::sync::Lazy;
|
||||
use opentelemetry::sdk::resource::{
|
||||
EnvResourceDetector, SdkProvidedResourceDetector, TelemetryResourceDetector,
|
||||
};
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::KeyValue;
|
||||
use tracing::Subscriber;
|
||||
use tracing_subscriber::fmt::format::FmtSpan;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
@ -9,6 +18,28 @@ use tracing_subscriber::util::SubscriberInitExt;
|
|||
#[cfg(feature = "has-storage")]
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub static OTEL_DEFAULT_RESOURCE: Lazy<Resource> = Lazy::new(|| {
|
||||
let res = Resource::from_detectors(
|
||||
Duration::from_secs(5),
|
||||
vec![
|
||||
// set service.name from env OTEL_SERVICE_NAME > env OTEL_RESOURCE_ATTRIBUTES > option_env! CARGO_BIN_NAME > unknown_service
|
||||
Box::new(SdkProvidedResourceDetector),
|
||||
// detect res from env OTEL_RESOURCE_ATTRIBUTES (resources string like key1=value1,key2=value2,...)
|
||||
Box::new(EnvResourceDetector::new()),
|
||||
// set telemetry.sdk.{name, language, version}
|
||||
Box::new(TelemetryResourceDetector),
|
||||
],
|
||||
);
|
||||
|
||||
// If no external service.name is set, set it to surrealdb
|
||||
if res.get("service.name".into()).unwrap_or("".into()).as_str() == "unknown_service" {
|
||||
debug!("No service.name detected, use 'surrealdb'");
|
||||
res.merge(&Resource::new([KeyValue::new("service.name", "surrealdb")]))
|
||||
} else {
|
||||
res
|
||||
}
|
||||
});
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Builder {
|
||||
log_level: Option<String>,
|
||||
|
@ -32,7 +63,8 @@ impl Builder {
|
|||
self.filter = Some(CustomEnvFilter(filter));
|
||||
self
|
||||
}
|
||||
/// Build a dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
|
||||
|
||||
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
|
||||
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
|
||||
let registry = tracing_subscriber::registry();
|
||||
let registry = registry.with(self.filter.map(|filter| {
|
||||
|
@ -44,11 +76,12 @@ impl Builder {
|
|||
.with_filter(filter.0)
|
||||
.boxed()
|
||||
}));
|
||||
let registry = registry.with(self.log_level.map(logger::new));
|
||||
let registry = registry.with(tracers::new());
|
||||
let registry = registry.with(self.log_level.map(logs::new));
|
||||
let registry = registry.with(traces::new());
|
||||
Box::new(registry)
|
||||
}
|
||||
/// Build a dispatcher and set it as global
|
||||
|
||||
/// tracing pipeline
|
||||
pub fn init(self) {
|
||||
self.build().init()
|
||||
}
|
||||
|
@ -60,10 +93,12 @@ mod tests {
|
|||
use tracing::{span, Level};
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
use crate::telemetry;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_otlp_tracer() {
|
||||
println!("Starting mock otlp server...");
|
||||
let (addr, mut req_rx) = super::tracers::tests::mock_otlp_server().await;
|
||||
let (addr, mut req_rx) = telemetry::traces::tests::mock_otlp_server().await;
|
||||
|
||||
{
|
||||
let otlp_endpoint = format!("http://{}", addr);
|
||||
|
@ -73,7 +108,7 @@ mod tests {
|
|||
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
||||
],
|
||||
|| {
|
||||
let _enter = super::builder().build().set_default();
|
||||
let _enter = telemetry::builder().build().set_default();
|
||||
|
||||
println!("Sending span...");
|
||||
|
||||
|
@ -90,16 +125,8 @@ mod tests {
|
|||
|
||||
println!("Waiting for request...");
|
||||
let req = req_rx.recv().await.expect("missing export request");
|
||||
let first_span = req
|
||||
.resource_spans
|
||||
.first()
|
||||
.unwrap()
|
||||
.instrumentation_library_spans
|
||||
.first()
|
||||
.unwrap()
|
||||
.spans
|
||||
.first()
|
||||
.unwrap();
|
||||
let first_span =
|
||||
req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
|
||||
assert_eq!("test-surreal-span", first_span.name);
|
||||
let first_event = first_span.events.first().unwrap();
|
||||
assert_eq!("test-surreal-event", first_event.name);
|
||||
|
@ -108,7 +135,7 @@ mod tests {
|
|||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_tracing_filter() {
|
||||
println!("Starting mock otlp server...");
|
||||
let (addr, mut req_rx) = super::tracers::tests::mock_otlp_server().await;
|
||||
let (addr, mut req_rx) = telemetry::traces::tests::mock_otlp_server().await;
|
||||
|
||||
{
|
||||
let otlp_endpoint = format!("http://{}", addr);
|
||||
|
@ -119,7 +146,7 @@ mod tests {
|
|||
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
||||
],
|
||||
|| {
|
||||
let _enter = super::builder().build().set_default();
|
||||
let _enter = telemetry::builder().build().set_default();
|
||||
|
||||
println!("Sending spans...");
|
||||
|
||||
|
@ -144,14 +171,7 @@ mod tests {
|
|||
|
||||
println!("Waiting for request...");
|
||||
let req = req_rx.recv().await.expect("missing export request");
|
||||
let spans = &req
|
||||
.resource_spans
|
||||
.first()
|
||||
.unwrap()
|
||||
.instrumentation_library_spans
|
||||
.first()
|
||||
.unwrap()
|
||||
.spans;
|
||||
let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
|
||||
|
||||
assert_eq!(1, spans.len());
|
||||
assert_eq!("debug", spans.first().unwrap().name);
|
|
@ -59,7 +59,9 @@ pub mod tests {
|
|||
request: tonic::Request<ExportTraceServiceRequest>,
|
||||
) -> Result<tonic::Response<ExportTraceServiceResponse>, tonic::Status> {
|
||||
self.tx.lock().unwrap().try_send(request.into_inner()).expect("Channel full");
|
||||
Ok(tonic::Response::new(ExportTraceServiceResponse {}))
|
||||
Ok(tonic::Response::new(ExportTraceServiceResponse {
|
||||
partial_success: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
use opentelemetry::sdk::{trace::Tracer, Resource};
|
||||
use opentelemetry::sdk::trace::Tracer;
|
||||
use opentelemetry::trace::TraceError;
|
||||
use opentelemetry::KeyValue;
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tracing::{Level, Subscriber};
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
|
||||
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
|
||||
|
||||
const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER";
|
||||
|
||||
pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync>
|
||||
|
@ -15,12 +16,12 @@ where
|
|||
}
|
||||
|
||||
fn tracer() -> Result<Tracer, TraceError> {
|
||||
let resource = Resource::new(vec![KeyValue::new("service.name", "surrealdb")]);
|
||||
|
||||
opentelemetry_otlp::new_pipeline()
|
||||
.tracing()
|
||||
.with_exporter(opentelemetry_otlp::new_exporter().tonic().with_env())
|
||||
.with_trace_config(opentelemetry::sdk::trace::config().with_resource(resource))
|
||||
.with_trace_config(
|
||||
opentelemetry::sdk::trace::config().with_resource(OTEL_DEFAULT_RESOURCE.clone()),
|
||||
)
|
||||
.install_batch(opentelemetry::runtime::Tokio)
|
||||
}
|
||||
|
354
tests/cli.rs
354
tests/cli.rs
|
@ -1,354 +0,0 @@
|
|||
mod cli_integration {
|
||||
// cargo test --package surreal --bin surreal --no-default-features --features storage-mem --test cli -- cli_integration --nocapture
|
||||
|
||||
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serial_test::serial;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
|
||||
/// Child is a (maybe running) CLI process. It can be killed by dropping it
|
||||
struct Child {
|
||||
inner: Option<std::process::Child>,
|
||||
}
|
||||
|
||||
impl Child {
|
||||
/// Send some thing to the child's stdin
|
||||
fn input(mut self, input: &str) -> Self {
|
||||
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
|
||||
use std::io::Write;
|
||||
stdin.write_all(input.as_bytes()).unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
fn kill(mut self) -> Self {
|
||||
self.inner.as_mut().unwrap().kill().unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
||||
/// returns successfully, Err otherwise.
|
||||
fn output(mut self) -> Result<String, String> {
|
||||
let output = self.inner.take().unwrap().wait_with_output().unwrap();
|
||||
|
||||
let mut buf = String::from_utf8(output.stdout).unwrap();
|
||||
buf.push_str(&String::from_utf8(output.stderr).unwrap());
|
||||
|
||||
if output.status.success() {
|
||||
Ok(buf)
|
||||
} else {
|
||||
Err(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Child {
|
||||
fn drop(&mut self) {
|
||||
if let Some(inner) = self.inner.as_mut() {
|
||||
let _ = inner.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
||||
let mut path = std::env::current_exe().unwrap();
|
||||
assert!(path.pop());
|
||||
if path.ends_with("deps") {
|
||||
assert!(path.pop());
|
||||
}
|
||||
|
||||
// Note: Cargo automatically builds this binary for integration tests.
|
||||
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
|
||||
|
||||
let mut cmd = Command::new(path);
|
||||
if let Some(dir) = current_dir {
|
||||
cmd.current_dir(&dir);
|
||||
}
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
cmd.args(args.split_ascii_whitespace());
|
||||
Child {
|
||||
inner: Some(cmd.spawn().unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args
|
||||
fn run(args: &str) -> Child {
|
||||
run_internal::<String>(args, None)
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args inside a temporary directory
|
||||
fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||
run_internal(args, Some(current_dir))
|
||||
}
|
||||
|
||||
fn tmp_file(name: &str) -> String {
|
||||
let path = Path::new(env!("OUT_DIR")).join(name);
|
||||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn version() {
|
||||
assert!(run("version").output().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn help() {
|
||||
assert!(run("help").output().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn nonexistent_subcommand() {
|
||||
assert!(run("nonexistent").output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn nonexistent_option() {
|
||||
assert!(run("version --turbo").output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn start() {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let port: u16 = rng.gen_range(13000..14000);
|
||||
let addr = format!("127.0.0.1:{port}");
|
||||
|
||||
let pass = rng.gen::<u64>().to_string();
|
||||
|
||||
let start_args =
|
||||
format!("start --bind {addr} --user root --pass {pass} memory --no-banner --log info");
|
||||
|
||||
println!("starting server with args: {start_args}");
|
||||
|
||||
let _server = run(&start_args);
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_millis(5000));
|
||||
|
||||
assert!(run(&format!("isready --conn http://{addr}")).output().is_ok());
|
||||
|
||||
// Create a record
|
||||
{
|
||||
let args =
|
||||
format!("sql --conn http://{addr} --user root --pass {pass} --ns N --db D --multi");
|
||||
assert_eq!(
|
||||
run(&args).input("CREATE thing:one;\n").output(),
|
||||
Ok("[{ id: thing:one }]\n\n".to_owned()),
|
||||
"failed to send sql: {args}"
|
||||
);
|
||||
}
|
||||
|
||||
// Export to stdout
|
||||
{
|
||||
let args =
|
||||
format!("export --conn http://{addr} --user root --pass {pass} --ns N --db D -");
|
||||
let output = run(&args).output().expect("failed to run stdout export: {args}");
|
||||
assert!(output.contains("DEFINE TABLE thing SCHEMALESS PERMISSIONS NONE;"));
|
||||
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
|
||||
}
|
||||
|
||||
// Export to file
|
||||
let exported = {
|
||||
let exported = tmp_file("exported.surql");
|
||||
let args = format!(
|
||||
"export --conn http://{addr} --user root --pass {pass} --ns N --db D {exported}"
|
||||
);
|
||||
run(&args).output().expect("failed to run file export: {args}");
|
||||
exported
|
||||
};
|
||||
|
||||
// Import the exported file
|
||||
{
|
||||
let args = format!(
|
||||
"import --conn http://{addr} --user root --pass {pass} --ns N --db D2 {exported}"
|
||||
);
|
||||
run(&args).output().expect("failed to run import: {args}");
|
||||
}
|
||||
|
||||
// Query from the import (pretty-printed this time)
|
||||
{
|
||||
let args = format!(
|
||||
"sql --conn http://{addr} --user root --pass {pass} --ns N --db D2 --pretty"
|
||||
);
|
||||
assert_eq!(
|
||||
run(&args).input("SELECT * FROM thing;\n").output(),
|
||||
Ok("[\n\t{\n\t\tid: thing:one\n\t}\n]\n\n".to_owned()),
|
||||
"failed to send sql: {args}"
|
||||
);
|
||||
}
|
||||
|
||||
// Unfinished backup CLI
|
||||
{
|
||||
let file = tmp_file("backup.db");
|
||||
let args = format!("backup --user root --pass {pass} http://{addr} {file}");
|
||||
run(&args).output().expect("failed to run backup: {args}");
|
||||
|
||||
// TODO: Once backups are functional, update this test.
|
||||
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
|
||||
}
|
||||
|
||||
// Multi-statement (and multi-line) query including error(s) over WS
|
||||
{
|
||||
let args = format!(
|
||||
"sql --conn ws://{addr} --user root --pass {pass} --ns N3 --db D3 --multi --pretty"
|
||||
);
|
||||
let output = run(&args)
|
||||
.input(
|
||||
r#"CREATE thing:success; \
|
||||
CREATE thing:fail SET bad=rand('evil'); \
|
||||
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
|
||||
CREATE thing:also_success;
|
||||
"#,
|
||||
)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
assert!(output.contains("thing:success"), "missing success in {output}");
|
||||
assert!(output.contains("rgument"), "missing argument error in {output}");
|
||||
assert!(
|
||||
output.contains("time") && output.contains("out"),
|
||||
"missing timeout error in {output}"
|
||||
);
|
||||
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
|
||||
}
|
||||
|
||||
// Multi-statement (and multi-line) transaction including error(s) over WS
|
||||
{
|
||||
let args = format!(
|
||||
"sql --conn ws://{addr} --user root --pass {pass} --ns N4 --db D4 --multi --pretty"
|
||||
);
|
||||
let output = run(&args)
|
||||
.input(
|
||||
r#"BEGIN; \
|
||||
CREATE thing:success; \
|
||||
CREATE thing:fail SET bad=rand('evil'); \
|
||||
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
|
||||
CREATE thing:also_success; \
|
||||
COMMIT;
|
||||
"#,
|
||||
)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
output.lines().filter(|s| s.contains("transaction")).count(),
|
||||
3,
|
||||
"missing failed txn errors in {output:?}"
|
||||
);
|
||||
assert!(output.contains("rgument"), "missing argument error in {output}");
|
||||
}
|
||||
|
||||
// Pass neither ns nor db
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} --user root --pass {pass}");
|
||||
let output = run(&args)
|
||||
.input("USE NS N5 DB D5; CREATE thing:one;\n")
|
||||
.output()
|
||||
.expect("neither ns nor db");
|
||||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||
}
|
||||
|
||||
// Pass only ns
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} --user root --pass {pass} --ns N5");
|
||||
let output = run(&args)
|
||||
.input("USE DB D5; SELECT * FROM thing:one;\n")
|
||||
.output()
|
||||
.expect("only ns");
|
||||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||
}
|
||||
|
||||
// Pass only db and expect an error
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} --user root --pass {pass} --db D5");
|
||||
run(&args).output().expect_err("only db");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn start_tls() {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let port: u16 = rng.gen_range(13000..14000);
|
||||
let addr = format!("127.0.0.1:{port}");
|
||||
|
||||
let pass = rng.gen::<u128>().to_string();
|
||||
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
let crt_path = tmp_file("crt.crt");
|
||||
let key_path = tmp_file("key.pem");
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
|
||||
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
|
||||
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
|
||||
|
||||
let start_args = format!(
|
||||
"start --bind {addr} --user root --pass {pass} memory --log info --web-crt {crt_path} --web-key {key_path}"
|
||||
);
|
||||
|
||||
println!("starting server with args: {start_args}");
|
||||
|
||||
let server = run(&start_args);
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_millis(750));
|
||||
|
||||
let output = server.kill().output().unwrap_err();
|
||||
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_found_no_files() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
temp_dir.child("file.txt").touch().unwrap();
|
||||
|
||||
assert!(run_in_dir("validate", &temp_dir).output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_succeed_for_valid_surql_files() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
let statement_file = temp_dir.child("statement.surql");
|
||||
|
||||
statement_file.touch().unwrap();
|
||||
statement_file.write_str("CREATE thing:success;").unwrap();
|
||||
|
||||
assert!(run_in_dir("validate", &temp_dir).output().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_failed_due_to_invalid_glob_pattern() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
const WRONG_GLOB_PATTERN: &str = "**/*{.txt";
|
||||
|
||||
let args = format!("validate \"{}\"", WRONG_GLOB_PATTERN);
|
||||
|
||||
assert!(run_in_dir(&args, &temp_dir).output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_failed_due_to_invalid_surql_files_syntax() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
let statement_file = temp_dir.child("statement.surql");
|
||||
|
||||
statement_file.touch().unwrap();
|
||||
statement_file.write_str("CREATE $thing WHERE value = '';").unwrap();
|
||||
|
||||
assert!(run_in_dir("validate", &temp_dir).output().is_err());
|
||||
}
|
||||
}
|
273
tests/cli_integration.rs
Normal file
273
tests/cli_integration.rs
Normal file
|
@ -0,0 +1,273 @@
|
|||
// cargo test --package surreal --bin surreal --no-default-features --features storage-mem --test cli -- cli_integration --nocapture
|
||||
|
||||
mod common;
|
||||
|
||||
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
||||
use serial_test::serial;
|
||||
use std::fs;
|
||||
|
||||
use common::{PASS, USER};
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn version() {
|
||||
assert!(common::run("version").output().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn help() {
|
||||
assert!(common::run("help").output().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn nonexistent_subcommand() {
|
||||
assert!(common::run("nonexistent").output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn nonexistent_option() {
|
||||
assert!(common::run("version --turbo").output().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn all_commands() {
|
||||
let (addr, _server) = common::start_server(false, true).await.unwrap();
|
||||
let creds = format!("--user {USER} --pass {PASS}");
|
||||
// Create a record
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
|
||||
assert_eq!(
|
||||
common::run(&args).input("CREATE thing:one;\n").output(),
|
||||
Ok("[{ id: thing:one }]\n\n".to_owned()),
|
||||
"failed to send sql: {args}"
|
||||
);
|
||||
}
|
||||
|
||||
// Export to stdout
|
||||
{
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
||||
let output = common::run(&args).output().expect("failed to run stdout export: {args}");
|
||||
assert!(output.contains("DEFINE TABLE thing SCHEMALESS PERMISSIONS NONE;"));
|
||||
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
|
||||
}
|
||||
|
||||
// Export to file
|
||||
let exported = {
|
||||
let exported = common::tmp_file("exported.surql");
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
||||
common::run(&args).output().expect("failed to run file export: {args}");
|
||||
exported
|
||||
};
|
||||
|
||||
// Import the exported file
|
||||
{
|
||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
||||
common::run(&args).output().expect("failed to run import: {args}");
|
||||
}
|
||||
|
||||
// Query from the import (pretty-printed this time)
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
|
||||
assert_eq!(
|
||||
common::run(&args).input("SELECT * FROM thing;\n").output(),
|
||||
Ok("[\n\t{\n\t\tid: thing:one\n\t}\n]\n\n".to_owned()),
|
||||
"failed to send sql: {args}"
|
||||
);
|
||||
}
|
||||
|
||||
// Unfinished backup CLI
|
||||
{
|
||||
let file = common::tmp_file("backup.db");
|
||||
let args = format!("backup {creds} http://{addr} {file}");
|
||||
common::run(&args).output().expect("failed to run backup: {args}");
|
||||
|
||||
// TODO: Once backups are functional, update this test.
|
||||
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
|
||||
}
|
||||
|
||||
// Multi-statement (and multi-line) query including error(s) over WS
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
|
||||
let output = common::run(&args)
|
||||
.input(
|
||||
r#"CREATE thing:success; \
|
||||
CREATE thing:fail SET bad=rand('evil'); \
|
||||
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
|
||||
CREATE thing:also_success;
|
||||
"#,
|
||||
)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
assert!(output.contains("thing:success"), "missing success in {output}");
|
||||
assert!(output.contains("rgument"), "missing argument error in {output}");
|
||||
assert!(
|
||||
output.contains("time") && output.contains("out"),
|
||||
"missing timeout error in {output}"
|
||||
);
|
||||
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
|
||||
}
|
||||
|
||||
// Multi-statement (and multi-line) transaction including error(s) over WS
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
|
||||
let output = common::run(&args)
|
||||
.input(
|
||||
r#"BEGIN; \
|
||||
CREATE thing:success; \
|
||||
CREATE thing:fail SET bad=rand('evil'); \
|
||||
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
|
||||
CREATE thing:also_success; \
|
||||
COMMIT;
|
||||
"#,
|
||||
)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
output.lines().filter(|s| s.contains("transaction")).count(),
|
||||
3,
|
||||
"missing failed txn errors in {output:?}"
|
||||
);
|
||||
assert!(output.contains("rgument"), "missing argument error in {output}");
|
||||
}
|
||||
|
||||
// Pass neither ns nor db
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds}");
|
||||
let output = common::run(&args)
|
||||
.input("USE NS N5 DB D5; CREATE thing:one;\n")
|
||||
.output()
|
||||
.expect("neither ns nor db");
|
||||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||
}
|
||||
|
||||
// Pass only ns
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --ns N5");
|
||||
let output = common::run(&args)
|
||||
.input("USE DB D5; SELECT * FROM thing:one;\n")
|
||||
.output()
|
||||
.expect("only ns");
|
||||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||
}
|
||||
|
||||
// Pass only db and expect an error
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --db D5");
|
||||
common::run(&args).output().expect_err("only db");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn start_tls() {
|
||||
let (_, server) = common::start_server(true, false).await.unwrap();
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_millis(2000));
|
||||
let output = server.kill().output().err().unwrap();
|
||||
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn with_root_auth() {
|
||||
let (addr, _server) = common::start_server(false, true).await.unwrap();
|
||||
let creds = format!("--user {USER} --pass {PASS}");
|
||||
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
||||
|
||||
// Can query /sql over HTTP
|
||||
{
|
||||
let args = format!("{sql_args} {creds}");
|
||||
let input = "INFO FOR ROOT;";
|
||||
let output = common::run(&args).input(input).output();
|
||||
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
|
||||
}
|
||||
|
||||
// Can query /sql over WS
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
||||
let input = "INFO FOR ROOT;";
|
||||
let output = common::run(&args).input(input).output();
|
||||
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
|
||||
}
|
||||
|
||||
// KV user can do exports
|
||||
let exported = {
|
||||
let exported = common::tmp_file("exported.surql");
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
||||
|
||||
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run export: {args}"));
|
||||
exported
|
||||
};
|
||||
|
||||
// KV user can do imports
|
||||
{
|
||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
||||
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
|
||||
}
|
||||
|
||||
// KV user can do backups
|
||||
{
|
||||
let file = common::tmp_file("backup.db");
|
||||
let args = format!("backup {creds} http://{addr} {file}");
|
||||
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run backup: {args}"));
|
||||
|
||||
// TODO: Once backups are functional, update this test.
|
||||
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_found_no_files() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
temp_dir.child("file.txt").touch().unwrap();
|
||||
|
||||
assert!(common::run_in_dir("validate", &temp_dir).output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_succeed_for_valid_surql_files() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
let statement_file = temp_dir.child("statement.surql");
|
||||
|
||||
statement_file.touch().unwrap();
|
||||
statement_file.write_str("CREATE thing:success;").unwrap();
|
||||
|
||||
assert!(common::run_in_dir("validate", &temp_dir).output().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_failed_due_to_invalid_glob_pattern() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
const WRONG_GLOB_PATTERN: &str = "**/*{.txt";
|
||||
|
||||
let args = format!("validate \"{}\"", WRONG_GLOB_PATTERN);
|
||||
|
||||
assert!(common::run_in_dir(&args, &temp_dir).output().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn validate_failed_due_to_invalid_surql_files_syntax() {
|
||||
let temp_dir = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
let statement_file = temp_dir.child("statement.surql");
|
||||
|
||||
statement_file.touch().unwrap();
|
||||
statement_file.write_str("CREATE $thing WHERE value = '';").unwrap();
|
||||
|
||||
assert!(common::run_in_dir("validate", &temp_dir).output().is_err());
|
||||
}
|
140
tests/common/mod.rs
Normal file
140
tests/common/mod.rs
Normal file
|
@ -0,0 +1,140 @@
|
|||
#![allow(dead_code)]
|
||||
use rand::{thread_rng, Rng};
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
use tokio::time;
|
||||
|
||||
pub const USER: &str = "root";
|
||||
pub const PASS: &str = "root";
|
||||
|
||||
/// Child is a (maybe running) CLI process. It can be killed by dropping it
|
||||
pub struct Child {
|
||||
inner: Option<std::process::Child>,
|
||||
}
|
||||
|
||||
impl Child {
|
||||
/// Send some thing to the child's stdin
|
||||
pub fn input(mut self, input: &str) -> Self {
|
||||
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
|
||||
use std::io::Write;
|
||||
stdin.write_all(input.as_bytes()).unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn kill(mut self) -> Self {
|
||||
self.inner.as_mut().unwrap().kill().unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
||||
/// returns successfully, Err otherwise.
|
||||
pub fn output(mut self) -> Result<String, String> {
|
||||
let output = self.inner.take().unwrap().wait_with_output().unwrap();
|
||||
|
||||
let mut buf = String::from_utf8(output.stdout).unwrap();
|
||||
buf.push_str(&String::from_utf8(output.stderr).unwrap());
|
||||
|
||||
if output.status.success() {
|
||||
Ok(buf)
|
||||
} else {
|
||||
Err(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Child {
|
||||
fn drop(&mut self) {
|
||||
if let Some(inner) = self.inner.as_mut() {
|
||||
let _ = inner.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
||||
let mut path = std::env::current_exe().unwrap();
|
||||
assert!(path.pop());
|
||||
if path.ends_with("deps") {
|
||||
assert!(path.pop());
|
||||
}
|
||||
|
||||
// Note: Cargo automatically builds this binary for integration tests.
|
||||
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
|
||||
|
||||
let mut cmd = Command::new(path);
|
||||
if let Some(dir) = current_dir {
|
||||
cmd.current_dir(&dir);
|
||||
}
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
cmd.args(args.split_ascii_whitespace());
|
||||
Child {
|
||||
inner: Some(cmd.spawn().unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args
|
||||
pub fn run(args: &str) -> Child {
|
||||
run_internal::<String>(args, None)
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args inside a temporary directory
|
||||
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||
run_internal(args, Some(current_dir))
|
||||
}
|
||||
|
||||
pub fn tmp_file(name: &str) -> String {
|
||||
let path = Path::new(env!("OUT_DIR")).join(name);
|
||||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
pub async fn start_server(
|
||||
tls: bool,
|
||||
wait_is_ready: bool,
|
||||
) -> Result<(String, Child), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let port: u16 = rng.gen_range(13000..14000);
|
||||
let addr = format!("127.0.0.1:{port}");
|
||||
|
||||
let mut extra_args = String::default();
|
||||
if tls {
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
let crt_path = tmp_file("crt.crt");
|
||||
let key_path = tmp_file("key.pem");
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
|
||||
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
|
||||
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
|
||||
|
||||
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
|
||||
}
|
||||
|
||||
let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}");
|
||||
|
||||
println!("starting server with args: {start_args}");
|
||||
|
||||
let server = run(&start_args);
|
||||
|
||||
if !wait_is_ready {
|
||||
return Ok((addr, server));
|
||||
}
|
||||
|
||||
// Wait 5 seconds for the server to start
|
||||
let mut interval = time::interval(time::Duration::from_millis(500));
|
||||
println!("Waiting for server to start...");
|
||||
for _i in 0..10 {
|
||||
interval.tick().await;
|
||||
|
||||
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
||||
println!("Server ready!");
|
||||
return Ok((addr, server));
|
||||
}
|
||||
}
|
||||
|
||||
let server_out = server.kill().output().err().unwrap();
|
||||
println!("server output: {server_out}");
|
||||
Err("server failed to start".into())
|
||||
}
|
1302
tests/http_integration.rs
Normal file
1302
tests/http_integration.rs
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue