[metrics] HTTP Layer + move to Axum (#2237)

This commit is contained in:
Salvador Girones Gil 2023-07-19 16:35:56 +02:00 committed by GitHub
parent eef9b755cb
commit 53702c247a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
60 changed files with 3873 additions and 1430 deletions

View file

@ -118,7 +118,29 @@ jobs:
sudo apt-get -y install protobuf-compiler libprotobuf-dev sudo apt-get -y install protobuf-compiler libprotobuf-dev
- name: Run cargo test - name: Run cargo test
run: cargo test --locked --no-default-features --features storage-mem --workspace --test cli run: cargo test --locked --no-default-features --features storage-mem --workspace --test cli_integration
http-server:
name: Test HTTP server
runs-on: ubuntu-20.04
steps:
- name: Install stable toolchain
uses: dtolnay/rust-toolchain@stable
- name: Checkout sources
uses: actions/checkout@v3
- name: Setup cache
uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt-get -y update
sudo apt-get -y install protobuf-compiler libprotobuf-dev
- name: Run cargo test
run: cargo test --locked --no-default-features --features storage-mem --workspace --test http_integration
test: test:
name: Test workspace name: Test workspace

301
Cargo.lock generated
View file

@ -405,20 +405,6 @@ dependencies = [
"futures-core", "futures-core",
] ]
[[package]]
name = "async-compression"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a"
dependencies = [
"brotli",
"flate2",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "async-executor" name = "async-executor"
version = "1.5.1" version = "1.5.1"
@ -543,9 +529,11 @@ checksum = "a6a1de45611fdb535bfde7b7de4fd54f4fd2b17b1737c0a59b69bf9b92074b8c"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"base64 0.21.2",
"bitflags 1.3.2", "bitflags 1.3.2",
"bytes", "bytes",
"futures-util", "futures-util",
"headers",
"http", "http",
"http-body", "http-body",
"hyper", "hyper",
@ -560,11 +548,25 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite 0.19.0",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
]
[[package]]
name = "axum-client-ip"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8e81eacc93f36480825da5f46a33b5fb2246ed024eacc9e8933425b80c5807"
dependencies = [
"axum",
"forwarded-header-value",
"serde",
] ]
[[package]] [[package]]
@ -582,6 +584,7 @@ dependencies = [
"rustversion", "rustversion",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
@ -595,6 +598,64 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "axum-extra"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "febf23ab04509bd7672e6abe76bd8277af31b679e89fa5ffc6087dc289a448a3"
dependencies = [
"axum",
"axum-core",
"axum-macros",
"bytes",
"form_urlencoded",
"futures-util",
"http",
"http-body",
"mime",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_html_form",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-macros"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bb524613be645939e280b7279f7b017f98cf7f5ef084ec374df373530e73277"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.26",
]
[[package]]
name = "axum-server"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063"
dependencies = [
"arc-swap",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"pin-project-lite",
"rustls 0.21.5",
"rustls-pemfile",
"tokio",
"tokio-rustls 0.24.1",
"tower-service",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.68" version = "0.3.68"
@ -616,6 +677,12 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.21.2" version = "0.21.2"
@ -1349,6 +1416,12 @@ dependencies = [
"parking_lot_core 0.9.8", "parking_lot_core 0.9.8",
] ]
[[package]]
name = "data-encoding"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308"
[[package]] [[package]]
name = "debugid" name = "debugid"
version = "0.8.0" version = "0.8.0"
@ -1654,6 +1727,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "forwarded-header-value"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
dependencies = [
"nonempty",
"thiserror",
]
[[package]] [[package]]
name = "foundationdb" name = "foundationdb"
version = "0.8.0" version = "0.8.0"
@ -2170,6 +2253,12 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]] [[package]]
name = "httparse" name = "httparse"
version = "1.8.0" version = "1.8.0"
@ -2741,24 +2830,6 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "multer"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http",
"httparse",
"log",
"memchr",
"mime",
"spin 0.9.8",
"version_check",
]
[[package]] [[package]]
name = "multimap" name = "multimap"
version = "0.8.3" version = "0.8.3"
@ -2844,6 +2915,12 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonempty"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.46.0" version = "0.46.0"
@ -2982,9 +3059,9 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry" name = "opentelemetry"
version = "0.18.0" version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f"
dependencies = [ dependencies = [
"opentelemetry_api", "opentelemetry_api",
"opentelemetry_sdk", "opentelemetry_sdk",
@ -2992,9 +3069,9 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry-otlp" name = "opentelemetry-otlp"
version = "0.11.0" version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde" checksum = "8af72d59a4484654ea8eb183fea5ae4eb6a41d7ac3e3bae5f4d2a282a3a7d3ca"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"futures 0.3.28", "futures 0.3.28",
@ -3010,39 +3087,38 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry-proto" name = "opentelemetry-proto"
version = "0.1.0" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28" checksum = "045f8eea8c0fa19f7d48e7bc3128a39c2e5c533d5c61298c548dfefc1064474c"
dependencies = [ dependencies = [
"futures 0.3.28", "futures 0.3.28",
"futures-util", "futures-util",
"opentelemetry", "opentelemetry",
"prost 0.11.9", "prost 0.11.9",
"tonic", "tonic",
"tonic-build",
] ]
[[package]] [[package]]
name = "opentelemetry_api" name = "opentelemetry_api"
version = "0.18.0" version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22" checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2"
dependencies = [ dependencies = [
"fnv", "fnv",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"indexmap 1.9.3", "indexmap 1.9.3",
"js-sys",
"once_cell", "once_cell",
"pin-project-lite", "pin-project-lite",
"thiserror", "thiserror",
"urlencoding",
] ]
[[package]] [[package]]
name = "opentelemetry_sdk" name = "opentelemetry_sdk"
version = "0.18.0" version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113" checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"crossbeam-channel", "crossbeam-channel",
@ -4241,12 +4317,6 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.1.0" version = "1.1.0"
@ -4358,6 +4428,19 @@ dependencies = [
"syn 2.0.26", "syn 2.0.26",
] ]
[[package]]
name = "serde_html_form"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7d9b8af723e4199801a643c622daa7aae8cf4a772dc2b3efcb3a95add6cb91a"
dependencies = [
"form_urlencoded",
"indexmap 2.0.0",
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.103" version = "1.0.103"
@ -4607,12 +4690,18 @@ version = "1.0.0-beta.9"
dependencies = [ dependencies = [
"argon2", "argon2",
"assert_fs", "assert_fs",
"axum",
"axum-client-ip",
"axum-extra",
"axum-server",
"base64 0.21.2", "base64 0.21.2",
"bytes", "bytes",
"clap 4.3.12", "clap 4.3.12",
"futures 0.3.28", "futures 0.3.28",
"futures-util",
"glob", "glob",
"http", "http",
"http-body",
"hyper", "hyper",
"ipnet", "ipnet",
"nix", "nix",
@ -4620,6 +4709,7 @@ dependencies = [
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",
"opentelemetry-proto", "opentelemetry-proto",
"pin-project-lite",
"rand 0.8.5", "rand 0.8.5",
"rcgen", "rcgen",
"reqwest", "reqwest",
@ -4637,12 +4727,13 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tonic", "tonic",
"tower",
"tower-http",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",
"tracing-subscriber", "tracing-subscriber",
"urlencoding", "urlencoding",
"uuid", "uuid",
"warp",
] ]
[[package]] [[package]]
@ -4715,7 +4806,7 @@ dependencies = [
"thiserror", "thiserror",
"time 0.3.23", "time 0.3.23",
"tokio", "tokio",
"tokio-tungstenite", "tokio-tungstenite 0.18.0",
"tokio-util", "tokio-util",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -5141,11 +5232,23 @@ dependencies = [
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tokio-rustls 0.23.4", "tokio-rustls 0.23.4",
"tungstenite", "tungstenite 0.18.0",
"webpki", "webpki",
"webpki-roots", "webpki-roots",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.19.0",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.8" version = "0.7.8"
@ -5219,19 +5322,6 @@ dependencies = [
"tracing-futures", "tracing-futures",
] ]
[[package]]
name = "tonic-build"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4"
dependencies = [
"prettyplease 0.1.25",
"proc-macro2",
"prost-build 0.11.9",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "tower" name = "tower"
version = "0.4.13" version = "0.4.13"
@ -5252,6 +5342,29 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tower-http"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c"
dependencies = [
"base64 0.20.0",
"bitflags 2.3.3",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"mime",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
"tracing",
"uuid",
]
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.2" version = "0.3.2"
@ -5321,9 +5434,9 @@ dependencies = [
[[package]] [[package]]
name = "tracing-opentelemetry" name = "tracing-opentelemetry"
version = "0.18.0" version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" checksum = "00a39dcf9bfc1742fa4d6215253b33a6e474be78275884c216fc2a06267b3600"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"opentelemetry", "opentelemetry",
@ -5396,6 +5509,25 @@ dependencies = [
"webpki", "webpki",
] ]
[[package]]
name = "tungstenite"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.16.0" version = "1.16.0"
@ -5544,39 +5676,6 @@ dependencies = [
"try-lock", "try-lock",
] ]
[[package]]
name = "warp"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69"
dependencies = [
"async-compression",
"bytes",
"futures-channel",
"futures-util",
"headers",
"http",
"hyper",
"log",
"mime",
"mime_guess",
"multer",
"percent-encoding",
"pin-project",
"rustls-pemfile",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-rustls 0.23.4",
"tokio-stream",
"tokio-tungstenite",
"tokio-util",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.9.0+wasi-snapshot-preview1" version = "0.9.0+wasi-snapshot-preview1"

View file

@ -34,17 +34,24 @@ strip = false
[dependencies] [dependencies]
argon2 = "0.5.1" argon2 = "0.5.1"
axum = { version = "0.6.18", features = ["tracing", "ws", "headers"] }
axum-client-ip = "0.4.1"
axum-extra = { version = "0.7.4", features = ["typed-routing"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
base64 = "0.21.2" base64 = "0.21.2"
bytes = "1.4.0" bytes = "1.4.0"
clap = { version = "4.3.12", features = ["env", "derive", "wrap_help", "unicode"] } clap = { version = "4.3.12", features = ["env", "derive", "wrap_help", "unicode"] }
futures = "0.3.28" futures = "0.3.28"
futures-util = "0.3.28"
glob = "0.3.1" glob = "0.3.1"
http = "0.2.9" http = "0.2.9"
http-body = "0.4.5"
hyper = "0.14.27" hyper = "0.14.27"
ipnet = "2.8.0" ipnet = "2.8.0"
once_cell = "1.18.0" once_cell = "1.18.0"
opentelemetry = { version = "0.18", features = ["rt-tokio"] } opentelemetry = { version = "0.19", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
pin-project-lite = "0.2.9"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.18", features = ["blocking"] } reqwest = { version = "0.11.18", features = ["blocking"] }
rustyline = { version = "11.0.0", features = ["derive"] } rustyline = { version = "11.0.0", features = ["derive"] }
@ -57,19 +64,20 @@ tempfile = "3.6.0"
thiserror = "1.0.43" thiserror = "1.0.43"
tokio = { version = "1.29.1", features = ["macros", "signal"] } tokio = { version = "1.29.1", features = ["macros", "signal"] }
tokio-util = { version = "0.7.8", features = ["io"] } tokio-util = { version = "0.7.8", features = ["io"] }
tower = "0.4.13"
tower-http = { version = "0.4.1", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
urlencoding = "2.1.2" urlencoding = "2.1.2"
uuid = { version = "1.4.0", features = ["serde", "js", "v4", "v7"] } uuid = { version = "1.4.0", features = ["serde", "js", "v4", "v7"] }
warp = { version = "0.3.5", features = ["compression", "tls", "websocket"] }
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix = "0.26.2" nix = "0.26.2"
[dev-dependencies] [dev-dependencies]
assert_fs = "1.0.13" assert_fs = "1.0.13"
opentelemetry-proto = { version = "0.1.0", features = ["gen-tonic", "traces", "build-server"] } opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
rcgen = "0.10.0" rcgen = "0.10.0"
serial_test = "2.0.0" serial_test = "2.0.0"
temp-env = "0.3.4" temp-env = "0.3.4"

69
dev/docker/compose.yaml Normal file
View file

@ -0,0 +1,69 @@
---
version: "3"
services:
grafana:
image: "grafana/grafana-oss:latest"
expose:
- "3000"
ports:
- "3000:3000"
volumes:
- "grafana:/var/lib/grafana"
- "./grafana.ini:/etc/grafana/grafana.ini"
- "./grafana-datasource.yaml:/etc/grafana/provisioning/datasources/grafana-datasource.yaml"
- "./grafana-dashboards.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboards.yaml"
- "./dashboards:/dashboards"
healthcheck:
test:
- CMD-SHELL
- bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/3001; exit $$?;'
interval: 1s
timeout: 5s
retries: 5
prometheus:
image: "prom/prometheus:latest"
command:
- "--config.file=/etc/prometheus/prometheus.yaml"
- "--storage.tsdb.path=/prometheus"
- "--web.console.libraries=/usr/share/prometheus/console_libraries"
- "--web.console.templates=/usr/share/prometheus/consoles"
- "--web.listen-address=0.0.0.0:9090"
- "--web.enable-remote-write-receiver"
- "--storage.tsdb.retention.time=1d"
expose:
- "9090"
ports:
- "9090:9090"
volumes:
- "prometheus:/prometheus"
- "./prometheus.yaml:/etc/prometheus/prometheus.yaml"
tempo:
image: grafana/tempo:latest
command: [ "-config.file=/etc/tempo.yaml" ]
volumes:
- ./tempo.yaml:/etc/tempo.yaml
- tempo:/tmp/tempo
ports:
- "3200" # tempo
- "4317" # otlp grpc
otel-collector:
image: "otel/opentelemetry-collector-contrib"
command:
- "--config=/etc/otel-collector.yaml"
expose:
- "4317"
ports:
- "4317:4317" # otlp grpc
- "9090" # for prometheus
volumes: ["./otel-collector.yaml:/etc/otel-collector.yaml"]
volumes:
grafana:
external: false
prometheus:
external: false
tempo:
external: false

View file

@ -0,0 +1,15 @@
apiVersion: 1
providers:
- name: 'surrealdb-grafana'
orgId: 1
folder: ''
folderUid: ''
type: file
disableDeletion: false
updateIntervalSeconds: 1
allowUiUpdates: true
options:
path: /dashboards
foldersFromFilesStructure: false

View file

@ -0,0 +1,30 @@
apiVersion: 1
deleteDatasources:
- name: Prometheus
- name: Tempo
datasources:
- name: Prometheus
type: prometheus
access: proxy
url: http://prometheus:9090
withCredentials: false
isDefault: true
tlsAuth: false
tlsAuthWithCACert: false
version: 1
editable: true
- name: Tempo
type: tempo
access: proxy
orgId: 1
url: http://tempo:3200
basicAuth: false
isDefault: false
version: 1
editable: false
apiVersion: 1
uid: tempo
jsonData:
httpMethod: GET
serviceMap:
datasourceUid: prometheus

6
dev/docker/grafana.ini Normal file
View file

@ -0,0 +1,6 @@
[server]
http_addr = 0.0.0.0
http_port = 3000
[users]
default_theme = light

View file

@ -0,0 +1,36 @@
receivers:
otlp:
protocols:
grpc:
exporters:
otlp:
endpoint: 'tempo:4317'
tls:
insecure: true
prometheus:
endpoint: ':9090'
send_timestamps: true
metric_expiration: 60m
resource_to_telemetry_conversion:
enabled: true
logging: # add to a pipeline for debugging
loglevel: debug
# processors:
# batch:
# timeout: 1s
# span:
# name:
# from_attributes: ["name"]
service:
pipelines:
traces:
receivers: [otlp]
exporters: [otlp, logging]
metrics:
receivers: [otlp]
exporters: [prometheus]

View file

@ -0,0 +1,17 @@
global:
scrape_interval: 5s
evaluation_interval: 10s
scrape_configs:
- job_name: prometheus
static_configs:
- targets: ["prometheus:9500"]
- job_name: 'tempo'
static_configs:
- targets: ["tempo:3200"]
- job_name: otel-collector
static_configs:
# Scrap the SurrealDB metrics sent to OpenTelemetry collector
- targets: ["otel-collector:9090"]

23
dev/docker/tempo.yaml Normal file
View file

@ -0,0 +1,23 @@
server:
http_listen_port: 3200
distributor:
receivers:
otlp:
protocols:
grpc:
ingester:
max_block_duration: 5m # cut the headblock when this much time passes. this is being set for dev purposes and should probably be left alone normally
compactor:
compaction:
block_retention: 1h # overall Tempo trace retention. set for dev purposes
storage:
trace:
backend: local
wal:
path: /tmp/tempo/wal
local:
path: /tmp/tempo/blocks

16
doc/TELEMETRY.md Normal file
View file

@ -0,0 +1,16 @@
# Telemetry
SurrealDB leverages the tracing and opentelemetry libraries to instrument the code.
Both metrics and traces are pushed to an OTEL compatible receiver.
For local development, you can start the observability stack defined in `dev/docker`. It spins up an instance of Opentelemetry collector, Grafana, Prometheus and Tempo:
```
$ docker-compose -f dev/docker/compose.yaml up -d
$ SURREAL_TRACING_TRACER=otlp OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317" surreal start
```
Now you can use the SurrealDB server and see the telemetry data opening this URL in the browser: http://localhost:3000
To login into Grafana, use the default user `admin` and password `admin`.

View file

@ -37,7 +37,7 @@ pub async fn init(
}: BackupCommandArguments, }: BackupCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("error").init(); crate::telemetry::builder().with_log_level("error").init();
// Process the source->destination response // Process the source->destination response
let into_local = into.ends_with(".db"); let into_local = into.ends_with(".db");

View file

@ -38,7 +38,7 @@ pub async fn init(
}: ExportCommandArguments, }: ExportCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("error").init(); crate::telemetry::builder().with_log_level("error").init();
let root = Root { let root = Root {
username: &username, username: &username,

View file

@ -36,7 +36,7 @@ pub async fn init(
}: ImportCommandArguments, }: ImportCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("error").init(); crate::telemetry::builder().with_log_level("error").init();
let root = Root { let root = Root {
username: &username, username: &username,

View file

@ -17,7 +17,7 @@ pub async fn init(
}: IsReadyCommandArguments, }: IsReadyCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("error").init(); crate::telemetry::builder().with_log_level("error").init();
// Connect to the database engine // Connect to the database engine
connect(endpoint).await?; connect(endpoint).await?;
println!("OK"); println!("OK");

View file

@ -49,7 +49,7 @@ pub async fn init(
}: SqlCommandArguments, }: SqlCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("warn").init(); crate::telemetry::builder().with_log_level("warn").init();
let root = Root { let root = Root {
username: &username, username: &username,

View file

@ -101,7 +101,7 @@ pub async fn init(
}: StartCommandArguments, }: StartCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_filter(log).init(); crate::telemetry::builder().with_filter(log).init();
// Check if a banner should be outputted // Check if a banner should be outputted
if !no_banner { if !no_banner {

View file

@ -47,7 +47,7 @@ impl UpgradeCommandArguments {
pub async fn init(args: UpgradeCommandArguments) -> Result<(), Error> { pub async fn init(args: UpgradeCommandArguments) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("error").init(); crate::telemetry::builder().with_log_level("error").init();
// Upgrading overwrites the existing executable // Upgrading overwrites the existing executable
let exe = std::env::current_exe()?; let exe = std::env::current_exe()?;

View file

@ -31,6 +31,10 @@ impl TypedValueParser for CustomEnvFilterParser {
arg: Option<&clap::Arg>, arg: Option<&clap::Arg>,
value: &std::ffi::OsStr, value: &std::ffi::OsStr,
) -> Result<Self::Value, clap::Error> { ) -> Result<Self::Value, clap::Error> {
if let Ok(dirs) = std::env::var("RUST_LOG") {
return Ok(CustomEnvFilter(EnvFilter::builder().parse_lossy(dirs)));
}
let inner = NonEmptyStringValueParser::new(); let inner = NonEmptyStringValueParser::new();
let v = inner.parse_ref(cmd, arg, value)?; let v = inner.parse_ref(cmd, arg, value)?;
let filter = (match v.as_str() { let filter = (match v.as_str() {

View file

@ -18,7 +18,7 @@ pub async fn init(
}: VersionCommandArguments, }: VersionCommandArguments,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Initialize opentelemetry and logging // Initialize opentelemetry and logging
crate::o11y::builder().with_log_level("error").init(); crate::telemetry::builder().with_log_level("error").init();
// Print server version if endpoint supplied else CLI version // Print server version if endpoint supplied else CLI version
if let Some(e) = endpoint { if let Some(e) = endpoint {
// Print remote server version // Print remote server version

View file

@ -1,4 +1,7 @@
use axum::response::{IntoResponse, Response};
use axum::Json;
use base64::DecodeError as Base64Error; use base64::DecodeError as Base64Error;
use http::StatusCode;
use reqwest::Error as ReqwestError; use reqwest::Error as ReqwestError;
use serde::Serialize; use serde::Serialize;
use serde_cbor::error::Error as CborError; use serde_cbor::error::Error as CborError;
@ -6,6 +9,7 @@ use serde_json::error::Error as JsonError;
use serde_pack::encode::Error as PackError; use serde_pack::encode::Error as PackError;
use std::io::Error as IoError; use std::io::Error as IoError;
use std::string::FromUtf8Error as Utf8Error; use std::string::FromUtf8Error as Utf8Error;
use surrealdb::error::Db as SurrealDbError;
use surrealdb::Error as SurrealError; use surrealdb::Error as SurrealError;
use thiserror::Error; use thiserror::Error;
@ -51,8 +55,6 @@ pub enum Error {
Remote(#[from] ReqwestError), Remote(#[from] ReqwestError),
} }
impl warp::reject::Reject for Error {}
impl From<Error> for String { impl From<Error> for String {
fn from(e: Error) -> String { fn from(e: Error) -> String {
e.to_string() e.to_string()
@ -85,3 +87,56 @@ impl Serialize for Error {
serializer.serialize_str(self.to_string().as_str()) serializer.serialize_str(self.to_string().as_str())
} }
} }
#[derive(Serialize)]
pub(super) struct Message {
code: u16,
details: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
information: Option<String>,
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
match self {
err @ Error::InvalidAuth | err @ Error::Db(SurrealError::Db(SurrealDbError::InvalidAuth)) => (
StatusCode::UNAUTHORIZED,
Json(Message {
code: 401,
details: Some("Authentication failed".to_string()),
description: Some("Your authentication details are invalid. Reauthenticate using valid authentication parameters.".to_string()),
information: Some(err.to_string()),
})
),
Error::InvalidType => (
StatusCode::UNSUPPORTED_MEDIA_TYPE,
Json(Message {
code: 415,
details: Some("Unsupported media type".to_string()),
description: Some("The request needs to adhere to certain constraints. Refer to the documentation for supported content types.".to_string()),
information: None,
}),
),
Error::InvalidStorage => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(Message {
code: 500,
details: Some("Health check failed".to_string()),
description: Some("The database health check for this instance failed. There was an issue with the underlying storage engine.".to_string()),
information: Some(self.to_string()),
}),
),
_ => (
StatusCode::BAD_REQUEST,
Json(Message {
code: 400,
details: Some("Request problems detected".to_string()),
description: Some("There is a problem with your request. Refer to the documentation for further information.".to_string()),
information: Some(self.to_string()),
}),
),
}.into_response()
}
}

View file

@ -3,8 +3,6 @@ pub mod verify;
use crate::cli::CF; use crate::cli::CF;
use crate::err::Error; use crate::err::Error;
pub const BASIC: &str = "Basic ";
pub async fn init() -> Result<(), Error> { pub async fn init() -> Result<(), Error> {
// Get local copy of options // Get local copy of options
let opt = CF.get().unwrap(); let opt = CF.get().unwrap();

View file

@ -1,76 +1,63 @@
use crate::cli::CF; use crate::cli::CF;
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::iam::BASIC;
use argon2::password_hash::{PasswordHash, PasswordVerifier}; use argon2::password_hash::{PasswordHash, PasswordVerifier};
use argon2::Argon2; use argon2::Argon2;
use std::sync::Arc; use std::sync::Arc;
use surrealdb::dbs::Auth; use surrealdb::dbs::Auth;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::iam::base::{Engine, BASE64};
pub async fn basic(session: &mut Session, auth: String) -> Result<(), Error> { pub async fn basic(session: &mut Session, user: &str, pass: &str) -> Result<(), Error> {
// Log the authentication type
trace!("Attempting basic authentication");
// Retrieve just the auth data
let auth = auth.trim_start_matches(BASIC).trim();
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
// Get the config options // Get the config options
let opts = CF.get().unwrap(); let opts = CF.get().unwrap();
// Decode the encoded auth data
let auth = BASE64.decode(auth)?; // Check that the details are not empty
// Convert the auth data to String if user.is_empty() || pass.is_empty() {
let auth = String::from_utf8(auth)?; return Err(Error::InvalidAuth);
// Split the auth data into user and pass }
if let Some((user, pass)) = auth.split_once(':') { // Check if this is root authentication
// Check that the details are not empty if let Some(root) = &opts.pass {
if user.is_empty() || pass.is_empty() { if user == opts.user && pass == root {
return Err(Error::InvalidAuth); // Log the authentication type
debug!("Authenticated as super user");
// Store the authentication data
session.au = Arc::new(Auth::Kv);
return Ok(());
} }
// Check if this is root authentication }
if let Some(root) = &opts.pass { // Check if this is NS authentication
if user == opts.user && pass == root { if let Some(ns) = &session.ns {
// Log the authentication type // Create a new readonly transaction
debug!("Authenticated as super user"); let mut tx = kvs.transaction(false, false).await?;
// Check if the supplied NS Login exists
if let Ok(nl) = tx.get_nl(ns, user).await {
// Compute the hash and verify the password
let hash = PasswordHash::new(&nl.hash).unwrap();
if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() {
// Log the successful namespace authentication
debug!("Authenticated as namespace user: {}", user);
// Store the authentication data // Store the authentication data
session.au = Arc::new(Auth::Kv); session.au = Arc::new(Auth::Ns(ns.to_owned()));
return Ok(()); return Ok(());
} }
} };
// Check if this is NS authentication // Check if this is DB authentication
if let Some(ns) = &session.ns { if let Some(db) = &session.db {
// Create a new readonly transaction // Check if the supplied DB Login exists
let mut tx = kvs.transaction(false, false).await?; if let Ok(dl) = tx.get_dl(ns, db, user).await {
// Check if the supplied NS Login exists
if let Ok(nl) = tx.get_nl(ns, user).await {
// Compute the hash and verify the password // Compute the hash and verify the password
let hash = PasswordHash::new(&nl.hash).unwrap(); let hash = PasswordHash::new(&dl.hash).unwrap();
if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() { if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() {
// Log the successful namespace authentication // Log the successful namespace authentication
debug!("Authenticated as namespace user: {}", user); debug!("Authenticated as namespace user: {}", user);
// Store the authentication data // Store the authentication data
session.au = Arc::new(Auth::Ns(ns.to_owned())); session.au = Arc::new(Auth::Db(ns.to_owned(), db.to_owned()));
return Ok(()); return Ok(());
} }
}; };
// Check if this is DB authentication
if let Some(db) = &session.db {
// Check if the supplied DB Login exists
if let Ok(dl) = tx.get_dl(ns, db, user).await {
// Compute the hash and verify the password
let hash = PasswordHash::new(&dl.hash).unwrap();
if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() {
// Log the successful namespace authentication
debug!("Authenticated as database user: {}", user);
// Store the authentication data
session.au = Arc::new(Auth::Db(ns.to_owned(), db.to_owned()));
return Ok(());
}
};
}
} }
} }
// There was an auth error
Err(Error::InvalidAuth) Err(Error::InvalidAuth)
} }

View file

@ -27,9 +27,9 @@ mod err;
mod iam; mod iam;
#[cfg(feature = "has-storage")] #[cfg(feature = "has-storage")]
mod net; mod net;
mod o11y;
#[cfg(feature = "has-storage")] #[cfg(feature = "has-storage")]
mod rpc; mod rpc;
mod telemetry;
use std::future::Future; use std::future::Future;
use std::process::ExitCode; use std::process::ExitCode;

106
src/net/auth.rs Normal file
View file

@ -0,0 +1,106 @@
use axum::{
body::{boxed, Body, BoxBody},
headers::{
authorization::{Basic, Bearer},
Authorization, Origin,
},
Extension, RequestPartsExt, TypedHeader,
};
use futures_util::future::BoxFuture;
use http::{request::Parts, StatusCode};
use hyper::{Request, Response};
use surrealdb::{dbs::Session, iam::verify::token};
use tower_http::auth::AsyncAuthorizeRequest;
use crate::{dbs::DB, err::Error, iam::verify::basic};
use super::{client_ip::ExtractClientIP, AppState};
///
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
///
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
///
/// ```rust
/// use tower_http::auth::RequireAuthorizationLayer;
/// use surrealdb::net::SurrealAuth;
/// use axum::Router;
///
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(auth);
/// ```
#[derive(Clone, Copy)]
pub(super) struct SurrealAuth;
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
where
B: Send + Sync + 'static,
{
type RequestBody = B;
type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
Box::pin(async {
let (mut parts, body) = request.into_parts();
match check_auth(&mut parts).await {
Ok(sess) => {
parts.extensions.insert(sess);
Ok(Request::from_parts(parts, body))
}
Err(err) => {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(err.to_string())))
.unwrap();
Err(unauthorized_response)
}
}
})
}
}
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
let kvs = DB.get().unwrap();
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
if !or.is_null() {
Some(or.to_string())
} else {
None
}
} else {
None
};
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
tracing::error!("Error extracting the app state: {:?}", err);
Error::InvalidAuth
})?;
let ExtractClientIP(ip) =
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
// Create session
#[rustfmt::skip]
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
// If Basic authentication data was supplied
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
basic(&mut session, au.username(), au.password()).await?;
};
// If Token authentication data was supplied
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
token(kvs, &mut session, au.token().into()).await?;
};
Ok(session)
}

View file

@ -1,10 +1,21 @@
use crate::cli::CF; use axum::async_trait;
use axum::extract::ConnectInfo;
use axum::extract::FromRef;
use axum::extract::FromRequestParts;
use axum::middleware::Next;
use axum::response::Response;
use axum::Extension;
use axum::RequestPartsExt;
use clap::ValueEnum; use clap::ValueEnum;
use std::net::IpAddr; use http::request::Parts;
use http::Request;
use http::StatusCode;
use std::net::SocketAddr; use std::net::SocketAddr;
use warp::Filter;
use super::AppState;
// TODO: Support Forwarded, X-Forwarded-For headers. // TODO: Support Forwarded, X-Forwarded-For headers.
// Get inspiration from https://github.com/imbolc/axum-client-ip or simply use it
#[derive(ValueEnum, Clone, Copy, Debug)] #[derive(ValueEnum, Clone, Copy, Debug)]
pub enum ClientIp { pub enum ClientIp {
/// Don't use client IP /// Don't use client IP
@ -25,31 +36,95 @@ pub enum ClientIp {
XRealIp, XRealIp,
} }
/// Creates an string represenation of the client's IP address impl std::fmt::Display for ClientIp {
pub fn build() -> impl Filter<Extract = (Option<String>,), Error = warp::Rejection> + Clone { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Get configured client IP source match self {
let client_ip = CF.get().unwrap().client_ip; ClientIp::None => write!(f, "None"),
// Enable on any path ClientIp::Socket => write!(f, "Socket"),
let conf = warp::any(); ClientIp::CfConectingIp => write!(f, "CF-Connecting-IP"),
// Add raw remote IP address ClientIp::FlyClientIp => write!(f, "Fly-Client-IP"),
let conf = ClientIp::TrueClientIP => write!(f, "True-Client-IP"),
conf.and(warp::filters::addr::remote().and_then(move |s: Option<SocketAddr>| async move { ClientIp::XRealIp => write!(f, "X-Real-IP"),
match client_ip { }
ClientIp::None => Ok(None), }
ClientIp::Socket => Ok(s.map(|s| s.ip())), }
// Move on to parsing selected IP header.
_ => Err(warp::reject::reject()), impl ClientIp {
} fn is_header(&self) -> bool {
})); match self {
// Add selected IP header ClientIp::None => false,
let conf = conf.or(warp::header::optional::<IpAddr>(match client_ip { ClientIp::Socket => false,
ClientIp::CfConectingIp => "Cf-Connecting-IP", ClientIp::CfConectingIp => true,
ClientIp::FlyClientIp => "Fly-Client-IP", ClientIp::FlyClientIp => true,
ClientIp::TrueClientIP => "True-Client-IP", ClientIp::TrueClientIP => true,
ClientIp::XRealIp => "X-Real-IP", ClientIp::XRealIp => true,
// none and socket are already handled so this will never be used }
_ => "unreachable", }
})); }
// Join the two filters
conf.unify().map(|ip: Option<IpAddr>| ip.map(|ip| ip.to_string())) pub(super) struct ExtractClientIP(pub Option<String>);
#[async_trait]
impl<S> FromRequestParts<S> for ExtractClientIP
where
AppState: FromRef<S>,
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let app_state = AppState::from_ref(state);
let res = match app_state.client_ip {
ClientIp::None => ExtractClientIP(None),
ClientIp::Socket => {
if let Ok(ConnectInfo(addr)) =
ConnectInfo::<SocketAddr>::from_request_parts(parts, state).await
{
ExtractClientIP(Some(addr.ip().to_string()))
} else {
ExtractClientIP(None)
}
}
// Get the IP from the corresponding header
var if var.is_header() => {
if let Some(ip) = parts.headers.get(var.to_string()) {
ip.to_str().map(|s| ExtractClientIP(Some(s.to_string()))).unwrap_or_else(
|err| {
debug!("Invalid header value for {}: {}", var, err);
ExtractClientIP(None)
},
)
} else {
ExtractClientIP(None)
}
}
_ => {
warn!("Unexpected ClientIp variant: {:?}", app_state.client_ip);
ExtractClientIP(None)
}
};
Ok(res)
}
}
pub(super) async fn client_ip_middleware<B>(
request: Request<B>,
next: Next<B>,
) -> Result<Response, StatusCode>
where
B: Send,
{
let (mut parts, body) = request.into_parts();
if let Ok(Extension(state)) = parts.extract::<Extension<AppState>>().await {
if let Ok(client_ip) = parts.extract_with_state::<ExtractClientIP, AppState>(&state).await {
parts.extensions.insert(client_ip);
}
} else {
trace!("No AppState found, skipping client_ip_middleware");
}
Ok(next.run(Request::from_parts(parts, body)).await)
} }

View file

@ -1,21 +1,25 @@
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error; use axum::response::IntoResponse;
use crate::net::session; use axum::routing::get;
use axum::Router;
use axum::{response::Response, Extension};
use bytes::Bytes; use bytes::Bytes;
use http::StatusCode;
use http_body::Body as HttpBody;
use hyper::body::Body; use hyper::body::Body;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use warp::Filter;
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
warp::path("export") B: HttpBody + Send + 'static,
.and(warp::path::end()) S: Clone + Send + Sync + 'static,
.and(warp::get()) {
.and(session::build()) Router::new().route("/export", get(handler))
.and_then(handler)
} }
async fn handler(session: Session) -> Result<impl warp::Reply, warp::Rejection> { async fn handler(
Extension(session): Extension<Session>,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Check the permissions // Check the permissions
match session.au.is_db() { match session.au.is_db() {
true => { true => {
@ -24,12 +28,12 @@ async fn handler(session: Session) -> Result<impl warp::Reply, warp::Rejection>
// Extract the NS header value // Extract the NS header value
let nsv = match session.ns { let nsv = match session.ns {
Some(ns) => ns, Some(ns) => ns,
None => return Err(warp::reject::custom(Error::NoNsHeader)), None => return Err((StatusCode::BAD_REQUEST, "No namespace provided")),
}; };
// Extract the DB header value // Extract the DB header value
let dbv = match session.db { let dbv = match session.db {
Some(db) => db, Some(db) => db,
None => return Err(warp::reject::custom(Error::NoDbHeader)), None => return Err((StatusCode::BAD_REQUEST, "No database provided")),
}; };
// Create a chunked response // Create a chunked response
let (mut chn, bdy) = Body::channel(); let (mut chn, bdy) = Body::channel();
@ -44,9 +48,9 @@ async fn handler(session: Session) -> Result<impl warp::Reply, warp::Rejection>
} }
}); });
// Return the chunked body // Return the chunked body
Ok(warp::reply::Response::new(bdy)) Ok(Response::builder().status(StatusCode::OK).body(bdy).unwrap())
} }
// There was an error with permissions // The user does not have the correct permissions
_ => Err(warp::reject::custom(Error::InvalidAuth)), _ => Err((StatusCode::FORBIDDEN, "Invalid permissions")),
} }
} }

View file

@ -1,126 +0,0 @@
use crate::err::Error;
use serde::Serialize;
use warp::http::StatusCode;
#[derive(Serialize)]
struct Message {
code: u16,
details: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
information: Option<String>,
}
pub async fn recover(err: warp::Rejection) -> Result<impl warp::Reply, warp::Rejection> {
if let Some(err) = err.find::<Error>() {
match err {
Error::InvalidAuth => Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 403,
details: Some("Authentication failed".to_string()),
description: Some("Your authentication details are invalid. Reauthenticate using valid authentication parameters.".to_string()),
information: Some(err.to_string()),
}),
StatusCode::FORBIDDEN,
)),
Error::InvalidType => Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 415,
details: Some("Unsupported media type".to_string()),
description: Some("The request needs to adhere to certain constraints. Refer to the documentation for supported content types.".to_string()),
information: None,
}),
StatusCode::UNSUPPORTED_MEDIA_TYPE,
)),
Error::InvalidStorage => Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 500,
details: Some("Health check failed".to_string()),
description: Some("The database health check for this instance failed. There was an issue with the underlying storage engine.".to_string()),
information: Some(err.to_string()),
}),
StatusCode::INTERNAL_SERVER_ERROR,
)),
_ => Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 400,
details: Some("Request problems detected".to_string()),
description: Some("There is a problem with your request. Refer to the documentation for further information.".to_string()),
information: Some(err.to_string()),
}),
StatusCode::BAD_REQUEST,
))
}
} else if err.is_not_found() {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 404,
details: Some("Requested resource not found".to_string()),
description: Some("The requested resource does not exist. Check that you have entered the url correctly.".to_string()),
information: None,
}),
StatusCode::NOT_FOUND,
))
} else if err.find::<warp::reject::MissingHeader>().is_some() {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 412,
details: Some("Request problems detected".to_string()),
description: Some("The request appears to be missing a required header. Refer to the documentation for request requirements.".to_string()),
information: None,
}),
StatusCode::PRECONDITION_FAILED,
))
} else if err.find::<warp::reject::PayloadTooLarge>().is_some() {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 413,
details: Some("Payload too large".to_string()),
description: Some("The request has exceeded the maximum payload size. Refer to the documentation for the request limitations.".to_string()),
information: None,
}),
StatusCode::PAYLOAD_TOO_LARGE,
))
} else if err.find::<warp::reject::InvalidQuery>().is_some() {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 501,
details: Some("Not implemented".to_string()),
description: Some("The server either does not recognize the query, or it lacks the ability to fulfill the request.".to_string()),
information: None,
}),
StatusCode::NOT_IMPLEMENTED,
))
} else if err.find::<warp::reject::InvalidHeader>().is_some() {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 501,
details: Some("Not implemented".to_string()),
description: Some("The server either does not recognize a request header, or it lacks the ability to fulfill the request.".to_string()),
information: None,
}),
StatusCode::NOT_IMPLEMENTED,
))
} else if err.find::<warp::reject::MethodNotAllowed>().is_some() {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 405,
details: Some("Requested method not allowed".to_string()),
description: Some("The requested http method is not allowed for this resource. Refer to the documentation for allowed methods.".to_string()),
information: None,
}),
StatusCode::METHOD_NOT_ALLOWED,
))
} else {
Ok(warp::reply::with_status(
warp::reply::json(&Message {
code: 500,
details: Some("Internal server error".to_string()),
description: Some("There was a problem with our servers, and we have been notified. Refer to the documentation for further information".to_string()),
information: None,
}),
StatusCode::INTERNAL_SERVER_ERROR,
))
}
}

View file

@ -1,41 +0,0 @@
use crate::cnf::PKG_NAME;
use crate::cnf::PKG_VERSION;
use surrealdb::cnf::SERVER_NAME;
const ID: &str = "ID";
const NS: &str = "NS";
const DB: &str = "DB";
const SERVER: &str = "Server";
const VERSION: &str = "Version";
pub fn version() -> warp::filters::reply::WithHeader {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
warp::reply::with::header(VERSION, val)
}
pub fn server() -> warp::filters::reply::WithHeader {
warp::reply::with::header(SERVER, SERVER_NAME)
}
pub fn cors() -> warp::filters::cors::Builder {
warp::cors()
.max_age(86400)
.allow_any_origin()
.allow_methods(vec![
http::Method::GET,
http::Method::PUT,
http::Method::POST,
http::Method::PATCH,
http::Method::DELETE,
http::Method::OPTIONS,
])
.allow_headers(vec![
http::header::ACCEPT,
http::header::AUTHORIZATION,
http::header::CONTENT_TYPE,
http::header::ORIGIN,
NS.parse().unwrap(),
DB.parse().unwrap(),
ID.parse().unwrap(),
])
}

95
src/net/headers.rs Normal file
View file

@ -0,0 +1,95 @@
use crate::cnf::PKG_NAME;
use crate::cnf::PKG_VERSION;
use axum::headers;
use axum::headers::Header;
use http::HeaderName;
use http::HeaderValue;
use surrealdb::cnf::SERVER_NAME;
use tower_http::set_header::SetResponseHeaderLayer;
pub(super) const ID: &str = "ID";
pub(super) const NS: &str = "NS";
pub(super) const DB: &str = "DB";
const SERVER: &str = "server";
const VERSION: &str = "version";
pub fn add_version_header() -> SetResponseHeaderLayer<HeaderValue> {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
SetResponseHeaderLayer::if_not_present(
HeaderName::from_static(VERSION),
HeaderValue::try_from(val).unwrap(),
)
}
pub fn add_server_header() -> SetResponseHeaderLayer<HeaderValue> {
SetResponseHeaderLayer::if_not_present(
HeaderName::from_static(SERVER),
HeaderValue::try_from(SERVER_NAME).unwrap(),
)
}
/// Typed header implementation for the `Accept` header.
pub enum Accept {
TextPlain,
ApplicationJson,
ApplicationCbor,
ApplicationPack,
ApplicationOctetStream,
Surrealdb,
}
impl std::fmt::Display for Accept {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Accept::TextPlain => write!(f, "text/plain"),
Accept::ApplicationJson => write!(f, "application/json"),
Accept::ApplicationCbor => write!(f, "application/cbor"),
Accept::ApplicationPack => write!(f, "application/pack"),
Accept::ApplicationOctetStream => write!(f, "application/octet-stream"),
Accept::Surrealdb => write!(f, "application/surrealdb"),
}
}
}
impl Header for Accept {
fn name() -> &'static HeaderName {
&http::header::ACCEPT
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
let value = values.next().ok_or_else(headers::Error::invalid)?;
match value.to_str().map_err(|_| headers::Error::invalid())? {
"text/plain" => Ok(Accept::TextPlain),
"application/json" => Ok(Accept::ApplicationJson),
"application/cbor" => Ok(Accept::ApplicationCbor),
"application/pack" => Ok(Accept::ApplicationPack),
"application/octet-stream" => Ok(Accept::ApplicationOctetStream),
"application/surrealdb" => Ok(Accept::Surrealdb),
// TODO: Support more (all?) mime-types
_ => Err(headers::Error::invalid()),
}
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
values.extend(std::iter::once(self.into()));
}
}
impl From<Accept> for HeaderValue {
fn from(value: Accept) -> Self {
HeaderValue::from(&value)
}
}
impl From<&Accept> for HeaderValue {
fn from(value: &Accept) -> Self {
HeaderValue::from_str(value.to_string().as_str()).unwrap()
}
}

View file

@ -1,25 +1,31 @@
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use warp::Filter; use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use http_body::Body as HttpBody;
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
warp::path("health").and(warp::path::end()).and(warp::get()).and_then(handler) B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
Router::new().route("/health", get(handler))
} }
async fn handler() -> Result<impl warp::Reply, warp::Rejection> { async fn handler() -> impl IntoResponse {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Attempt to open a transaction // Attempt to open a transaction
match db.transaction(false, false).await { match db.transaction(false, false).await {
// The transaction failed to start // The transaction failed to start
Err(_) => Err(warp::reject::custom(Error::InvalidStorage)), Err(_) => Err(Error::InvalidStorage),
// The transaction was successful // The transaction was successful
Ok(mut tx) => { Ok(mut tx) => {
// Cancel the transaction // Cancel the transaction
let _ = tx.cancel().await; let _ = tx.cancel().await;
// Return the response // Return the response
Ok(warp::reply()) Ok(())
} }
} }
} }

View file

@ -2,31 +2,39 @@ use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
use crate::net::session; use axum::extract::DefaultBodyLimit;
use axum::response::IntoResponse;
use axum::routing::post;
use axum::Extension;
use axum::Router;
use axum::TypedHeader;
use bytes::Bytes; use bytes::Bytes;
use http_body::Body as HttpBody;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use warp::http; use tower_http::limit::RequestBodyLimitLayer;
use warp::Filter;
const MAX: u64 = 1024 * 1024 * 1024 * 4; // 4 GiB use super::headers::Accept;
#[allow(opaque_hidden_inferred_bound)] const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
warp::path("import") pub(super) fn router<S, B>() -> Router<S, B>
.and(warp::path::end()) where
.and(warp::post()) B: HttpBody + Send + 'static,
.and(warp::header::<String>(http::header::ACCEPT.as_str())) B::Data: Send,
.and(warp::body::content_length_limit(MAX)) B::Error: std::error::Error + Send + Sync + 'static,
.and(warp::body::bytes()) S: Clone + Send + Sync + 'static,
.and(session::build()) {
.and_then(handler) Router::new()
.route("/import", post(handler))
.route_layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(MAX))
} }
async fn handler( async fn handler(
output: String, Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
sql: Bytes, sql: Bytes,
session: Session, ) -> Result<impl IntoResponse, impl IntoResponse> {
) -> Result<impl warp::Reply, warp::Rejection> {
// Check the permissions // Check the permissions
match session.au.is_db() { match session.au.is_db() {
true => { true => {
@ -36,22 +44,22 @@ async fn handler(
let sql = bytes_to_utf8(&sql)?; let sql = bytes_to_utf8(&sql)?;
// Execute the sql query in the database // Execute the sql query in the database
match db.execute(sql, &session, None).await { match db.execute(sql, &session, None).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// Return nothing // Return nothing
"application/octet-stream" => Ok(output::none()), Some(Accept::ApplicationOctetStream) => Ok(output::none()),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
_ => Err(warp::reject::custom(Error::InvalidAuth)), _ => Err(Error::InvalidAuth),
} }
} }

View file

@ -1,10 +0,0 @@
use crate::cnf;
use warp::http::Uri;
use warp::Filter;
#[allow(opaque_hidden_inferred_bound)]
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
warp::path::end()
.and(warp::get())
.map(|| warp::redirect::temporary(Uri::from_static(cnf::APP_ENDPOINT)))
}

View file

@ -1,6 +1,6 @@
use crate::err::Error; use crate::err::Error;
use bytes::Bytes; use bytes::Bytes;
pub(crate) fn bytes_to_utf8(bytes: &Bytes) -> Result<&str, warp::Rejection> { pub(crate) fn bytes_to_utf8(bytes: &Bytes) -> Result<&str, Error> {
std::str::from_utf8(bytes).map_err(|_| warp::reject::custom(Error::Request)) std::str::from_utf8(bytes).map_err(|_| Error::Request)
} }

View file

@ -2,145 +2,62 @@ use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
use crate::net::params::{Param, Params}; use crate::net::params::Params;
use crate::net::session; use axum::extract::{DefaultBodyLimit, Path, Query};
use axum::response::IntoResponse;
use axum::routing::options;
use axum::{Extension, Router, TypedHeader};
use bytes::Bytes; use bytes::Bytes;
use http_body::Body as HttpBody;
use serde::Deserialize; use serde::Deserialize;
use std::str; use std::str;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use warp::path; use tower_http::limit::RequestBodyLimitLayer;
use warp::Filter;
const MAX: u64 = 1024 * 16; // 16 KiB use super::headers::Accept;
const MAX: usize = 1024 * 16; // 16 KiB
#[derive(Default, Deserialize, Debug, Clone)] #[derive(Default, Deserialize, Debug, Clone)]
struct Query { struct QueryOptions {
pub limit: Option<String>, pub limit: Option<String>,
pub start: Option<String>, pub start: Option<String>,
} }
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
// ------------------------------ B: HttpBody + Send + 'static,
// Routes for OPTIONS B::Data: Send,
// ------------------------------ B::Error: std::error::Error + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
let base = warp::path("key"); {
// Set opts method Router::new()
let opts = base.and(warp::options()).map(warp::reply); .route(
"/key/:table",
// ------------------------------ options(|| async {})
// Routes for a table .get(select_all)
// ------------------------------ .post(create_all)
.put(update_all)
// Set select method .patch(modify_all)
let select = warp::any() .delete(delete_all),
.and(warp::get()) )
.and(warp::header::<String>(http::header::ACCEPT.as_str())) .route_layer(DefaultBodyLimit::disable())
.and(path!("key" / Param).and(warp::path::end())) .layer(RequestBodyLimitLayer::new(MAX))
.and(warp::query()) .merge(
.and(session::build()) Router::new()
.and_then(select_all); .route(
// Set create method "/key/:table/:key",
let create = warp::any() options(|| async {})
.and(warp::post()) .get(select_one)
.and(warp::header::<String>(http::header::ACCEPT.as_str())) .post(create_one)
.and(path!("key" / Param).and(warp::path::end())) .put(update_one)
.and(warp::body::content_length_limit(MAX)) .patch(modify_one)
.and(warp::body::bytes()) .delete(delete_one),
.and(warp::query()) )
.and(session::build()) .route_layer(DefaultBodyLimit::disable())
.and_then(create_all); .layer(RequestBodyLimitLayer::new(MAX)),
// Set update method )
let update = warp::any()
.and(warp::put())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param).and(warp::path::end()))
.and(warp::body::content_length_limit(MAX))
.and(warp::body::bytes())
.and(warp::query())
.and(session::build())
.and_then(update_all);
// Set modify method
let modify = warp::any()
.and(warp::patch())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param).and(warp::path::end()))
.and(warp::body::content_length_limit(MAX))
.and(warp::body::bytes())
.and(warp::query())
.and(session::build())
.and_then(modify_all);
// Set delete method
let delete = warp::any()
.and(warp::delete())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param).and(warp::path::end()))
.and(warp::query())
.and(session::build())
.and_then(delete_all);
// Specify route
let all = select.or(create).or(update).or(modify).or(delete);
// ------------------------------
// Routes for a thing
// ------------------------------
// Set select method
let select = warp::any()
.and(warp::get())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param / Param).and(warp::path::end()))
.and(session::build())
.and_then(select_one);
// Set create method
let create = warp::any()
.and(warp::post())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param / Param).and(warp::path::end()))
.and(warp::body::content_length_limit(MAX))
.and(warp::body::bytes())
.and(warp::query())
.and(session::build())
.and_then(create_one);
// Set update method
let update = warp::any()
.and(warp::put())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param / Param).and(warp::path::end()))
.and(warp::body::content_length_limit(MAX))
.and(warp::body::bytes())
.and(warp::query())
.and(session::build())
.and_then(update_one);
// Set modify method
let modify = warp::any()
.and(warp::patch())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param / Param).and(warp::path::end()))
.and(warp::body::content_length_limit(MAX))
.and(warp::body::bytes())
.and(warp::query())
.and(session::build())
.and_then(modify_one);
// Set delete method
let delete = warp::any()
.and(warp::delete())
.and(warp::header::<String>(http::header::ACCEPT.as_str()))
.and(path!("key" / Param / Param).and(warp::path::end()))
.and(warp::query())
.and(session::build())
.and_then(delete_one);
// Specify route
let one = select.or(create).or(update).or(modify).or(delete);
// ------------------------------
// All routes
// ------------------------------
// Specify route
opts.or(all).or(one)
} }
// ------------------------------ // ------------------------------
@ -148,11 +65,11 @@ pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejecti
// ------------------------------ // ------------------------------
async fn select_all( async fn select_all(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
query: Query, Path(table): Path<String>,
session: Session, Query(query): Query<QueryOptions>,
) -> Result<impl warp::Reply, warp::Rejection> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Specify the request statement // Specify the request statement
@ -167,28 +84,28 @@ async fn select_all(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql.as_str(), &session, Some(vars)).await { match db.execute(sql.as_str(), &session, Some(vars)).await {
Ok(ref res) => match output.as_ref() { Ok(ref res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
async fn create_all( async fn create_all(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
body: Bytes, body: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the HTTP request body // Convert the HTTP request body
@ -206,31 +123,31 @@ async fn create_all(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
Err(_) => Err(warp::reject::custom(Error::Request)), Err(_) => Err(Error::Request),
} }
} }
async fn update_all( async fn update_all(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
body: Bytes, body: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the HTTP request body // Convert the HTTP request body
@ -248,31 +165,31 @@ async fn update_all(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
Err(_) => Err(warp::reject::custom(Error::Request)), Err(_) => Err(Error::Request),
} }
} }
async fn modify_all( async fn modify_all(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
body: Bytes, body: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the HTTP request body // Convert the HTTP request body
@ -290,30 +207,30 @@ async fn modify_all(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
Err(_) => Err(warp::reject::custom(Error::Request)), Err(_) => Err(Error::Request),
} }
} }
async fn delete_all( async fn delete_all(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
params: Params, Path(table): Path<String>,
session: Session, Query(params): Query<Params>,
) -> Result<impl warp::Reply, warp::Rejection> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Specify the request statement // Specify the request statement
@ -325,18 +242,18 @@ async fn delete_all(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
@ -345,11 +262,10 @@ async fn delete_all(
// ------------------------------ // ------------------------------
async fn select_one( async fn select_one(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
id: Param, Path((table, id)): Path<(String, String)>,
session: Session, ) -> Result<impl IntoResponse, impl IntoResponse> {
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Specify the request statement // Specify the request statement
@ -366,29 +282,28 @@ async fn select_one(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
async fn create_one( async fn create_one(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
id: Param, Query(params): Query<Params>,
Path((table, id)): Path<(String, String)>,
body: Bytes, body: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the HTTP request body // Convert the HTTP request body
@ -412,32 +327,31 @@ async fn create_one(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
Err(_) => Err(warp::reject::custom(Error::Request)), Err(_) => Err(Error::Request),
} }
} }
async fn update_one( async fn update_one(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
id: Param, Query(params): Query<Params>,
Path((table, id)): Path<(String, String)>,
body: Bytes, body: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the HTTP request body // Convert the HTTP request body
@ -461,32 +375,31 @@ async fn update_one(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
Err(_) => Err(warp::reject::custom(Error::Request)), Err(_) => Err(Error::Request),
} }
} }
async fn modify_one( async fn modify_one(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
id: Param, Query(params): Query<Params>,
Path((table, id)): Path<(String, String)>,
body: Bytes, body: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the HTTP request body // Convert the HTTP request body
@ -510,31 +423,29 @@ async fn modify_one(
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
Err(_) => Err(warp::reject::custom(Error::Request)), Err(_) => Err(Error::Request),
} }
} }
async fn delete_one( async fn delete_one(
output: String, Extension(session): Extension<Session>,
table: Param, maybe_output: Option<TypedHeader<Accept>>,
id: Param, Path((table, id)): Path<(String, String)>,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Specify the request statement // Specify the request statement
@ -548,21 +459,20 @@ async fn delete_one(
let vars = map! { let vars = map! {
String::from("table") => Value::from(table), String::from("table") => Value::from(table),
String::from("id") => rid, String::from("id") => rid,
=> params.parse()
}; };
// Execute the query and return the result // Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await { match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match output.as_ref() { Ok(res) => match maybe_output.as_deref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }

View file

@ -1,30 +0,0 @@
use std::fmt;
use tracing::Level;
struct OptFmt<T>(Option<T>);
impl<T: fmt::Display> fmt::Display for OptFmt<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(ref t) = self.0 {
fmt::Display::fmt(t, f)
} else {
f.write_str("-")
}
}
}
pub fn write() -> warp::filters::log::Log<impl Fn(warp::filters::log::Info) + Copy> {
warp::log::custom(|info| {
event!(
Level::INFO,
"{} {} {} {:?} {} \"{}\" {:?}",
OptFmt(info.remote_addr()),
info.method(),
info.path(),
info.version(),
info.status().as_u16(),
OptFmt(info.user_agent()),
info.elapsed(),
);
})
}

View file

@ -1,105 +1,164 @@
mod auth;
pub mod client_ip; pub mod client_ip;
mod export; mod export;
mod fail; mod headers;
mod head;
mod health; mod health;
mod import; mod import;
mod index;
mod input; mod input;
mod key; mod key;
mod log;
mod output; mod output;
mod params; mod params;
mod rpc; mod rpc;
mod session;
mod signals; mod signals;
mod signin; mod signin;
mod signup; mod signup;
mod sql; mod sql;
mod status;
mod sync; mod sync;
mod tracer;
mod version; mod version;
use axum::response::Redirect;
use axum::routing::get;
use axum::{middleware, Router};
use axum_server::Handle;
use http::header;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tower::ServiceBuilder;
use tower_http::add_extension::AddExtensionLayer;
use tower_http::auth::AsyncRequireAuthorizationLayer;
use tower_http::cors::{Any, CorsLayer};
use tower_http::request_id::MakeRequestUuid;
use tower_http::sensitive_headers::{
SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer,
};
use tower_http::trace::TraceLayer;
use tower_http::ServiceBuilderExt;
use crate::cli::CF; use crate::cli::CF;
use crate::cnf;
use crate::err::Error; use crate::err::Error;
use warp::Filter; use crate::net::signals::graceful_shutdown;
use crate::telemetry::metrics::HttpMetricsLayer;
use axum_server::tls_rustls::RustlsConfig;
const LOG: &str = "surrealdb::net";
///
/// AppState is used to share data between routes.
///
#[derive(Clone)]
struct AppState {
client_ip: client_ip::ClientIp,
}
pub async fn init() -> Result<(), Error> { pub async fn init() -> Result<(), Error> {
// Setup web routes
let net = index::config()
// Version endpoint
.or(version::config())
// Status endpoint
.or(status::config())
// Health endpoint
.or(health::config())
// Signup endpoint
.or(signup::config())
// Signin endpoint
.or(signin::config())
// Export endpoint
.or(export::config())
// Import endpoint
.or(import::config())
// Backup endpoint
.or(sync::config())
// RPC query endpoint
.or(rpc::config())
// SQL query endpoint
.or(sql::config())
// API query endpoint
.or(key::config())
// Catch all errors
.recover(fail::recover)
// End routes setup
;
// Specify a generic version header
let net = net.with(head::version());
// Specify a generic server header
let net = net.with(head::server());
// Set cors headers on all requests
let net = net.with(head::cors());
// Log all requests to the console
let net = net.with(log::write());
// Trace requests
let net = net.with(warp::trace::request());
// Get local copy of options // Get local copy of options
let opt = CF.get().unwrap(); let opt = CF.get().unwrap();
let app_state = AppState {
client_ip: opt.client_ip,
};
// Specify headers to be obfuscated from all requests/responses
let headers: Arc<[_]> = Arc::new([
header::AUTHORIZATION,
header::PROXY_AUTHORIZATION,
header::COOKIE,
header::SET_COOKIE,
]);
// Build the middleware to our service.
let service = ServiceBuilder::new()
.catch_panic()
.set_x_request_id(MakeRequestUuid)
.propagate_x_request_id()
.layer(AddExtensionLayer::new(app_state))
.layer(middleware::from_fn(client_ip::client_ip_middleware))
.layer(SetSensitiveRequestHeadersLayer::from_shared(Arc::clone(&headers)))
.layer(
TraceLayer::new_for_http()
.make_span_with(tracer::HttpTraceLayerHooks)
.on_request(tracer::HttpTraceLayerHooks)
.on_response(tracer::HttpTraceLayerHooks)
.on_failure(tracer::HttpTraceLayerHooks),
)
.layer(HttpMetricsLayer)
.layer(SetSensitiveResponseHeadersLayer::from_shared(headers))
.layer(AsyncRequireAuthorizationLayer::new(auth::SurrealAuth))
.layer(headers::add_server_header())
.layer(headers::add_version_header())
.layer(
CorsLayer::new()
.allow_methods([
http::Method::GET,
http::Method::PUT,
http::Method::POST,
http::Method::PATCH,
http::Method::DELETE,
http::Method::OPTIONS,
])
.allow_headers([
http::header::ACCEPT,
http::header::AUTHORIZATION,
http::header::CONTENT_TYPE,
http::header::ORIGIN,
headers::NS.parse().unwrap(),
headers::DB.parse().unwrap(),
headers::ID.parse().unwrap(),
])
// allow requests from any origin
.allow_origin(Any)
.max_age(Duration::from_secs(86400)),
);
let axum_app = Router::new()
// Redirect until we provide a UI
.route("/", get(|| async { Redirect::temporary(cnf::APP_ENDPOINT) }))
.route("/status", get(|| async {}))
.merge(health::router())
.merge(export::router())
.merge(import::router())
.merge(rpc::router())
.merge(version::router())
.merge(sync::router())
.merge(sql::router())
.merge(signin::router())
.merge(signup::router())
.merge(key::router())
.layer(service);
info!("Starting web server on {}", &opt.bind); info!("Starting web server on {}", &opt.bind);
if let (Some(c), Some(k)) = (&opt.crt, &opt.key) { // Setup the graceful shutdown with no timeout
// Bind the server to the desired port let handle = Handle::new();
let (adr, srv) = warp::serve(net) graceful_shutdown(handle.clone(), None);
.tls()
.cert_path(c) if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
.key_path(k) // configure certificate and private key used by https
.bind_with_graceful_shutdown(opt.bind, async move { let tls = RustlsConfig::from_pem_file(cert, key).await.unwrap();
// Capture the shutdown signals and log that the graceful shutdown has started
let result = signals::listen().await.expect("Failed to listen to shutdown signal"); let server = axum_server::bind_rustls(opt.bind, tls);
info!("{} received. Start graceful shutdown...", result);
}); info!(target: LOG, "Started web server on {}", &opt.bind);
// Log the server startup status
info!("Started web server on {}", &adr); server
// Run the server forever .handle(handle)
srv.await; .serve(axum_app.into_make_service_with_connect_info::<SocketAddr>())
// Log the server shutdown event .await?;
info!("Shutdown complete. Bye!")
} else { } else {
// Bind the server to the desired port let server = axum_server::bind(opt.bind);
let (adr, srv) = warp::serve(net).bind_with_graceful_shutdown(opt.bind, async move {
// Capture the shutdown signals and log that the graceful shutdown has started info!(target: LOG, "Started web server on {}", &opt.bind);
let result = signals::listen().await.expect("Failed to listen to shutdown signal");
info!("{} received. Start graceful shutdown...", result); server
}); .handle(handle)
// Log the server startup status .serve(axum_app.into_make_service_with_connect_info::<SocketAddr>())
info!("Started web server on {}", &adr); .await?;
// Run the server forever
srv.await;
// Log the server shutdown event
info!("Shutdown complete. Bye!")
}; };
info!(target: LOG, "Web server stopped. Bye!");
Ok(()) Ok(())
} }

View file

@ -1,9 +1,12 @@
use axum::response::{IntoResponse, Response};
use http::header::{HeaderValue, CONTENT_TYPE}; use http::header::{HeaderValue, CONTENT_TYPE};
use http::StatusCode; use http::StatusCode;
use serde::Serialize; use serde::Serialize;
use serde_json::Value as Json; use serde_json::Value as Json;
use surrealdb::sql; use surrealdb::sql;
use super::headers::Accept;
pub enum Output { pub enum Output {
None, None,
Fail, Fail,
@ -67,38 +70,23 @@ pub fn simplify<T: Serialize>(v: T) -> Json {
sql::to_value(v).unwrap().into() sql::to_value(v).unwrap().into()
} }
impl warp::Reply for Output { impl IntoResponse for Output {
fn into_response(self) -> warp::reply::Response { fn into_response(self) -> Response {
match self { match self {
Output::Text(v) => { Output::Text(v) => {
let mut res = warp::reply::Response::new(v.into()); ([(CONTENT_TYPE, HeaderValue::from(Accept::TextPlain))], v).into_response()
let con = HeaderValue::from_static("text/plain");
res.headers_mut().insert(CONTENT_TYPE, con);
res
} }
Output::Json(v) => { Output::Json(v) => {
let mut res = warp::reply::Response::new(v.into()); ([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationJson))], v).into_response()
let con = HeaderValue::from_static("application/json");
res.headers_mut().insert(CONTENT_TYPE, con);
res
} }
Output::Cbor(v) => { Output::Cbor(v) => {
let mut res = warp::reply::Response::new(v.into()); ([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationCbor))], v).into_response()
let con = HeaderValue::from_static("application/cbor");
res.headers_mut().insert(CONTENT_TYPE, con);
res
} }
Output::Pack(v) => { Output::Pack(v) => {
let mut res = warp::reply::Response::new(v.into()); ([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationPack))], v).into_response()
let con = HeaderValue::from_static("application/pack");
res.headers_mut().insert(CONTENT_TYPE, con);
res
} }
Output::Full(v) => { Output::Full(v) => {
let mut res = warp::reply::Response::new(v.into()); ([(CONTENT_TYPE, HeaderValue::from(Accept::Surrealdb))], v).into_response()
let con = HeaderValue::from_static("application/surrealdb");
res.headers_mut().insert(CONTENT_TYPE, con);
res
} }
Output::None => StatusCode::OK.into_response(), Output::None => StatusCode::OK.into_response(),
Output::Fail => StatusCode::INTERNAL_SERVER_ERROR.into_response(), Output::Fail => StatusCode::INTERNAL_SERVER_ERROR.into_response(),

View file

@ -5,13 +5,16 @@ use crate::cnf::PKG_VERSION;
use crate::cnf::WEBSOCKET_PING_FREQUENCY; use crate::cnf::WEBSOCKET_PING_FREQUENCY;
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::session;
use crate::rpc::args::Take; use crate::rpc::args::Take;
use crate::rpc::paths::{ID, METHOD, PARAMS}; use crate::rpc::paths::{ID, METHOD, PARAMS};
use crate::rpc::res; use crate::rpc::res;
use crate::rpc::res::Failure; use crate::rpc::res::Failure;
use crate::rpc::res::Output; use crate::rpc::res::Output;
use axum::routing::get;
use axum::Extension;
use axum::Router;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use http_body::Body as HttpBody;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::collections::HashMap; use std::collections::HashMap;
@ -27,8 +30,11 @@ use surrealdb::sql::Value;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::instrument; use tracing::instrument;
use uuid::Uuid; use uuid::Uuid;
use warp::ws::{Message, WebSocket, Ws};
use warp::Filter; use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
response::IntoResponse,
};
// Mapping of WebSocketID to WebSocket // Mapping of WebSocketID to WebSocket
type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>; type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>;
@ -38,17 +44,22 @@ type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default); static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default); static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
warp::path("rpc") B: HttpBody + Send + 'static,
.and(warp::path::end()) S: Clone + Send + Sync + 'static,
.and(warp::ws()) {
.and(session::build()) Router::new().route("/rpc", get(handler))
.map(|ws: Ws, session: Session| ws.on_upgrade(move |ws| socket(ws, session)))
} }
async fn socket(ws: WebSocket, session: Session) { async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension<Session>) -> impl IntoResponse {
let rpc = Rpc::new(session); // finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, sess))
}
async fn handle_socket(ws: WebSocket, sess: Session) {
let rpc = Rpc::new(sess);
Rpc::serve(rpc, ws).await Rpc::serve(rpc, ws).await
} }
@ -89,7 +100,7 @@ impl Rpc {
let png = chn.clone(); let png = chn.clone();
// The WebSocket has connected // The WebSocket has connected
Rpc::connected(rpc.clone(), chn.clone()).await; Rpc::connected(rpc.clone(), chn.clone()).await;
// Send messages to the client // Send Ping messages to the client
tokio::task::spawn(async move { tokio::task::spawn(async move {
// Create the interval ticker // Create the interval ticker
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY); let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
@ -98,7 +109,7 @@ impl Rpc {
// Wait for the timer // Wait for the timer
interval.tick().await; interval.tick().await;
// Create the ping message // Create the ping message
let msg = Message::ping(vec![]); let msg = Message::Ping(vec![]);
// Send the message to the client // Send the message to the client
if png.send(msg).await.is_err() { if png.send(msg).await.is_err() {
// Exit out of the loop // Exit out of the loop
@ -146,20 +157,18 @@ impl Rpc {
while let Some(msg) = wrx.next().await { while let Some(msg) = wrx.next().await {
match msg { match msg {
// We've received a message from the client // We've received a message from the client
// Ping is automatically handled by the WebSocket library
Ok(msg) => match msg { Ok(msg) => match msg {
msg if msg.is_ping() => { Message::Text(_) => {
let _ = chn.send(Message::pong(vec![])).await;
}
msg if msg.is_text() => {
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone())); tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
} }
msg if msg.is_binary() => { Message::Binary(_) => {
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone())); tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
} }
msg if msg.is_close() => { Message::Close(_) => {
break; break;
} }
msg if msg.is_pong() => { Message::Pong(_) => {
continue; continue;
} }
_ => { _ => {
@ -214,16 +223,14 @@ impl Rpc {
// Parse the request // Parse the request
let req = match msg { let req = match msg {
// This is a binary message // This is a binary message
m if m.is_binary() => { Message::Binary(val) => {
// Use binary output // Use binary output
out = Output::Full; out = Output::Full;
// Deserialize the input // Deserialize the input
Value::from(m.into_bytes()) Value::from(val)
} }
// This is a text message // This is a text message
m if m.is_text() => { Message::Text(ref val) => {
// This won't panic due to the check above
let val = m.to_str().unwrap();
// Parse the SurrealQL object // Parse the SurrealQL object
match surrealdb::sql::value(val) { match surrealdb::sql::value(val) {
// The SurrealQL message parsed ok // The SurrealQL message parsed ok

View file

@ -1,57 +0,0 @@
use crate::dbs::DB;
use crate::err::Error;
use crate::iam::verify::basic;
use crate::iam::BASIC;
use crate::net::client_ip;
use surrealdb::dbs::Session;
use surrealdb::iam::verify::token;
use surrealdb::iam::TOKEN;
use warp::Filter;
pub fn build() -> impl Filter<Extract = (Session,), Error = warp::Rejection> + Clone {
// Enable on any path
let conf = warp::any();
// Add remote ip address
let conf = conf.and(client_ip::build());
// Add authorization header
let conf = conf.and(warp::header::optional::<String>("authorization"));
// Add http origin header
let conf = conf.and(warp::header::optional::<String>("origin"));
// Add session id header
let conf = conf.and(warp::header::optional::<String>("id"));
// Add namespace header
let conf = conf.and(warp::header::optional::<String>("ns"));
// Add database header
let conf = conf.and(warp::header::optional::<String>("db"));
// Process all headers
conf.and_then(process)
}
async fn process(
ip: Option<String>,
au: Option<String>,
or: Option<String>,
id: Option<String>,
ns: Option<String>,
db: Option<String>,
) -> Result<Session, warp::Rejection> {
let kvs = DB.get().unwrap();
// Create session
#[rustfmt::skip]
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
// Parse the authentication header
match au {
// Basic authentication data was supplied
Some(auth) if auth.starts_with(BASIC) => basic(&mut session, auth).await,
// Token authentication data was supplied
Some(auth) if auth.starts_with(TOKEN) => {
token(kvs, &mut session, auth).await.map_err(Error::from)
}
// Wrong authentication data was supplied
Some(_) => Err(Error::InvalidAuth),
// No authentication data was supplied
None => Ok(()),
}?;
// Pass the authenticated session through
Ok(session)
}

View file

@ -1,5 +1,19 @@
use std::time::Duration;
use axum_server::Handle;
use crate::err::Error; use crate::err::Error;
/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received.
pub fn graceful_shutdown(handle: Handle, dur: Option<Duration>) {
tokio::spawn(async move {
let result = listen().await.expect("Failed to listen to shutdown signal");
info!(target: super::LOG, "{} received. Start graceful shutdown...", result);
handle.graceful_shutdown(dur)
});
}
#[cfg(unix)] #[cfg(unix)]
pub async fn listen() -> Result<String, Error> { pub async fn listen() -> Result<String, Error> {
// Import the OS signals // Import the OS signals
@ -11,7 +25,7 @@ pub async fn listen() -> Result<String, Error> {
let mut sigterm = signal(SignalKind::terminate())?; let mut sigterm = signal(SignalKind::terminate())?;
// Listen and wait for the system signals // Listen and wait for the system signals
tokio::select! { tokio::select! {
// Wait for a SIGQUIT signal // Wait for a SIGHUP signal
_ = sighup.recv() => { _ = sighup.recv() => {
Ok(String::from("SIGHUP")) Ok(String::from("SIGHUP"))
} }

View file

@ -2,16 +2,24 @@ use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
use crate::net::session;
use crate::net::CF; use crate::net::CF;
use axum::extract::DefaultBodyLimit;
use axum::response::IntoResponse;
use axum::routing::options;
use axum::Extension;
use axum::Router;
use axum::TypedHeader;
use bytes::Bytes; use bytes::Bytes;
use http_body::Body as HttpBody;
use serde::Serialize; use serde::Serialize;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::opt::auth::Root; use surrealdb::opt::auth::Root;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use warp::Filter; use tower_http::limit::RequestBodyLimitLayer;
const MAX: u64 = 1024; // 1 KiB use super::headers::Accept;
const MAX: usize = 1024; // 1 KiB
#[derive(Serialize)] #[derive(Serialize)]
struct Success { struct Success {
@ -30,29 +38,24 @@ impl Success {
} }
} }
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
// Set base path B: HttpBody + Send + 'static,
let base = warp::path("signin").and(warp::path::end()); B::Data: Send,
// Set opts method B::Error: std::error::Error + Send + Sync + 'static,
let opts = base.and(warp::options()).map(warp::reply); S: Clone + Send + Sync + 'static,
// Set post method {
let post = base Router::new()
.and(warp::post()) .route("/signin", options(|| async {}).post(handler))
.and(warp::header::optional::<String>(http::header::ACCEPT.as_str())) .route_layer(DefaultBodyLimit::disable())
.and(warp::body::content_length_limit(MAX)) .layer(RequestBodyLimitLayer::new(MAX))
.and(warp::body::bytes())
.and(session::build())
.and_then(handler);
// Specify route
opts.or(post)
} }
async fn handler( async fn handler(
output: Option<String>, Extension(mut session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
body: Bytes, body: Bytes,
mut session: Session, ) -> Result<impl IntoResponse, impl IntoResponse> {
) -> Result<impl warp::Reply, warp::Rejection> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
// Get the config options // Get the config options
@ -72,25 +75,25 @@ async fn handler(
.map_err(Error::from) .map_err(Error::from)
{ {
// Authentication was successful // Authentication was successful
Ok(v) => match output.as_deref() { Ok(v) => match maybe_output.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Text serialization
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
// Return nothing // Return nothing
None => Ok(output::none()), None => Ok(output::none()),
// Text serialization
Some("text/plain") => Ok(output::text(v.unwrap_or_default())),
// Simple serialization
Some("application/json") => Ok(output::json(&Success::new(v))),
Some("application/cbor") => Ok(output::cbor(&Success::new(v))),
Some("application/pack") => Ok(output::pack(&Success::new(v))),
// Internal serialization
Some("application/surrealdb") => Ok(output::full(&Success::new(v))),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error with authentication // There was an error with authentication
Err(e) => Err(warp::reject::custom(e)), Err(err) => Err(err),
} }
} }
// The provided value was not an object // The provided value was not an object
_ => Err(warp::reject::custom(Error::Request)), _ => Err(Error::Request),
} }
} }

View file

@ -2,14 +2,20 @@ use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
use crate::net::session; use axum::extract::DefaultBodyLimit;
use axum::response::IntoResponse;
use axum::routing::options;
use axum::{Extension, Router, TypedHeader};
use bytes::Bytes; use bytes::Bytes;
use http_body::Body as HttpBody;
use serde::Serialize; use serde::Serialize;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use warp::Filter; use tower_http::limit::RequestBodyLimitLayer;
const MAX: u64 = 1024; // 1 KiB use super::headers::Accept;
const MAX: usize = 1024; // 1 KiB
#[derive(Serialize)] #[derive(Serialize)]
struct Success { struct Success {
@ -28,29 +34,24 @@ impl Success {
} }
} }
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
// Set base path B: HttpBody + Send + 'static,
let base = warp::path("signup").and(warp::path::end()); B::Data: Send,
// Set opts method B::Error: std::error::Error + Send + Sync + 'static,
let opts = base.and(warp::options()).map(warp::reply); S: Clone + Send + Sync + 'static,
// Set post method {
let post = base Router::new()
.and(warp::post()) .route("/signup", options(|| async {}).post(handler))
.and(warp::header::optional::<String>(http::header::ACCEPT.as_str())) .route_layer(DefaultBodyLimit::disable())
.and(warp::body::content_length_limit(MAX)) .layer(RequestBodyLimitLayer::new(MAX))
.and(warp::body::bytes())
.and(session::build())
.and_then(handler);
// Specify route
opts.or(post)
} }
async fn handler( async fn handler(
output: Option<String>, Extension(mut session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
body: Bytes, body: Bytes,
mut session: Session, ) -> Result<impl IntoResponse, impl IntoResponse> {
) -> Result<impl warp::Reply, warp::Rejection> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
// Convert the HTTP body into text // Convert the HTTP body into text
@ -62,25 +63,25 @@ async fn handler(
match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from) match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from)
{ {
// Authentication was successful // Authentication was successful
Ok(v) => match output.as_deref() { Ok(v) => match maybe_output.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Text serialization
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
// Return nothing // Return nothing
None => Ok(output::none()), None => Ok(output::none()),
// Text serialization
Some("text/plain") => Ok(output::text(v.unwrap_or_default())),
// Simple serialization
Some("application/json") => Ok(output::json(&Success::new(v))),
Some("application/cbor") => Ok(output::cbor(&Success::new(v))),
Some("application/pack") => Ok(output::pack(&Success::new(v))),
// Internal serialization
Some("application/surrealdb") => Ok(output::full(&Success::new(v))),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error with authentication // There was an error with authentication
Err(e) => Err(warp::reject::custom(e)), Err(err) => Err(err),
} }
} }
// The provided value was not an object // The provided value was not an object
_ => Err(warp::reject::custom(Error::Request)), _ => Err(Error::Request),
} }
} }

View file

@ -3,74 +3,80 @@ use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
use crate::net::params::Params; use crate::net::params::Params;
use crate::net::session; use axum::extract::ws::Message;
use axum::extract::ws::WebSocket;
use axum::extract::DefaultBodyLimit;
use axum::extract::Query;
use axum::extract::WebSocketUpgrade;
use axum::response::IntoResponse;
use axum::routing::options;
use axum::Extension;
use axum::Router;
use axum::TypedHeader;
use bytes::Bytes; use bytes::Bytes;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use http_body::Body as HttpBody;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use warp::ws::{Message, WebSocket, Ws}; use tower_http::limit::RequestBodyLimitLayer;
use warp::Filter;
const MAX: u64 = 1024 * 1024; // 1 MiB use super::headers::Accept;
#[allow(opaque_hidden_inferred_bound)] const MAX: usize = 1024 * 1024; // 1 MiB
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
// Set base path pub(super) fn router<S, B>() -> Router<S, B>
let base = warp::path("sql").and(warp::path::end()); where
// Set opts method B: HttpBody + Send + 'static,
let opts = base.and(warp::options()).map(warp::reply); B::Data: Send,
// Set post method B::Error: std::error::Error + Send + Sync + 'static,
let post = base S: Clone + Send + Sync + 'static,
.and(warp::post()) {
.and(warp::header::<String>(http::header::ACCEPT.as_str())) Router::new()
.and(warp::body::content_length_limit(MAX)) .route("/sql", options(|| async {}).get(ws_handler).post(post_handler))
.and(warp::body::bytes()) .route_layer(DefaultBodyLimit::disable())
.and(warp::query()) .layer(RequestBodyLimitLayer::new(MAX))
.and(session::build())
.and_then(handler);
// Set sock method
let sock = base
.and(warp::ws())
.and(session::build())
.map(|ws: Ws, session: Session| ws.on_upgrade(move |ws| socket(ws, session)));
// Specify route
opts.or(post).or(sock)
} }
async fn handler( async fn post_handler(
output: String, Extension(session): Extension<Session>,
output: Option<TypedHeader<Accept>>,
params: Query<Params>,
sql: Bytes, sql: Bytes,
params: Params, ) -> Result<impl IntoResponse, impl IntoResponse> {
session: Session,
) -> Result<impl warp::Reply, warp::Rejection> {
// Get a database reference // Get a database reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Convert the received sql query // Convert the received sql query
let sql = bytes_to_utf8(&sql)?; let sql = bytes_to_utf8(&sql)?;
// Execute the received sql query // Execute the received sql query
match db.execute(sql, &session, params.parse().into()).await { match db.execute(sql, &session, params.0.parse().into()).await {
// Convert the response to JSON Ok(res) => match output.as_deref() {
Ok(res) => match output.as_ref() {
// Simple serialization // Simple serialization
"application/json" => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
"application/cbor" => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
"application/pack" => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
"application/surrealdb" => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(warp::reject::custom(Error::InvalidType)), _ => Err(Error::InvalidType),
}, },
// There was an error when executing the query // There was an error when executing the query
Err(err) => Err(warp::reject::custom(Error::from(err))), Err(err) => Err(Error::from(err)),
} }
} }
async fn socket(ws: WebSocket, session: Session) { async fn ws_handler(
ws: WebSocketUpgrade,
Extension(sess): Extension<Session>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, sess))
}
async fn handle_socket(ws: WebSocket, session: Session) {
// Split the WebSocket connection // Split the WebSocket connection
let (mut tx, mut rx) = ws.split(); let (mut tx, mut rx) = ws.split();
// Wait to receive the next message // Wait to receive the next message
while let Some(res) = rx.next().await { while let Some(res) = rx.next().await {
if let Ok(msg) = res { if let Ok(msg) = res {
if let Ok(sql) = msg.to_str() { if let Ok(sql) = msg.to_text() {
// Get a database reference // Get a database reference
let db = DB.get().unwrap(); let db = DB.get().unwrap();
// Execute the received sql query // Execute the received sql query
@ -78,12 +84,12 @@ async fn socket(ws: WebSocket, session: Session) {
// Convert the response to JSON // Convert the response to JSON
Ok(v) => match serde_json::to_string(&v) { Ok(v) => match serde_json::to_string(&v) {
// Send the JSON response to the client // Send the JSON response to the client
Ok(v) => tx.send(Message::text(v)).await, Ok(v) => tx.send(Message::Text(v)).await,
// There was an error converting to JSON // There was an error converting to JSON
Err(e) => tx.send(Message::text(Error::from(e))).await, Err(e) => tx.send(Message::Text(Error::from(e).to_string())).await,
}, },
// There was an error when executing the query // There was an error when executing the query
Err(e) => tx.send(Message::text(Error::from(e))).await, Err(e) => tx.send(Message::Text(Error::from(e).to_string())).await,
}; };
} }
} }

View file

@ -1,6 +0,0 @@
use warp::Filter;
#[allow(opaque_hidden_inferred_bound)]
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
warp::path("status").and(warp::path::end()).and(warp::get()).map(warp::reply)
}

View file

@ -1,22 +1,20 @@
use warp::http; use axum::response::IntoResponse;
use warp::Filter; use axum::routing::get;
use axum::Router;
use http_body::Body as HttpBody;
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
// Set base path B: HttpBody + Send + 'static,
let base = warp::path("sync").and(warp::path::end()); S: Clone + Send + Sync + 'static,
// Set save method {
let save = base.and(warp::get()).and_then(save); Router::new().route("/sync", get(save).post(load))
// Set load method
let load = base.and(warp::post()).and_then(load);
// Specify route
save.or(load)
} }
pub async fn load() -> Result<impl warp::Reply, warp::Rejection> { async fn load() -> impl IntoResponse {
Ok(warp::reply::with_status("Load", http::StatusCode::OK)) "Load"
} }
pub async fn save() -> Result<impl warp::Reply, warp::Rejection> { async fn save() -> impl IntoResponse {
Ok(warp::reply::with_status("Save", http::StatusCode::OK)) "Save"
} }

235
src/net/tracer.rs Normal file
View file

@ -0,0 +1,235 @@
use std::{fmt, time::Duration};
use axum::{
body::{boxed, Body, BoxBody},
extract::MatchedPath,
headers::{
authorization::{Basic, Bearer},
Authorization, Origin,
},
Extension, RequestPartsExt, TypedHeader,
};
use futures_util::future::BoxFuture;
use http::{header, request::Parts, StatusCode};
use hyper::{Request, Response};
use surrealdb::{dbs::Session, iam::verify::token};
use tower_http::{
auth::AsyncAuthorizeRequest,
request_id::RequestId,
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
};
use tracing::{field, Level, Span};
use crate::{dbs::DB, err::Error, iam::verify::basic};
use super::{client_ip::ExtractClientIP, AppState};
///
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
///
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
///
/// ```rust
/// use tower_http::auth::RequireAuthorizationLayer;
/// use surrealdb::net::SurrealAuth;
/// use axum::Router;
///
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(auth);
/// ```
#[derive(Clone, Copy)]
pub(super) struct SurrealAuth;
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
where
B: Send + Sync + 'static,
{
type RequestBody = B;
type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
Box::pin(async {
let (mut parts, body) = request.into_parts();
match check_auth(&mut parts).await {
Ok(sess) => {
parts.extensions.insert(sess);
Ok(Request::from_parts(parts, body))
}
Err(err) => {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(err.to_string())))
.unwrap();
Err(unauthorized_response)
}
}
})
}
}
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
let kvs = DB.get().unwrap();
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
if !or.is_null() {
Some(or.to_string())
} else {
None
}
} else {
None
};
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
tracing::error!("Error extracting the app state: {:?}", err);
Error::InvalidAuth
})?;
let ExtractClientIP(ip) =
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
// Create session
#[rustfmt::skip]
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
// If Basic authentication data was supplied
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
basic(&mut session, au.username(), au.password()).await
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
token(kvs, &mut session, au.token().into()).await.map_err(|e| e.into())
} else {
Err(Error::InvalidAuth)
}?;
Ok(session)
}
///
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
///
/// Example:
///
/// ```rust
/// use tower_http::trace::TraceLayer;
/// use surrealdb::net::HttpTraceLayerHooks;
/// use axum::Router;
///
/// let trace = TraceLayer::new_for_http().on_request(HttpTraceLayerHooks::default());
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(trace);
/// ```
#[derive(Default, Clone)]
pub(crate) struct HttpTraceLayerHooks;
impl<B> MakeSpan<B> for HttpTraceLayerHooks {
fn make_span(&mut self, req: &Request<B>) -> Span {
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
let span = tracing::info_span!(
target: "surreal::http",
"request",
otel.name = field::Empty,
otel.kind = "server",
http.route = field::Empty,
http.request.method = req.method().as_str(),
http.request.body.size = field::Empty,
url.path = req.uri().path(),
url.query = field::Empty,
url.scheme = field::Empty,
http.request.id = field::Empty,
user_agent.original = field::Empty,
network.protocol.name = "http",
network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"),
client.address = field::Empty,
client.port = field::Empty,
client.socket.address = field::Empty,
server.address = field::Empty,
server.port = field::Empty,
// set on the response hook
http.latency.ms = field::Empty,
http.response.status_code = field::Empty,
http.response.body.size = field::Empty,
// set on the failure hook
error = field::Empty,
error_message = field::Empty,
);
req.uri().query().map(|v| span.record("url.query", v));
req.uri().scheme().map(|v| span.record("url.scheme", v.as_str()));
req.uri().host().map(|v| span.record("server.address", v));
req.uri().port_u16().map(|v| span.record("server.port", v));
req.headers()
.get(header::CONTENT_LENGTH)
.map(|v| v.to_str().map(|v| span.record("http.request.body.size", v)));
req.headers()
.get(header::USER_AGENT)
.map(|v| v.to_str().map(|v| span.record("user_agent.original", v)));
if let Some(path) = req.extensions().get::<MatchedPath>() {
span.record("otel.name", format!("{} {}", req.method(), path.as_str()));
span.record("http.route", path.as_str());
} else {
span.record("otel.name", format!("{} -", req.method()));
};
if let Some(req_id) = req.extensions().get::<RequestId>() {
match req_id.header_value().to_str() {
Err(err) => tracing::error!(error = %err, "failed to parse request id"),
Ok(request_id) => {
span.record("http.request.id", request_id);
}
}
}
if let Some(client_ip) = req.extensions().get::<ExtractClientIP>() {
if let Some(ref client_ip) = client_ip.0 {
span.record("client.address", client_ip);
}
}
span
}
}
impl<B> OnRequest<B> for HttpTraceLayerHooks {
fn on_request(&mut self, _: &Request<B>, _: &Span) {
tracing::event!(Level::INFO, "started processing request");
}
}
impl<B> OnResponse<B> for HttpTraceLayerHooks {
fn on_response(self, response: &Response<B>, latency: Duration, span: &Span) {
if let Some(size) = response.headers().get(header::CONTENT_LENGTH) {
span.record("http.response.body.size", size.to_str().unwrap());
}
span.record("http.response.status_code", response.status().as_u16());
// Server errors are handled by the OnFailure hook
if !response.status().is_server_error() {
span.record("http.latency.ms", latency.as_millis());
tracing::event!(Level::INFO, "finished processing request");
}
}
}
impl<FailureClass> OnFailure<FailureClass> for HttpTraceLayerHooks
where
FailureClass: fmt::Display,
{
fn on_failure(&mut self, error: FailureClass, latency: Duration, span: &Span) {
span.record("error_message", &error.to_string());
span.record("http.latency.ms", latency.as_millis());
tracing::event!(Level::ERROR, error = error.to_string(), "response failed");
}
}

View file

@ -1,14 +1,18 @@
use crate::cnf::PKG_NAME; use crate::cnf::PKG_NAME;
use crate::cnf::PKG_VERSION; use crate::cnf::PKG_VERSION;
use warp::http; use axum::response::IntoResponse;
use warp::Filter; use axum::routing::get;
use axum::Router;
use http_body::Body as HttpBody;
#[allow(opaque_hidden_inferred_bound)] pub(super) fn router<S, B>() -> Router<S, B>
pub fn config() -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone { where
warp::path("version").and(warp::path::end()).and(warp::get()).and_then(handler) B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
Router::new().route("/version", get(handler))
} }
pub async fn handler() -> Result<impl warp::Reply, warp::Rejection> { async fn handler() -> impl IntoResponse {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION); format!("{PKG_NAME}-{}", *PKG_VERSION)
Ok(warp::reply::with_status(val, http::StatusCode::OK))
} }

View file

@ -1,3 +1,4 @@
use axum::extract::ws::Message;
use serde::Serialize; use serde::Serialize;
use serde_json::{json, Value as Json}; use serde_json::{json, Value as Json};
use std::borrow::Cow; use std::borrow::Cow;
@ -7,7 +8,6 @@ use surrealdb::dbs::Notification;
use surrealdb::sql; use surrealdb::sql;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tracing::instrument; use tracing::instrument;
use warp::ws::Message;
#[derive(Clone)] #[derive(Clone)]
pub enum Output { pub enum Output {
@ -87,19 +87,19 @@ impl Response {
let message = match out { let message = match out {
Output::Json => { Output::Json => {
let res = serde_json::to_string(&self.simplify()).unwrap(); let res = serde_json::to_string(&self.simplify()).unwrap();
Message::text(res) Message::Text(res)
} }
Output::Cbor => { Output::Cbor => {
let res = serde_cbor::to_vec(&self.simplify()).unwrap(); let res = serde_cbor::to_vec(&self.simplify()).unwrap();
Message::binary(res) Message::Binary(res)
} }
Output::Pack => { Output::Pack => {
let res = serde_pack::to_vec(&self.simplify()).unwrap(); let res = serde_pack::to_vec(&self.simplify()).unwrap();
Message::binary(res) Message::Binary(res)
} }
Output::Full => { Output::Full => {
let res = surrealdb::sql::serde::serialize(&self).unwrap(); let res = surrealdb::sql::serde::serialize(&self).unwrap();
Message::binary(res) Message::Binary(res)
} }
}; };
let _ = chn.send(message).await; let _ = chn.send(message).await;

View file

@ -0,0 +1,112 @@
pub(super) mod tower_layer;
use once_cell::sync::Lazy;
use opentelemetry::{
metrics::{Histogram, Meter, MeterProvider, ObservableUpDownCounter, Unit},
runtime,
sdk::{
export::metrics::aggregation,
metrics::{
controllers::{self, BasicController},
processors, selectors,
},
},
Context,
};
use opentelemetry_otlp::MetricsExporterBuilder;
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
// Histogram buckets in milliseconds
static HTTP_DURATION_MS_HISTOGRAM_BUCKETS: &[f64] = &[
5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0,
2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0,
];
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const HTTP_SIZE_HISTOGRAM_BUCKETS: &[f64] = &[
1.0 * KB, // 1 KB
2.0 * KB, // 2 KB
5.0 * KB, // 5 KB
10.0 * KB, // 10 KB
100.0 * KB, // 100 KB
500.0 * KB, // 500 KB
1.0 * MB, // 1 MB
2.5 * MB, // 2 MB
5.0 * MB, // 5 MB
10.0 * MB, // 10 MB
25.0 * MB, // 25 MB
50.0 * MB, // 50 MB
100.0 * MB, // 100 MB
];
static METER_PROVIDER_HTTP_DURATION: Lazy<BasicController> = Lazy::new(|| {
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
.unwrap();
let builder = controllers::basic(processors::factory(
selectors::simple::histogram(HTTP_DURATION_MS_HISTOGRAM_BUCKETS),
aggregation::cumulative_temporality_selector(),
))
.with_exporter(exporter)
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
let controller = builder.build();
controller.start(&Context::current(), runtime::Tokio).unwrap();
controller
});
static METER_PROVIDER_HTTP_SIZE: Lazy<BasicController> = Lazy::new(|| {
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
.unwrap();
let builder = controllers::basic(processors::factory(
selectors::simple::histogram(HTTP_SIZE_HISTOGRAM_BUCKETS),
aggregation::cumulative_temporality_selector(),
))
.with_exporter(exporter)
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
let controller = builder.build();
controller.start(&Context::current(), runtime::Tokio).unwrap();
controller
});
static HTTP_DURATION_METER: Lazy<Meter> =
Lazy::new(|| METER_PROVIDER_HTTP_DURATION.meter("http_duration"));
static HTTP_SIZE_METER: Lazy<Meter> = Lazy::new(|| METER_PROVIDER_HTTP_SIZE.meter("http_size"));
pub static HTTP_SERVER_DURATION: Lazy<Histogram<u64>> = Lazy::new(|| {
HTTP_DURATION_METER
.u64_histogram("http.server.duration")
.with_description("The HTTP server duration in milliseconds.")
.with_unit(Unit::new("ms"))
.init()
});
pub static HTTP_SERVER_ACTIVE_REQUESTS: Lazy<ObservableUpDownCounter<i64>> = Lazy::new(|| {
HTTP_DURATION_METER
.i64_observable_up_down_counter("http.server.active_requests")
.with_description("The number of active HTTP requests.")
.init()
});
pub static HTTP_SERVER_REQUEST_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
HTTP_SIZE_METER
.u64_histogram("http.server.request.size")
.with_description("Measures the size of HTTP request messages.")
.with_unit(Unit::new("mb"))
.init()
});
pub static HTTP_SERVER_RESPONSE_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
HTTP_SIZE_METER
.u64_histogram("http.server.response.size")
.with_description("Measures the size of HTTP response messages.")
.with_unit(Unit::new("mb"))
.init()
});

View file

@ -0,0 +1,310 @@
use axum::extract::MatchedPath;
use opentelemetry::{metrics::MetricsError, Context as TelemetryContext, KeyValue};
use pin_project_lite::pin_project;
use std::{
cell::Cell,
fmt,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
};
use futures::Future;
use http::{Request, Response, StatusCode, Version};
use tower::{Layer, Service};
use super::{
HTTP_DURATION_METER, HTTP_SERVER_ACTIVE_REQUESTS, HTTP_SERVER_DURATION,
HTTP_SERVER_REQUEST_SIZE, HTTP_SERVER_RESPONSE_SIZE,
};
#[derive(Clone, Default)]
pub struct HttpMetricsLayer;
impl<S> Layer<S> for HttpMetricsLayer {
type Service = HttpMetrics<S>;
fn layer(&self, inner: S) -> Self::Service {
HttpMetrics {
inner,
}
}
}
#[derive(Clone)]
pub struct HttpMetrics<S> {
inner: S,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for HttpMetrics<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: http_body::Body,
ResBody: http_body::Body,
S::Error: fmt::Display + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = HttpCallMetricsFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
let tracker = HttpCallMetricTracker::new(&request);
HttpCallMetricsFuture::new(self.inner.call(request), tracker)
}
}
pin_project! {
pub struct HttpCallMetricsFuture<F> {
#[pin]
inner: F,
tracker: HttpCallMetricTracker,
}
}
impl<F> HttpCallMetricsFuture<F> {
fn new(inner: F, tracker: HttpCallMetricTracker) -> Self {
Self {
inner,
tracker,
}
}
}
impl<Fut, ResBody, E> Future for HttpCallMetricsFuture<Fut>
where
Fut: Future<Output = Result<Response<ResBody>, E>>,
ResBody: http_body::Body,
E: std::fmt::Display + 'static,
{
type Output = Result<Response<ResBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.tracker.set_state(ResultState::Started);
if let Err(err) = on_request_start(this.tracker) {
error!("Failed to setup metrics when request started: {}", err);
// Consider this request not tracked: reset the state to None, so that the drop handler does not decrease the counter.
this.tracker.set_state(ResultState::None);
};
let response = futures_util::ready!(this.inner.poll(cx));
let result = match response {
Ok(reply) => {
this.tracker.set_state(ResultState::Result(
reply.status(),
reply.version(),
reply.body().size_hint().exact(),
));
Ok(reply)
}
Err(e) => {
this.tracker.set_state(ResultState::Failed);
Err(e)
}
};
Poll::Ready(result)
}
}
pub struct HttpCallMetricTracker {
version: String,
method: hyper::Method,
scheme: Option<http::uri::Scheme>,
host: Option<String>,
route: Option<String>,
state: Cell<ResultState>,
status_code: Option<StatusCode>,
request_size: Option<u64>,
response_size: Option<u64>,
start: Instant,
finish: Option<Instant>,
}
pub enum ResultState {
/// The result was already processed.
None,
/// Request was started.
Started,
/// The result failed with an error.
Failed,
/// The result is an actual HTTP response.
Result(StatusCode, Version, Option<u64>),
}
impl HttpCallMetricTracker {
fn new<B>(request: &Request<B>) -> Self
where
B: http_body::Body,
{
Self {
version: format!("{:?}", request.version()),
method: request.method().clone(),
scheme: request.uri().scheme().cloned(),
host: request.uri().host().map(|s| s.to_string()),
route: request.extensions().get::<MatchedPath>().map(|v| v.as_str().to_string()),
state: Cell::new(ResultState::None),
status_code: None,
request_size: request.body().size_hint().exact(),
response_size: None,
start: Instant::now(),
finish: None,
}
}
fn set_state(&self, state: ResultState) {
self.state.set(state)
}
pub fn duration(&self) -> Duration {
self.finish.unwrap_or(Instant::now()) - self.start
}
// Follows the OpenTelemetry semantic conventions for HTTP metrics define here: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/metrics/semantic_conventions/http-metrics.md
fn olel_common_attrs(&self) -> Vec<KeyValue> {
let mut res = vec![
KeyValue::new("http.request.method", self.method.as_str().to_owned()),
KeyValue::new("network.protocol.name", "http".to_owned()),
];
if let Some(scheme) = &self.scheme {
res.push(KeyValue::new("url.scheme", scheme.as_str().to_owned()));
}
if let Some(host) = &self.host {
res.push(KeyValue::new("server.address", host.to_owned()));
}
res
}
pub(super) fn active_req_attrs(&self) -> Vec<KeyValue> {
self.olel_common_attrs()
}
pub(super) fn request_duration_attrs(&self) -> Vec<KeyValue> {
let mut res = self.olel_common_attrs();
res.push(KeyValue::new(
"http.response.status_code",
self.status_code.map(|v| v.as_str().to_owned()).unwrap_or("000".to_owned()),
));
if let Some(v) = self.version.strip_prefix("HTTP/") {
res.push(KeyValue::new("network.protocol.version", v.to_owned()));
}
if let Some(target) = &self.route {
res.push(KeyValue::new("http.route", target.to_owned()));
}
res
}
pub(super) fn request_size_attrs(&self) -> Vec<KeyValue> {
self.request_duration_attrs()
}
pub(super) fn response_size_attrs(&self) -> Vec<KeyValue> {
self.request_duration_attrs()
}
}
impl Drop for HttpCallMetricTracker {
fn drop(&mut self) {
match self.state.replace(ResultState::None) {
ResultState::None => {
// Request was not tracked, so no need to decrease the counter.
return;
}
ResultState::Started => {
// If the response was never processed, we can't get a valid status code
}
ResultState::Failed => {
// If there's an error processing the request and we don't have a response, we can't get a valid status code
}
ResultState::Result(s, v, size) => {
self.status_code = Some(s);
self.version = format!("{:?}", v);
self.response_size = size;
}
};
self.finish = Some(Instant::now());
if let Err(err) = on_request_finish(self) {
error!(target: "surrealdb::telemetry", "Failed to setup metrics when request finished: {}", err);
}
}
}
pub fn on_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
// Setup the active_requests observer
observe_active_request_start(tracker)
}
pub fn on_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
// Setup the active_requests observer
observe_active_request_finish(tracker)?;
// Record the duration of the request.
record_request_duration(tracker);
// Record the request size if known
if let Some(size) = tracker.request_size {
record_request_size(tracker, size)
}
// Record the response size if known
if let Some(size) = tracker.response_size {
record_response_size(tracker, size)
}
Ok(())
}
fn observe_active_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
let attrs = tracker.active_req_attrs();
// Setup the callback to observe the active requests.
HTTP_DURATION_METER
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, 1, &attrs))
}
fn observe_active_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
let attrs = tracker.active_req_attrs();
// Setup the callback to observe the active requests.
HTTP_DURATION_METER
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, -1, &attrs))
}
fn record_request_duration(tracker: &HttpCallMetricTracker) {
// Record the duration of the request.
HTTP_SERVER_DURATION.record(
&TelemetryContext::current(),
tracker.duration().as_millis() as u64,
&tracker.request_duration_attrs(),
);
}
pub fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) {
HTTP_SERVER_REQUEST_SIZE.record(
&TelemetryContext::current(),
size,
&tracker.request_size_attrs(),
);
}
pub fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) {
HTTP_SERVER_RESPONSE_SIZE.record(
&TelemetryContext::current(),
size,
&tracker.response_size_attrs(),
);
}

View file

@ -0,0 +1,3 @@
pub mod http;
pub use self::http::tower_layer::HttpMetricsLayer;

View file

@ -1,7 +1,16 @@
mod logger; mod logs;
mod tracers; pub mod metrics;
mod traces;
use std::time::Duration;
use crate::cli::validator::parser::env_filter::CustomEnvFilter; use crate::cli::validator::parser::env_filter::CustomEnvFilter;
use once_cell::sync::Lazy;
use opentelemetry::sdk::resource::{
EnvResourceDetector, SdkProvidedResourceDetector, TelemetryResourceDetector,
};
use opentelemetry::sdk::Resource;
use opentelemetry::KeyValue;
use tracing::Subscriber; use tracing::Subscriber;
use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -9,6 +18,28 @@ use tracing_subscriber::util::SubscriberInitExt;
#[cfg(feature = "has-storage")] #[cfg(feature = "has-storage")]
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
pub static OTEL_DEFAULT_RESOURCE: Lazy<Resource> = Lazy::new(|| {
let res = Resource::from_detectors(
Duration::from_secs(5),
vec![
// set service.name from env OTEL_SERVICE_NAME > env OTEL_RESOURCE_ATTRIBUTES > option_env! CARGO_BIN_NAME > unknown_service
Box::new(SdkProvidedResourceDetector),
// detect res from env OTEL_RESOURCE_ATTRIBUTES (resources string like key1=value1,key2=value2,...)
Box::new(EnvResourceDetector::new()),
// set telemetry.sdk.{name, language, version}
Box::new(TelemetryResourceDetector),
],
);
// If no external service.name is set, set it to surrealdb
if res.get("service.name".into()).unwrap_or("".into()).as_str() == "unknown_service" {
debug!("No service.name detected, use 'surrealdb'");
res.merge(&Resource::new([KeyValue::new("service.name", "surrealdb")]))
} else {
res
}
});
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
pub struct Builder { pub struct Builder {
log_level: Option<String>, log_level: Option<String>,
@ -32,7 +63,8 @@ impl Builder {
self.filter = Some(CustomEnvFilter(filter)); self.filter = Some(CustomEnvFilter(filter));
self self
} }
/// Build a dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> { pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
let registry = tracing_subscriber::registry(); let registry = tracing_subscriber::registry();
let registry = registry.with(self.filter.map(|filter| { let registry = registry.with(self.filter.map(|filter| {
@ -44,11 +76,12 @@ impl Builder {
.with_filter(filter.0) .with_filter(filter.0)
.boxed() .boxed()
})); }));
let registry = registry.with(self.log_level.map(logger::new)); let registry = registry.with(self.log_level.map(logs::new));
let registry = registry.with(tracers::new()); let registry = registry.with(traces::new());
Box::new(registry) Box::new(registry)
} }
/// Build a dispatcher and set it as global
/// tracing pipeline
pub fn init(self) { pub fn init(self) {
self.build().init() self.build().init()
} }
@ -60,10 +93,12 @@ mod tests {
use tracing::{span, Level}; use tracing::{span, Level};
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use crate::telemetry;
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
async fn test_otlp_tracer() { async fn test_otlp_tracer() {
println!("Starting mock otlp server..."); println!("Starting mock otlp server...");
let (addr, mut req_rx) = super::tracers::tests::mock_otlp_server().await; let (addr, mut req_rx) = telemetry::traces::tests::mock_otlp_server().await;
{ {
let otlp_endpoint = format!("http://{}", addr); let otlp_endpoint = format!("http://{}", addr);
@ -73,7 +108,7 @@ mod tests {
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())), ("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
], ],
|| { || {
let _enter = super::builder().build().set_default(); let _enter = telemetry::builder().build().set_default();
println!("Sending span..."); println!("Sending span...");
@ -90,16 +125,8 @@ mod tests {
println!("Waiting for request..."); println!("Waiting for request...");
let req = req_rx.recv().await.expect("missing export request"); let req = req_rx.recv().await.expect("missing export request");
let first_span = req let first_span =
.resource_spans req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
.first()
.unwrap()
.instrumentation_library_spans
.first()
.unwrap()
.spans
.first()
.unwrap();
assert_eq!("test-surreal-span", first_span.name); assert_eq!("test-surreal-span", first_span.name);
let first_event = first_span.events.first().unwrap(); let first_event = first_span.events.first().unwrap();
assert_eq!("test-surreal-event", first_event.name); assert_eq!("test-surreal-event", first_event.name);
@ -108,7 +135,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
async fn test_tracing_filter() { async fn test_tracing_filter() {
println!("Starting mock otlp server..."); println!("Starting mock otlp server...");
let (addr, mut req_rx) = super::tracers::tests::mock_otlp_server().await; let (addr, mut req_rx) = telemetry::traces::tests::mock_otlp_server().await;
{ {
let otlp_endpoint = format!("http://{}", addr); let otlp_endpoint = format!("http://{}", addr);
@ -119,7 +146,7 @@ mod tests {
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())), ("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
], ],
|| { || {
let _enter = super::builder().build().set_default(); let _enter = telemetry::builder().build().set_default();
println!("Sending spans..."); println!("Sending spans...");
@ -144,14 +171,7 @@ mod tests {
println!("Waiting for request..."); println!("Waiting for request...");
let req = req_rx.recv().await.expect("missing export request"); let req = req_rx.recv().await.expect("missing export request");
let spans = &req let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
.resource_spans
.first()
.unwrap()
.instrumentation_library_spans
.first()
.unwrap()
.spans;
assert_eq!(1, spans.len()); assert_eq!(1, spans.len());
assert_eq!("debug", spans.first().unwrap().name); assert_eq!("debug", spans.first().unwrap().name);

View file

@ -59,7 +59,9 @@ pub mod tests {
request: tonic::Request<ExportTraceServiceRequest>, request: tonic::Request<ExportTraceServiceRequest>,
) -> Result<tonic::Response<ExportTraceServiceResponse>, tonic::Status> { ) -> Result<tonic::Response<ExportTraceServiceResponse>, tonic::Status> {
self.tx.lock().unwrap().try_send(request.into_inner()).expect("Channel full"); self.tx.lock().unwrap().try_send(request.into_inner()).expect("Channel full");
Ok(tonic::Response::new(ExportTraceServiceResponse {})) Ok(tonic::Response::new(ExportTraceServiceResponse {
partial_success: None,
}))
} }
} }

View file

@ -1,10 +1,11 @@
use opentelemetry::sdk::{trace::Tracer, Resource}; use opentelemetry::sdk::trace::Tracer;
use opentelemetry::trace::TraceError; use opentelemetry::trace::TraceError;
use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use tracing::{Level, Subscriber}; use tracing::{Level, Subscriber};
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{EnvFilter, Layer};
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER"; const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER";
pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync> pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync>
@ -15,12 +16,12 @@ where
} }
fn tracer() -> Result<Tracer, TraceError> { fn tracer() -> Result<Tracer, TraceError> {
let resource = Resource::new(vec![KeyValue::new("service.name", "surrealdb")]);
opentelemetry_otlp::new_pipeline() opentelemetry_otlp::new_pipeline()
.tracing() .tracing()
.with_exporter(opentelemetry_otlp::new_exporter().tonic().with_env()) .with_exporter(opentelemetry_otlp::new_exporter().tonic().with_env())
.with_trace_config(opentelemetry::sdk::trace::config().with_resource(resource)) .with_trace_config(
opentelemetry::sdk::trace::config().with_resource(OTEL_DEFAULT_RESOURCE.clone()),
)
.install_batch(opentelemetry::runtime::Tokio) .install_batch(opentelemetry::runtime::Tokio)
} }

View file

@ -1,354 +0,0 @@
mod cli_integration {
// cargo test --package surreal --bin surreal --no-default-features --features storage-mem --test cli -- cli_integration --nocapture
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use rand::{thread_rng, Rng};
use serial_test::serial;
use std::fs;
use std::path::Path;
use std::process::{Command, Stdio};
/// Child is a (maybe running) CLI process. It can be killed by dropping it
struct Child {
inner: Option<std::process::Child>,
}
impl Child {
/// Send some thing to the child's stdin
fn input(mut self, input: &str) -> Self {
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
use std::io::Write;
stdin.write_all(input.as_bytes()).unwrap();
self
}
fn kill(mut self) -> Self {
self.inner.as_mut().unwrap().kill().unwrap();
self
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
fn output(mut self) -> Result<String, String> {
let output = self.inner.take().unwrap().wait_with_output().unwrap();
let mut buf = String::from_utf8(output.stdout).unwrap();
buf.push_str(&String::from_utf8(output.stderr).unwrap());
if output.status.success() {
Ok(buf)
} else {
Err(buf)
}
}
}
impl Drop for Child {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
let _ = inner.kill();
}
}
}
fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
assert!(path.pop());
}
// Note: Cargo automatically builds this binary for integration tests.
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
let mut cmd = Command::new(path);
if let Some(dir) = current_dir {
cmd.current_dir(&dir);
}
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
}
}
/// Run the CLI with the given args
fn run(args: &str) -> Child {
run_internal::<String>(args, None)
}
/// Run the CLI with the given args inside a temporary directory
fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
}
fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(name);
path.to_string_lossy().into_owned()
}
#[test]
#[serial]
fn version() {
assert!(run("version").output().is_ok());
}
#[test]
#[serial]
fn help() {
assert!(run("help").output().is_ok());
}
#[test]
#[serial]
fn nonexistent_subcommand() {
assert!(run("nonexistent").output().is_err());
}
#[test]
#[serial]
fn nonexistent_option() {
assert!(run("version --turbo").output().is_err());
}
#[test]
#[serial]
fn start() {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let pass = rng.gen::<u64>().to_string();
let start_args =
format!("start --bind {addr} --user root --pass {pass} memory --no-banner --log info");
println!("starting server with args: {start_args}");
let _server = run(&start_args);
std::thread::sleep(std::time::Duration::from_millis(5000));
assert!(run(&format!("isready --conn http://{addr}")).output().is_ok());
// Create a record
{
let args =
format!("sql --conn http://{addr} --user root --pass {pass} --ns N --db D --multi");
assert_eq!(
run(&args).input("CREATE thing:one;\n").output(),
Ok("[{ id: thing:one }]\n\n".to_owned()),
"failed to send sql: {args}"
);
}
// Export to stdout
{
let args =
format!("export --conn http://{addr} --user root --pass {pass} --ns N --db D -");
let output = run(&args).output().expect("failed to run stdout export: {args}");
assert!(output.contains("DEFINE TABLE thing SCHEMALESS PERMISSIONS NONE;"));
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
}
// Export to file
let exported = {
let exported = tmp_file("exported.surql");
let args = format!(
"export --conn http://{addr} --user root --pass {pass} --ns N --db D {exported}"
);
run(&args).output().expect("failed to run file export: {args}");
exported
};
// Import the exported file
{
let args = format!(
"import --conn http://{addr} --user root --pass {pass} --ns N --db D2 {exported}"
);
run(&args).output().expect("failed to run import: {args}");
}
// Query from the import (pretty-printed this time)
{
let args = format!(
"sql --conn http://{addr} --user root --pass {pass} --ns N --db D2 --pretty"
);
assert_eq!(
run(&args).input("SELECT * FROM thing;\n").output(),
Ok("[\n\t{\n\t\tid: thing:one\n\t}\n]\n\n".to_owned()),
"failed to send sql: {args}"
);
}
// Unfinished backup CLI
{
let file = tmp_file("backup.db");
let args = format!("backup --user root --pass {pass} http://{addr} {file}");
run(&args).output().expect("failed to run backup: {args}");
// TODO: Once backups are functional, update this test.
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
}
// Multi-statement (and multi-line) query including error(s) over WS
{
let args = format!(
"sql --conn ws://{addr} --user root --pass {pass} --ns N3 --db D3 --multi --pretty"
);
let output = run(&args)
.input(
r#"CREATE thing:success; \
CREATE thing:fail SET bad=rand('evil'); \
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
CREATE thing:also_success;
"#,
)
.output()
.unwrap();
assert!(output.contains("thing:success"), "missing success in {output}");
assert!(output.contains("rgument"), "missing argument error in {output}");
assert!(
output.contains("time") && output.contains("out"),
"missing timeout error in {output}"
);
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
}
// Multi-statement (and multi-line) transaction including error(s) over WS
{
let args = format!(
"sql --conn ws://{addr} --user root --pass {pass} --ns N4 --db D4 --multi --pretty"
);
let output = run(&args)
.input(
r#"BEGIN; \
CREATE thing:success; \
CREATE thing:fail SET bad=rand('evil'); \
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
CREATE thing:also_success; \
COMMIT;
"#,
)
.output()
.unwrap();
assert_eq!(
output.lines().filter(|s| s.contains("transaction")).count(),
3,
"missing failed txn errors in {output:?}"
);
assert!(output.contains("rgument"), "missing argument error in {output}");
}
// Pass neither ns nor db
{
let args = format!("sql --conn http://{addr} --user root --pass {pass}");
let output = run(&args)
.input("USE NS N5 DB D5; CREATE thing:one;\n")
.output()
.expect("neither ns nor db");
assert!(output.contains("thing:one"), "missing thing:one in {output}");
}
// Pass only ns
{
let args = format!("sql --conn http://{addr} --user root --pass {pass} --ns N5");
let output = run(&args)
.input("USE DB D5; SELECT * FROM thing:one;\n")
.output()
.expect("only ns");
assert!(output.contains("thing:one"), "missing thing:one in {output}");
}
// Pass only db and expect an error
{
let args = format!("sql --conn http://{addr} --user root --pass {pass} --db D5");
run(&args).output().expect_err("only db");
}
}
#[test]
#[serial]
fn start_tls() {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let pass = rng.gen::<u128>().to_string();
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
let key_path = tmp_file("key.pem");
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
let start_args = format!(
"start --bind {addr} --user root --pass {pass} memory --log info --web-crt {crt_path} --web-key {key_path}"
);
println!("starting server with args: {start_args}");
let server = run(&start_args);
std::thread::sleep(std::time::Duration::from_millis(750));
let output = server.kill().output().unwrap_err();
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
}
#[test]
#[serial]
fn validate_found_no_files() {
let temp_dir = assert_fs::TempDir::new().unwrap();
temp_dir.child("file.txt").touch().unwrap();
assert!(run_in_dir("validate", &temp_dir).output().is_err());
}
#[test]
#[serial]
fn validate_succeed_for_valid_surql_files() {
let temp_dir = assert_fs::TempDir::new().unwrap();
let statement_file = temp_dir.child("statement.surql");
statement_file.touch().unwrap();
statement_file.write_str("CREATE thing:success;").unwrap();
assert!(run_in_dir("validate", &temp_dir).output().is_ok());
}
#[test]
#[serial]
fn validate_failed_due_to_invalid_glob_pattern() {
let temp_dir = assert_fs::TempDir::new().unwrap();
const WRONG_GLOB_PATTERN: &str = "**/*{.txt";
let args = format!("validate \"{}\"", WRONG_GLOB_PATTERN);
assert!(run_in_dir(&args, &temp_dir).output().is_err());
}
#[test]
#[serial]
fn validate_failed_due_to_invalid_surql_files_syntax() {
let temp_dir = assert_fs::TempDir::new().unwrap();
let statement_file = temp_dir.child("statement.surql");
statement_file.touch().unwrap();
statement_file.write_str("CREATE $thing WHERE value = '';").unwrap();
assert!(run_in_dir("validate", &temp_dir).output().is_err());
}
}

273
tests/cli_integration.rs Normal file
View file

@ -0,0 +1,273 @@
// cargo test --package surreal --bin surreal --no-default-features --features storage-mem --test cli -- cli_integration --nocapture
mod common;
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use serial_test::serial;
use std::fs;
use common::{PASS, USER};
#[test]
#[serial]
fn version() {
assert!(common::run("version").output().is_ok());
}
#[test]
#[serial]
fn help() {
assert!(common::run("help").output().is_ok());
}
#[test]
#[serial]
fn nonexistent_subcommand() {
assert!(common::run("nonexistent").output().is_err());
}
#[test]
#[serial]
fn nonexistent_option() {
assert!(common::run("version --turbo").output().is_err());
}
#[tokio::test]
#[serial]
async fn all_commands() {
let (addr, _server) = common::start_server(false, true).await.unwrap();
let creds = format!("--user {USER} --pass {PASS}");
// Create a record
{
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
assert_eq!(
common::run(&args).input("CREATE thing:one;\n").output(),
Ok("[{ id: thing:one }]\n\n".to_owned()),
"failed to send sql: {args}"
);
}
// Export to stdout
{
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
let output = common::run(&args).output().expect("failed to run stdout export: {args}");
assert!(output.contains("DEFINE TABLE thing SCHEMALESS PERMISSIONS NONE;"));
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
}
// Export to file
let exported = {
let exported = common::tmp_file("exported.surql");
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
common::run(&args).output().expect("failed to run file export: {args}");
exported
};
// Import the exported file
{
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
common::run(&args).output().expect("failed to run import: {args}");
}
// Query from the import (pretty-printed this time)
{
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
assert_eq!(
common::run(&args).input("SELECT * FROM thing;\n").output(),
Ok("[\n\t{\n\t\tid: thing:one\n\t}\n]\n\n".to_owned()),
"failed to send sql: {args}"
);
}
// Unfinished backup CLI
{
let file = common::tmp_file("backup.db");
let args = format!("backup {creds} http://{addr} {file}");
common::run(&args).output().expect("failed to run backup: {args}");
// TODO: Once backups are functional, update this test.
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
}
// Multi-statement (and multi-line) query including error(s) over WS
{
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
let output = common::run(&args)
.input(
r#"CREATE thing:success; \
CREATE thing:fail SET bad=rand('evil'); \
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
CREATE thing:also_success;
"#,
)
.output()
.unwrap();
assert!(output.contains("thing:success"), "missing success in {output}");
assert!(output.contains("rgument"), "missing argument error in {output}");
assert!(
output.contains("time") && output.contains("out"),
"missing timeout error in {output}"
);
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
}
// Multi-statement (and multi-line) transaction including error(s) over WS
{
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
let output = common::run(&args)
.input(
r#"BEGIN; \
CREATE thing:success; \
CREATE thing:fail SET bad=rand('evil'); \
SELECT * FROM sleep(10ms) TIMEOUT 1ms; \
CREATE thing:also_success; \
COMMIT;
"#,
)
.output()
.unwrap();
assert_eq!(
output.lines().filter(|s| s.contains("transaction")).count(),
3,
"missing failed txn errors in {output:?}"
);
assert!(output.contains("rgument"), "missing argument error in {output}");
}
// Pass neither ns nor db
{
let args = format!("sql --conn http://{addr} {creds}");
let output = common::run(&args)
.input("USE NS N5 DB D5; CREATE thing:one;\n")
.output()
.expect("neither ns nor db");
assert!(output.contains("thing:one"), "missing thing:one in {output}");
}
// Pass only ns
{
let args = format!("sql --conn http://{addr} {creds} --ns N5");
let output = common::run(&args)
.input("USE DB D5; SELECT * FROM thing:one;\n")
.output()
.expect("only ns");
assert!(output.contains("thing:one"), "missing thing:one in {output}");
}
// Pass only db and expect an error
{
let args = format!("sql --conn http://{addr} {creds} --db D5");
common::run(&args).output().expect_err("only db");
}
}
#[tokio::test]
#[serial]
async fn start_tls() {
let (_, server) = common::start_server(true, false).await.unwrap();
std::thread::sleep(std::time::Duration::from_millis(2000));
let output = server.kill().output().err().unwrap();
// Test the crt/key args but the keys are self signed so don't actually connect.
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
}
#[tokio::test]
#[serial]
async fn with_root_auth() {
let (addr, _server) = common::start_server(false, true).await.unwrap();
let creds = format!("--user {USER} --pass {PASS}");
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
// Can query /sql over HTTP
{
let args = format!("{sql_args} {creds}");
let input = "INFO FOR ROOT;";
let output = common::run(&args).input(input).output();
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
}
// Can query /sql over WS
{
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
let input = "INFO FOR ROOT;";
let output = common::run(&args).input(input).output();
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
}
// KV user can do exports
let exported = {
let exported = common::tmp_file("exported.surql");
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run export: {args}"));
exported
};
// KV user can do imports
{
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
}
// KV user can do backups
{
let file = common::tmp_file("backup.db");
let args = format!("backup {creds} http://{addr} {file}");
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run backup: {args}"));
// TODO: Once backups are functional, update this test.
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
}
}
#[test]
#[serial]
fn validate_found_no_files() {
let temp_dir = assert_fs::TempDir::new().unwrap();
temp_dir.child("file.txt").touch().unwrap();
assert!(common::run_in_dir("validate", &temp_dir).output().is_err());
}
#[test]
#[serial]
fn validate_succeed_for_valid_surql_files() {
let temp_dir = assert_fs::TempDir::new().unwrap();
let statement_file = temp_dir.child("statement.surql");
statement_file.touch().unwrap();
statement_file.write_str("CREATE thing:success;").unwrap();
assert!(common::run_in_dir("validate", &temp_dir).output().is_ok());
}
#[test]
#[serial]
fn validate_failed_due_to_invalid_glob_pattern() {
let temp_dir = assert_fs::TempDir::new().unwrap();
const WRONG_GLOB_PATTERN: &str = "**/*{.txt";
let args = format!("validate \"{}\"", WRONG_GLOB_PATTERN);
assert!(common::run_in_dir(&args, &temp_dir).output().is_err());
}
#[test]
#[serial]
fn validate_failed_due_to_invalid_surql_files_syntax() {
let temp_dir = assert_fs::TempDir::new().unwrap();
let statement_file = temp_dir.child("statement.surql");
statement_file.touch().unwrap();
statement_file.write_str("CREATE $thing WHERE value = '';").unwrap();
assert!(common::run_in_dir("validate", &temp_dir).output().is_err());
}

140
tests/common/mod.rs Normal file
View file

@ -0,0 +1,140 @@
#![allow(dead_code)]
use rand::{thread_rng, Rng};
use std::error::Error;
use std::fs;
use std::path::Path;
use std::process::{Command, Stdio};
use tokio::time;
pub const USER: &str = "root";
pub const PASS: &str = "root";
/// Child is a (maybe running) CLI process. It can be killed by dropping it
pub struct Child {
inner: Option<std::process::Child>,
}
impl Child {
/// Send some thing to the child's stdin
pub fn input(mut self, input: &str) -> Self {
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
use std::io::Write;
stdin.write_all(input.as_bytes()).unwrap();
self
}
pub fn kill(mut self) -> Self {
self.inner.as_mut().unwrap().kill().unwrap();
self
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let output = self.inner.take().unwrap().wait_with_output().unwrap();
let mut buf = String::from_utf8(output.stdout).unwrap();
buf.push_str(&String::from_utf8(output.stderr).unwrap());
if output.status.success() {
Ok(buf)
} else {
Err(buf)
}
}
}
impl Drop for Child {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
let _ = inner.kill();
}
}
}
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
assert!(path.pop());
}
// Note: Cargo automatically builds this binary for integration tests.
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
let mut cmd = Command::new(path);
if let Some(dir) = current_dir {
cmd.current_dir(&dir);
}
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
}
}
/// Run the CLI with the given args
pub fn run(args: &str) -> Child {
run_internal::<String>(args, None)
}
/// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
}
pub fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(name);
path.to_string_lossy().into_owned()
}
pub async fn start_server(
tls: bool,
wait_is_ready: bool,
) -> Result<(String, Child), Box<dyn Error>> {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let mut extra_args = String::default();
if tls {
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
let key_path = tmp_file("key.pem");
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
}
let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}");
println!("starting server with args: {start_args}");
let server = run(&start_args);
if !wait_is_ready {
return Ok((addr, server));
}
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(500));
println!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
println!("Server ready!");
return Ok((addr, server));
}
}
let server_out = server.kill().output().err().unwrap();
println!("server output: {server_out}");
Err("server failed to start".into())
}

1302
tests/http_integration.rs Normal file

File diff suppressed because it is too large Load diff