From 53702c247ab03baf5ddab108902f927deadb6c99 Mon Sep 17 00:00:00 2001 From: Salvador Girones Gil Date: Wed, 19 Jul 2023 16:35:56 +0200 Subject: [PATCH] [metrics] HTTP Layer + move to Axum (#2237) --- .github/workflows/ci.yml | 24 +- Cargo.lock | 301 ++-- Cargo.toml | 18 +- dev/docker/compose.yaml | 69 + dev/docker/grafana-dashboards.yaml | 15 + dev/docker/grafana-datasource.yaml | 30 + dev/docker/grafana.ini | 6 + dev/docker/otel-collector.yaml | 36 + dev/docker/prometheus.yaml | 17 + dev/docker/tempo.yaml | 23 + doc/TELEMETRY.md | 16 + src/cli/backup.rs | 2 +- src/cli/export.rs | 2 +- src/cli/import.rs | 2 +- src/cli/isready.rs | 2 +- src/cli/sql.rs | 2 +- src/cli/start.rs | 2 +- src/cli/upgrade.rs | 2 +- src/cli/validator/parser/env_filter.rs | 4 + src/cli/version.rs | 2 +- src/err/mod.rs | 59 +- src/iam/mod.rs | 2 - src/iam/verify.rs | 81 +- src/main.rs | 2 +- src/net/auth.rs | 106 ++ src/net/client_ip.rs | 135 +- src/net/export.rs | 36 +- src/net/fail.rs | 126 -- src/net/head.rs | 41 - src/net/headers.rs | 95 ++ src/net/health.rs | 20 +- src/net/import.rs | 60 +- src/net/index.rs | 10 - src/net/input.rs | 4 +- src/net/key.rs | 426 +++--- src/net/log.rs | 30 - src/net/mod.rs | 209 ++- src/net/output.rs | 32 +- src/net/rpc.rs | 59 +- src/net/session.rs | 57 - src/net/signals.rs | 16 +- src/net/signin.rs | 71 +- src/net/signup.rs | 69 +- src/net/sql.rs | 96 +- src/net/status.rs | 6 - src/net/sync.rs | 30 +- src/net/tracer.rs | 235 +++ src/net/version.rs | 20 +- src/rpc/res.rs | 10 +- src/{o11y/logger.rs => telemetry/logs/mod.rs} | 0 src/telemetry/metrics/http/mod.rs | 112 ++ src/telemetry/metrics/http/tower_layer.rs | 310 ++++ src/telemetry/metrics/mod.rs | 3 + src/{o11y => telemetry}/mod.rs | 76 +- src/{o11y/tracers => telemetry/traces}/mod.rs | 4 +- .../tracers => telemetry/traces}/otlp.rs | 11 +- tests/cli.rs | 354 ----- tests/cli_integration.rs | 273 ++++ tests/common/mod.rs | 140 ++ tests/http_integration.rs | 1302 +++++++++++++++++ 60 files changed, 3873 insertions(+), 1430 deletions(-) create mode 100644 dev/docker/compose.yaml create mode 100644 dev/docker/grafana-dashboards.yaml create mode 100644 dev/docker/grafana-datasource.yaml create mode 100644 dev/docker/grafana.ini create mode 100644 dev/docker/otel-collector.yaml create mode 100644 dev/docker/prometheus.yaml create mode 100644 dev/docker/tempo.yaml create mode 100644 doc/TELEMETRY.md create mode 100644 src/net/auth.rs delete mode 100644 src/net/fail.rs delete mode 100644 src/net/head.rs create mode 100644 src/net/headers.rs delete mode 100644 src/net/index.rs delete mode 100644 src/net/log.rs delete mode 100644 src/net/session.rs delete mode 100644 src/net/status.rs create mode 100644 src/net/tracer.rs rename src/{o11y/logger.rs => telemetry/logs/mod.rs} (100%) create mode 100644 src/telemetry/metrics/http/mod.rs create mode 100644 src/telemetry/metrics/http/tower_layer.rs create mode 100644 src/telemetry/metrics/mod.rs rename src/{o11y => telemetry}/mod.rs (64%) rename src/{o11y/tracers => telemetry/traces}/mod.rs (96%) rename src/{o11y/tracers => telemetry/traces}/otlp.rs (80%) delete mode 100644 tests/cli.rs create mode 100644 tests/cli_integration.rs create mode 100644 tests/common/mod.rs create mode 100644 tests/http_integration.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 252e32e3..c272f615 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -118,7 +118,29 @@ jobs: sudo apt-get -y install protobuf-compiler libprotobuf-dev - name: Run cargo test - run: cargo test --locked --no-default-features --features storage-mem --workspace --test cli + run: cargo test --locked --no-default-features --features storage-mem --workspace --test cli_integration + + http-server: + name: Test HTTP server + runs-on: ubuntu-20.04 + steps: + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Checkout sources + uses: actions/checkout@v3 + + - name: Setup cache + uses: Swatinem/rust-cache@v2 + + - name: Install dependencies + run: | + sudo apt-get -y update + sudo apt-get -y install protobuf-compiler libprotobuf-dev + + - name: Run cargo test + run: cargo test --locked --no-default-features --features storage-mem --workspace --test http_integration test: name: Test workspace diff --git a/Cargo.lock b/Cargo.lock index 8a0260de..b36e93bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -405,20 +405,6 @@ dependencies = [ "futures-core", ] -[[package]] -name = "async-compression" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a" -dependencies = [ - "brotli", - "flate2", - "futures-core", - "memchr", - "pin-project-lite", - "tokio", -] - [[package]] name = "async-executor" version = "1.5.1" @@ -543,9 +529,11 @@ checksum = "a6a1de45611fdb535bfde7b7de4fd54f4fd2b17b1737c0a59b69bf9b92074b8c" dependencies = [ "async-trait", "axum-core", + "base64 0.21.2", "bitflags 1.3.2", "bytes", "futures-util", + "headers", "http", "http-body", "hyper", @@ -560,11 +548,25 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.19.0", "tower", "tower-layer", "tower-service", + "tracing", +] + +[[package]] +name = "axum-client-ip" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df8e81eacc93f36480825da5f46a33b5fb2246ed024eacc9e8933425b80c5807" +dependencies = [ + "axum", + "forwarded-header-value", + "serde", ] [[package]] @@ -582,6 +584,7 @@ dependencies = [ "rustversion", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -595,6 +598,64 @@ dependencies = [ "tokio", ] +[[package]] +name = "axum-extra" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "febf23ab04509bd7672e6abe76bd8277af31b679e89fa5ffc6087dc289a448a3" +dependencies = [ + "axum", + "axum-core", + "axum-macros", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "mime", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_html_form", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-macros" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bb524613be645939e280b7279f7b017f98cf7f5ef084ec374df373530e73277" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 2.0.26", +] + +[[package]] +name = "axum-server" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063" +dependencies = [ + "arc-swap", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "rustls 0.21.5", + "rustls-pemfile", + "tokio", + "tokio-rustls 0.24.1", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.68" @@ -616,6 +677,12 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" + [[package]] name = "base64" version = "0.21.2" @@ -1349,6 +1416,12 @@ dependencies = [ "parking_lot_core 0.9.8", ] +[[package]] +name = "data-encoding" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" + [[package]] name = "debugid" version = "0.8.0" @@ -1654,6 +1727,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror", +] + [[package]] name = "foundationdb" version = "0.8.0" @@ -2170,6 +2253,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -2741,24 +2830,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "multer" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" -dependencies = [ - "bytes", - "encoding_rs", - "futures-util", - "http", - "httparse", - "log", - "memchr", - "mime", - "spin 0.9.8", - "version_check", -] - [[package]] name = "multimap" version = "0.8.3" @@ -2844,6 +2915,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2982,9 +3059,9 @@ dependencies = [ [[package]] name = "opentelemetry" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" +checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f" dependencies = [ "opentelemetry_api", "opentelemetry_sdk", @@ -2992,9 +3069,9 @@ dependencies = [ [[package]] name = "opentelemetry-otlp" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1c928609d087790fc936a1067bdc310ae702bdf3b090c3f281b713622c8bbde" +checksum = "8af72d59a4484654ea8eb183fea5ae4eb6a41d7ac3e3bae5f4d2a282a3a7d3ca" dependencies = [ "async-trait", "futures 0.3.28", @@ -3010,39 +3087,38 @@ dependencies = [ [[package]] name = "opentelemetry-proto" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61a2f56df5574508dd86aaca016c917489e589ece4141df1b5e349af8d66c28" +checksum = "045f8eea8c0fa19f7d48e7bc3128a39c2e5c533d5c61298c548dfefc1064474c" dependencies = [ "futures 0.3.28", "futures-util", "opentelemetry", "prost 0.11.9", "tonic", - "tonic-build", ] [[package]] name = "opentelemetry_api" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22" +checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2" dependencies = [ "fnv", "futures-channel", "futures-util", "indexmap 1.9.3", - "js-sys", "once_cell", "pin-project-lite", "thiserror", + "urlencoding", ] [[package]] name = "opentelemetry_sdk" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113" +checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1" dependencies = [ "async-trait", "crossbeam-channel", @@ -4241,12 +4317,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.1.0" @@ -4358,6 +4428,19 @@ dependencies = [ "syn 2.0.26", ] +[[package]] +name = "serde_html_form" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7d9b8af723e4199801a643c622daa7aae8cf4a772dc2b3efcb3a95add6cb91a" +dependencies = [ + "form_urlencoded", + "indexmap 2.0.0", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.103" @@ -4607,12 +4690,18 @@ version = "1.0.0-beta.9" dependencies = [ "argon2", "assert_fs", + "axum", + "axum-client-ip", + "axum-extra", + "axum-server", "base64 0.21.2", "bytes", "clap 4.3.12", "futures 0.3.28", + "futures-util", "glob", "http", + "http-body", "hyper", "ipnet", "nix", @@ -4620,6 +4709,7 @@ dependencies = [ "opentelemetry", "opentelemetry-otlp", "opentelemetry-proto", + "pin-project-lite", "rand 0.8.5", "rcgen", "reqwest", @@ -4637,12 +4727,13 @@ dependencies = [ "tokio-stream", "tokio-util", "tonic", + "tower", + "tower-http", "tracing", "tracing-opentelemetry", "tracing-subscriber", "urlencoding", "uuid", - "warp", ] [[package]] @@ -4715,7 +4806,7 @@ dependencies = [ "thiserror", "time 0.3.23", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.18.0", "tokio-util", "tracing", "tracing-subscriber", @@ -5141,11 +5232,23 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls 0.23.4", - "tungstenite", + "tungstenite 0.18.0", "webpki", "webpki-roots", ] +[[package]] +name = "tokio-tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.19.0", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -5219,19 +5322,6 @@ dependencies = [ "tracing-futures", ] -[[package]] -name = "tonic-build" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf5e9b9c0f7e0a7c027dcfaba7b2c60816c7049171f679d99ee2ff65d0de8c4" -dependencies = [ - "prettyplease 0.1.25", - "proc-macro2", - "prost-build 0.11.9", - "quote", - "syn 1.0.109", -] - [[package]] name = "tower" version = "0.4.13" @@ -5252,6 +5342,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c" +dependencies = [ + "base64 0.20.0", + "bitflags 2.3.3", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "mime", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "tracing", + "uuid", +] + [[package]] name = "tower-layer" version = "0.3.2" @@ -5321,9 +5434,9 @@ dependencies = [ [[package]] name = "tracing-opentelemetry" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" +checksum = "00a39dcf9bfc1742fa4d6215253b33a6e474be78275884c216fc2a06267b3600" dependencies = [ "once_cell", "opentelemetry", @@ -5396,6 +5509,25 @@ dependencies = [ "webpki", ] +[[package]] +name = "tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.16.0" @@ -5544,39 +5676,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "warp" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69" -dependencies = [ - "async-compression", - "bytes", - "futures-channel", - "futures-util", - "headers", - "http", - "hyper", - "log", - "mime", - "mime_guess", - "multer", - "percent-encoding", - "pin-project", - "rustls-pemfile", - "scoped-tls", - "serde", - "serde_json", - "serde_urlencoded", - "tokio", - "tokio-rustls 0.23.4", - "tokio-stream", - "tokio-tungstenite", - "tokio-util", - "tower-service", - "tracing", -] - [[package]] name = "wasi" version = "0.9.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index d7276a41..19f7ac6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,17 +34,24 @@ strip = false [dependencies] argon2 = "0.5.1" +axum = { version = "0.6.18", features = ["tracing", "ws", "headers"] } +axum-client-ip = "0.4.1" +axum-extra = { version = "0.7.4", features = ["typed-routing"] } +axum-server = { version = "0.5.1", features = ["tls-rustls"] } base64 = "0.21.2" bytes = "1.4.0" clap = { version = "4.3.12", features = ["env", "derive", "wrap_help", "unicode"] } futures = "0.3.28" +futures-util = "0.3.28" glob = "0.3.1" http = "0.2.9" +http-body = "0.4.5" hyper = "0.14.27" ipnet = "2.8.0" once_cell = "1.18.0" -opentelemetry = { version = "0.18", features = ["rt-tokio"] } -opentelemetry-otlp = "0.11.0" +opentelemetry = { version = "0.19", features = ["rt-tokio"] } +opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] } +pin-project-lite = "0.2.9" rand = "0.8.5" reqwest = { version = "0.11.18", features = ["blocking"] } rustyline = { version = "11.0.0", features = ["derive"] } @@ -57,19 +64,20 @@ tempfile = "3.6.0" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["macros", "signal"] } tokio-util = { version = "0.7.8", features = ["io"] } +tower = "0.4.13" +tower-http = { version = "0.4.1", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] } tracing = "0.1" -tracing-opentelemetry = "0.18.0" +tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } urlencoding = "2.1.2" uuid = { version = "1.4.0", features = ["serde", "js", "v4", "v7"] } -warp = { version = "0.3.5", features = ["compression", "tls", "websocket"] } [target.'cfg(unix)'.dependencies] nix = "0.26.2" [dev-dependencies] assert_fs = "1.0.13" -opentelemetry-proto = { version = "0.1.0", features = ["gen-tonic", "traces", "build-server"] } +opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] } rcgen = "0.10.0" serial_test = "2.0.0" temp-env = "0.3.4" diff --git a/dev/docker/compose.yaml b/dev/docker/compose.yaml new file mode 100644 index 00000000..5bdc9f0a --- /dev/null +++ b/dev/docker/compose.yaml @@ -0,0 +1,69 @@ +--- +version: "3" +services: + grafana: + image: "grafana/grafana-oss:latest" + expose: + - "3000" + ports: + - "3000:3000" + volumes: + - "grafana:/var/lib/grafana" + - "./grafana.ini:/etc/grafana/grafana.ini" + - "./grafana-datasource.yaml:/etc/grafana/provisioning/datasources/grafana-datasource.yaml" + - "./grafana-dashboards.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboards.yaml" + - "./dashboards:/dashboards" + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/3001; exit $$?;' + interval: 1s + timeout: 5s + retries: 5 + prometheus: + image: "prom/prometheus:latest" + command: + - "--config.file=/etc/prometheus/prometheus.yaml" + - "--storage.tsdb.path=/prometheus" + - "--web.console.libraries=/usr/share/prometheus/console_libraries" + - "--web.console.templates=/usr/share/prometheus/consoles" + - "--web.listen-address=0.0.0.0:9090" + - "--web.enable-remote-write-receiver" + - "--storage.tsdb.retention.time=1d" + expose: + - "9090" + ports: + - "9090:9090" + volumes: + - "prometheus:/prometheus" + - "./prometheus.yaml:/etc/prometheus/prometheus.yaml" + + tempo: + image: grafana/tempo:latest + command: [ "-config.file=/etc/tempo.yaml" ] + volumes: + - ./tempo.yaml:/etc/tempo.yaml + - tempo:/tmp/tempo + ports: + - "3200" # tempo + - "4317" # otlp grpc + + + otel-collector: + image: "otel/opentelemetry-collector-contrib" + command: + - "--config=/etc/otel-collector.yaml" + expose: + - "4317" + ports: + - "4317:4317" # otlp grpc + - "9090" # for prometheus + volumes: ["./otel-collector.yaml:/etc/otel-collector.yaml"] + +volumes: + grafana: + external: false + prometheus: + external: false + tempo: + external: false diff --git a/dev/docker/grafana-dashboards.yaml b/dev/docker/grafana-dashboards.yaml new file mode 100644 index 00000000..3133367b --- /dev/null +++ b/dev/docker/grafana-dashboards.yaml @@ -0,0 +1,15 @@ +apiVersion: 1 + +providers: + - name: 'surrealdb-grafana' + orgId: 1 + folder: '' + folderUid: '' + type: file + disableDeletion: false + updateIntervalSeconds: 1 + allowUiUpdates: true + options: + path: /dashboards + foldersFromFilesStructure: false + \ No newline at end of file diff --git a/dev/docker/grafana-datasource.yaml b/dev/docker/grafana-datasource.yaml new file mode 100644 index 00000000..a4b08679 --- /dev/null +++ b/dev/docker/grafana-datasource.yaml @@ -0,0 +1,30 @@ +apiVersion: 1 +deleteDatasources: + - name: Prometheus + - name: Tempo +datasources: + - name: Prometheus + type: prometheus + access: proxy + url: http://prometheus:9090 + withCredentials: false + isDefault: true + tlsAuth: false + tlsAuthWithCACert: false + version: 1 + editable: true + - name: Tempo + type: tempo + access: proxy + orgId: 1 + url: http://tempo:3200 + basicAuth: false + isDefault: false + version: 1 + editable: false + apiVersion: 1 + uid: tempo + jsonData: + httpMethod: GET + serviceMap: + datasourceUid: prometheus diff --git a/dev/docker/grafana.ini b/dev/docker/grafana.ini new file mode 100644 index 00000000..593167c0 --- /dev/null +++ b/dev/docker/grafana.ini @@ -0,0 +1,6 @@ +[server] +http_addr = 0.0.0.0 +http_port = 3000 + +[users] +default_theme = light diff --git a/dev/docker/otel-collector.yaml b/dev/docker/otel-collector.yaml new file mode 100644 index 00000000..6de8ab60 --- /dev/null +++ b/dev/docker/otel-collector.yaml @@ -0,0 +1,36 @@ +receivers: + otlp: + protocols: + grpc: + +exporters: + otlp: + endpoint: 'tempo:4317' + tls: + insecure: true + + prometheus: + endpoint: ':9090' + send_timestamps: true + metric_expiration: 60m + resource_to_telemetry_conversion: + enabled: true + + logging: # add to a pipeline for debugging + loglevel: debug + +# processors: +# batch: +# timeout: 1s +# span: +# name: +# from_attributes: ["name"] + +service: + pipelines: + traces: + receivers: [otlp] + exporters: [otlp, logging] + metrics: + receivers: [otlp] + exporters: [prometheus] diff --git a/dev/docker/prometheus.yaml b/dev/docker/prometheus.yaml new file mode 100644 index 00000000..099d52ad --- /dev/null +++ b/dev/docker/prometheus.yaml @@ -0,0 +1,17 @@ +global: + scrape_interval: 5s + evaluation_interval: 10s + +scrape_configs: + - job_name: prometheus + static_configs: + - targets: ["prometheus:9500"] + + - job_name: 'tempo' + static_configs: + - targets: ["tempo:3200"] + + - job_name: otel-collector + static_configs: + # Scrap the SurrealDB metrics sent to OpenTelemetry collector + - targets: ["otel-collector:9090"] diff --git a/dev/docker/tempo.yaml b/dev/docker/tempo.yaml new file mode 100644 index 00000000..4ac55481 --- /dev/null +++ b/dev/docker/tempo.yaml @@ -0,0 +1,23 @@ +server: + http_listen_port: 3200 + +distributor: + receivers: + otlp: + protocols: + grpc: + +ingester: + max_block_duration: 5m # cut the headblock when this much time passes. this is being set for dev purposes and should probably be left alone normally + +compactor: + compaction: + block_retention: 1h # overall Tempo trace retention. set for dev purposes + +storage: + trace: + backend: local + wal: + path: /tmp/tempo/wal + local: + path: /tmp/tempo/blocks diff --git a/doc/TELEMETRY.md b/doc/TELEMETRY.md new file mode 100644 index 00000000..82e4d2ce --- /dev/null +++ b/doc/TELEMETRY.md @@ -0,0 +1,16 @@ +# Telemetry + +SurrealDB leverages the tracing and opentelemetry libraries to instrument the code. + +Both metrics and traces are pushed to an OTEL compatible receiver. + +For local development, you can start the observability stack defined in `dev/docker`. It spins up an instance of Opentelemetry collector, Grafana, Prometheus and Tempo: + +``` +$ docker-compose -f dev/docker/compose.yaml up -d +$ SURREAL_TRACING_TRACER=otlp OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317" surreal start +``` + +Now you can use the SurrealDB server and see the telemetry data opening this URL in the browser: http://localhost:3000 + +To login into Grafana, use the default user `admin` and password `admin`. diff --git a/src/cli/backup.rs b/src/cli/backup.rs index f5e07ddb..a464d10d 100644 --- a/src/cli/backup.rs +++ b/src/cli/backup.rs @@ -37,7 +37,7 @@ pub async fn init( }: BackupCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("error").init(); + crate::telemetry::builder().with_log_level("error").init(); // Process the source->destination response let into_local = into.ends_with(".db"); diff --git a/src/cli/export.rs b/src/cli/export.rs index 608f2121..7d3c7774 100644 --- a/src/cli/export.rs +++ b/src/cli/export.rs @@ -38,7 +38,7 @@ pub async fn init( }: ExportCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("error").init(); + crate::telemetry::builder().with_log_level("error").init(); let root = Root { username: &username, diff --git a/src/cli/import.rs b/src/cli/import.rs index 0e130581..b12392f7 100644 --- a/src/cli/import.rs +++ b/src/cli/import.rs @@ -36,7 +36,7 @@ pub async fn init( }: ImportCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("error").init(); + crate::telemetry::builder().with_log_level("error").init(); let root = Root { username: &username, diff --git a/src/cli/isready.rs b/src/cli/isready.rs index 8c681711..4ee517e6 100644 --- a/src/cli/isready.rs +++ b/src/cli/isready.rs @@ -17,7 +17,7 @@ pub async fn init( }: IsReadyCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("error").init(); + crate::telemetry::builder().with_log_level("error").init(); // Connect to the database engine connect(endpoint).await?; println!("OK"); diff --git a/src/cli/sql.rs b/src/cli/sql.rs index 1bbf716c..20afde9f 100644 --- a/src/cli/sql.rs +++ b/src/cli/sql.rs @@ -49,7 +49,7 @@ pub async fn init( }: SqlCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("warn").init(); + crate::telemetry::builder().with_log_level("warn").init(); let root = Root { username: &username, diff --git a/src/cli/start.rs b/src/cli/start.rs index 0eae8745..5862acb1 100644 --- a/src/cli/start.rs +++ b/src/cli/start.rs @@ -101,7 +101,7 @@ pub async fn init( }: StartCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_filter(log).init(); + crate::telemetry::builder().with_filter(log).init(); // Check if a banner should be outputted if !no_banner { diff --git a/src/cli/upgrade.rs b/src/cli/upgrade.rs index bb66b902..cb7ccddc 100644 --- a/src/cli/upgrade.rs +++ b/src/cli/upgrade.rs @@ -47,7 +47,7 @@ impl UpgradeCommandArguments { pub async fn init(args: UpgradeCommandArguments) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("error").init(); + crate::telemetry::builder().with_log_level("error").init(); // Upgrading overwrites the existing executable let exe = std::env::current_exe()?; diff --git a/src/cli/validator/parser/env_filter.rs b/src/cli/validator/parser/env_filter.rs index 9a2953ba..47240a38 100644 --- a/src/cli/validator/parser/env_filter.rs +++ b/src/cli/validator/parser/env_filter.rs @@ -31,6 +31,10 @@ impl TypedValueParser for CustomEnvFilterParser { arg: Option<&clap::Arg>, value: &std::ffi::OsStr, ) -> Result { + if let Ok(dirs) = std::env::var("RUST_LOG") { + return Ok(CustomEnvFilter(EnvFilter::builder().parse_lossy(dirs))); + } + let inner = NonEmptyStringValueParser::new(); let v = inner.parse_ref(cmd, arg, value)?; let filter = (match v.as_str() { diff --git a/src/cli/version.rs b/src/cli/version.rs index d89e3ff1..8571f785 100644 --- a/src/cli/version.rs +++ b/src/cli/version.rs @@ -18,7 +18,7 @@ pub async fn init( }: VersionCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging - crate::o11y::builder().with_log_level("error").init(); + crate::telemetry::builder().with_log_level("error").init(); // Print server version if endpoint supplied else CLI version if let Some(e) = endpoint { // Print remote server version diff --git a/src/err/mod.rs b/src/err/mod.rs index d0007009..78193cb1 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -1,4 +1,7 @@ +use axum::response::{IntoResponse, Response}; +use axum::Json; use base64::DecodeError as Base64Error; +use http::StatusCode; use reqwest::Error as ReqwestError; use serde::Serialize; use serde_cbor::error::Error as CborError; @@ -6,6 +9,7 @@ use serde_json::error::Error as JsonError; use serde_pack::encode::Error as PackError; use std::io::Error as IoError; use std::string::FromUtf8Error as Utf8Error; +use surrealdb::error::Db as SurrealDbError; use surrealdb::Error as SurrealError; use thiserror::Error; @@ -51,8 +55,6 @@ pub enum Error { Remote(#[from] ReqwestError), } -impl warp::reject::Reject for Error {} - impl From for String { fn from(e: Error) -> String { e.to_string() @@ -85,3 +87,56 @@ impl Serialize for Error { serializer.serialize_str(self.to_string().as_str()) } } + +#[derive(Serialize)] +pub(super) struct Message { + code: u16, + details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + information: Option, +} + +impl IntoResponse for Error { + fn into_response(self) -> Response { + match self { + err @ Error::InvalidAuth | err @ Error::Db(SurrealError::Db(SurrealDbError::InvalidAuth)) => ( + StatusCode::UNAUTHORIZED, + Json(Message { + code: 401, + details: Some("Authentication failed".to_string()), + description: Some("Your authentication details are invalid. Reauthenticate using valid authentication parameters.".to_string()), + information: Some(err.to_string()), + }) + ), + Error::InvalidType => ( + StatusCode::UNSUPPORTED_MEDIA_TYPE, + Json(Message { + code: 415, + details: Some("Unsupported media type".to_string()), + description: Some("The request needs to adhere to certain constraints. Refer to the documentation for supported content types.".to_string()), + information: None, + }), + ), + Error::InvalidStorage => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(Message { + code: 500, + details: Some("Health check failed".to_string()), + description: Some("The database health check for this instance failed. There was an issue with the underlying storage engine.".to_string()), + information: Some(self.to_string()), + }), + ), + _ => ( + StatusCode::BAD_REQUEST, + Json(Message { + code: 400, + details: Some("Request problems detected".to_string()), + description: Some("There is a problem with your request. Refer to the documentation for further information.".to_string()), + information: Some(self.to_string()), + }), + ), + }.into_response() + } +} diff --git a/src/iam/mod.rs b/src/iam/mod.rs index 09831b1c..5c6b801b 100644 --- a/src/iam/mod.rs +++ b/src/iam/mod.rs @@ -3,8 +3,6 @@ pub mod verify; use crate::cli::CF; use crate::err::Error; -pub const BASIC: &str = "Basic "; - pub async fn init() -> Result<(), Error> { // Get local copy of options let opt = CF.get().unwrap(); diff --git a/src/iam/verify.rs b/src/iam/verify.rs index 027c8c63..1bae0278 100644 --- a/src/iam/verify.rs +++ b/src/iam/verify.rs @@ -1,76 +1,63 @@ use crate::cli::CF; use crate::dbs::DB; use crate::err::Error; -use crate::iam::BASIC; use argon2::password_hash::{PasswordHash, PasswordVerifier}; use argon2::Argon2; use std::sync::Arc; use surrealdb::dbs::Auth; use surrealdb::dbs::Session; -use surrealdb::iam::base::{Engine, BASE64}; -pub async fn basic(session: &mut Session, auth: String) -> Result<(), Error> { - // Log the authentication type - trace!("Attempting basic authentication"); - // Retrieve just the auth data - let auth = auth.trim_start_matches(BASIC).trim(); +pub async fn basic(session: &mut Session, user: &str, pass: &str) -> Result<(), Error> { // Get a database reference let kvs = DB.get().unwrap(); // Get the config options let opts = CF.get().unwrap(); - // Decode the encoded auth data - let auth = BASE64.decode(auth)?; - // Convert the auth data to String - let auth = String::from_utf8(auth)?; - // Split the auth data into user and pass - if let Some((user, pass)) = auth.split_once(':') { - // Check that the details are not empty - if user.is_empty() || pass.is_empty() { - return Err(Error::InvalidAuth); + + // Check that the details are not empty + if user.is_empty() || pass.is_empty() { + return Err(Error::InvalidAuth); + } + // Check if this is root authentication + if let Some(root) = &opts.pass { + if user == opts.user && pass == root { + // Log the authentication type + debug!("Authenticated as super user"); + // Store the authentication data + session.au = Arc::new(Auth::Kv); + return Ok(()); } - // Check if this is root authentication - if let Some(root) = &opts.pass { - if user == opts.user && pass == root { - // Log the authentication type - debug!("Authenticated as super user"); + } + // Check if this is NS authentication + if let Some(ns) = &session.ns { + // Create a new readonly transaction + let mut tx = kvs.transaction(false, false).await?; + // Check if the supplied NS Login exists + if let Ok(nl) = tx.get_nl(ns, user).await { + // Compute the hash and verify the password + let hash = PasswordHash::new(&nl.hash).unwrap(); + if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() { + // Log the successful namespace authentication + debug!("Authenticated as namespace user: {}", user); // Store the authentication data - session.au = Arc::new(Auth::Kv); + session.au = Arc::new(Auth::Ns(ns.to_owned())); return Ok(()); } - } - // Check if this is NS authentication - if let Some(ns) = &session.ns { - // Create a new readonly transaction - let mut tx = kvs.transaction(false, false).await?; - // Check if the supplied NS Login exists - if let Ok(nl) = tx.get_nl(ns, user).await { + }; + // Check if this is DB authentication + if let Some(db) = &session.db { + // Check if the supplied DB Login exists + if let Ok(dl) = tx.get_dl(ns, db, user).await { // Compute the hash and verify the password - let hash = PasswordHash::new(&nl.hash).unwrap(); + let hash = PasswordHash::new(&dl.hash).unwrap(); if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() { // Log the successful namespace authentication debug!("Authenticated as namespace user: {}", user); // Store the authentication data - session.au = Arc::new(Auth::Ns(ns.to_owned())); + session.au = Arc::new(Auth::Db(ns.to_owned(), db.to_owned())); return Ok(()); } }; - // Check if this is DB authentication - if let Some(db) = &session.db { - // Check if the supplied DB Login exists - if let Ok(dl) = tx.get_dl(ns, db, user).await { - // Compute the hash and verify the password - let hash = PasswordHash::new(&dl.hash).unwrap(); - if Argon2::default().verify_password(pass.as_ref(), &hash).is_ok() { - // Log the successful namespace authentication - debug!("Authenticated as database user: {}", user); - // Store the authentication data - session.au = Arc::new(Auth::Db(ns.to_owned(), db.to_owned())); - return Ok(()); - } - }; - } } } - // There was an auth error Err(Error::InvalidAuth) } diff --git a/src/main.rs b/src/main.rs index 53e4b83f..1991fb94 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,9 +27,9 @@ mod err; mod iam; #[cfg(feature = "has-storage")] mod net; -mod o11y; #[cfg(feature = "has-storage")] mod rpc; +mod telemetry; use std::future::Future; use std::process::ExitCode; diff --git a/src/net/auth.rs b/src/net/auth.rs new file mode 100644 index 00000000..8efbd52b --- /dev/null +++ b/src/net/auth.rs @@ -0,0 +1,106 @@ +use axum::{ + body::{boxed, Body, BoxBody}, + headers::{ + authorization::{Basic, Bearer}, + Authorization, Origin, + }, + Extension, RequestPartsExt, TypedHeader, +}; +use futures_util::future::BoxFuture; +use http::{request::Parts, StatusCode}; +use hyper::{Request, Response}; +use surrealdb::{dbs::Session, iam::verify::token}; +use tower_http::auth::AsyncAuthorizeRequest; + +use crate::{dbs::DB, err::Error, iam::verify::basic}; + +use super::{client_ip::ExtractClientIP, AppState}; + +/// +/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait. +/// It is used to authorize requests to SurrealDB using Basic or Token authentication. +/// +/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer: +/// +/// ```rust +/// use tower_http::auth::RequireAuthorizationLayer; +/// use surrealdb::net::SurrealAuth; +/// use axum::Router; +/// +/// let auth = RequireAuthorizationLayer::new(SurrealAuth); +/// +/// let app = Router::new() +/// .route("/version", get(|| async { "0.1.0" })) +/// .layer(auth); +/// ``` +#[derive(Clone, Copy)] +pub(super) struct SurrealAuth; + +impl AsyncAuthorizeRequest for SurrealAuth +where + B: Send + Sync + 'static, +{ + type RequestBody = B; + type ResponseBody = BoxBody; + type Future = BoxFuture<'static, Result, Response>>; + + fn authorize(&mut self, request: Request) -> Self::Future { + Box::pin(async { + let (mut parts, body) = request.into_parts(); + match check_auth(&mut parts).await { + Ok(sess) => { + parts.extensions.insert(sess); + Ok(Request::from_parts(parts, body)) + } + Err(err) => { + let unauthorized_response = Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(boxed(Body::from(err.to_string()))) + .unwrap(); + Err(unauthorized_response) + } + } + }) + } +} + +async fn check_auth(parts: &mut Parts) -> Result { + let kvs = DB.get().unwrap(); + + let or = if let Ok(or) = parts.extract::>().await { + if !or.is_null() { + Some(or.to_string()) + } else { + None + } + } else { + None + }; + + let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader + let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader + let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader + + let Extension(state) = parts.extract::>().await.map_err(|err| { + tracing::error!("Error extracting the app state: {:?}", err); + Error::InvalidAuth + })?; + let ExtractClientIP(ip) = + parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None)); + + // Create session + #[rustfmt::skip] + let mut session = Session { ip, or, id, ns, db, ..Default::default() }; + + // If Basic authentication data was supplied + if let Ok(au) = parts.extract::>>().await { + basic(&mut session, au.username(), au.password()).await?; + }; + + // If Token authentication data was supplied + if let Ok(au) = parts.extract::>>().await { + token(kvs, &mut session, au.token().into()).await?; + }; + + Ok(session) +} diff --git a/src/net/client_ip.rs b/src/net/client_ip.rs index 00e9fecb..aa9a3803 100644 --- a/src/net/client_ip.rs +++ b/src/net/client_ip.rs @@ -1,10 +1,21 @@ -use crate::cli::CF; +use axum::async_trait; +use axum::extract::ConnectInfo; +use axum::extract::FromRef; +use axum::extract::FromRequestParts; +use axum::middleware::Next; +use axum::response::Response; +use axum::Extension; +use axum::RequestPartsExt; use clap::ValueEnum; -use std::net::IpAddr; +use http::request::Parts; +use http::Request; +use http::StatusCode; use std::net::SocketAddr; -use warp::Filter; + +use super::AppState; // TODO: Support Forwarded, X-Forwarded-For headers. +// Get inspiration from https://github.com/imbolc/axum-client-ip or simply use it #[derive(ValueEnum, Clone, Copy, Debug)] pub enum ClientIp { /// Don't use client IP @@ -25,31 +36,95 @@ pub enum ClientIp { XRealIp, } -/// Creates an string represenation of the client's IP address -pub fn build() -> impl Filter,), Error = warp::Rejection> + Clone { - // Get configured client IP source - let client_ip = CF.get().unwrap().client_ip; - // Enable on any path - let conf = warp::any(); - // Add raw remote IP address - let conf = - conf.and(warp::filters::addr::remote().and_then(move |s: Option| async move { - match client_ip { - ClientIp::None => Ok(None), - ClientIp::Socket => Ok(s.map(|s| s.ip())), - // Move on to parsing selected IP header. - _ => Err(warp::reject::reject()), - } - })); - // Add selected IP header - let conf = conf.or(warp::header::optional::(match client_ip { - ClientIp::CfConectingIp => "Cf-Connecting-IP", - ClientIp::FlyClientIp => "Fly-Client-IP", - ClientIp::TrueClientIP => "True-Client-IP", - ClientIp::XRealIp => "X-Real-IP", - // none and socket are already handled so this will never be used - _ => "unreachable", - })); - // Join the two filters - conf.unify().map(|ip: Option| ip.map(|ip| ip.to_string())) +impl std::fmt::Display for ClientIp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ClientIp::None => write!(f, "None"), + ClientIp::Socket => write!(f, "Socket"), + ClientIp::CfConectingIp => write!(f, "CF-Connecting-IP"), + ClientIp::FlyClientIp => write!(f, "Fly-Client-IP"), + ClientIp::TrueClientIP => write!(f, "True-Client-IP"), + ClientIp::XRealIp => write!(f, "X-Real-IP"), + } + } +} + +impl ClientIp { + fn is_header(&self) -> bool { + match self { + ClientIp::None => false, + ClientIp::Socket => false, + ClientIp::CfConectingIp => true, + ClientIp::FlyClientIp => true, + ClientIp::TrueClientIP => true, + ClientIp::XRealIp => true, + } + } +} + +pub(super) struct ExtractClientIP(pub Option); + +#[async_trait] +impl FromRequestParts for ExtractClientIP +where + AppState: FromRef, + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let app_state = AppState::from_ref(state); + + let res = match app_state.client_ip { + ClientIp::None => ExtractClientIP(None), + ClientIp::Socket => { + if let Ok(ConnectInfo(addr)) = + ConnectInfo::::from_request_parts(parts, state).await + { + ExtractClientIP(Some(addr.ip().to_string())) + } else { + ExtractClientIP(None) + } + } + // Get the IP from the corresponding header + var if var.is_header() => { + if let Some(ip) = parts.headers.get(var.to_string()) { + ip.to_str().map(|s| ExtractClientIP(Some(s.to_string()))).unwrap_or_else( + |err| { + debug!("Invalid header value for {}: {}", var, err); + ExtractClientIP(None) + }, + ) + } else { + ExtractClientIP(None) + } + } + _ => { + warn!("Unexpected ClientIp variant: {:?}", app_state.client_ip); + ExtractClientIP(None) + } + }; + + Ok(res) + } +} + +pub(super) async fn client_ip_middleware( + request: Request, + next: Next, +) -> Result +where + B: Send, +{ + let (mut parts, body) = request.into_parts(); + + if let Ok(Extension(state)) = parts.extract::>().await { + if let Ok(client_ip) = parts.extract_with_state::(&state).await { + parts.extensions.insert(client_ip); + } + } else { + trace!("No AppState found, skipping client_ip_middleware"); + } + + Ok(next.run(Request::from_parts(parts, body)).await) } diff --git a/src/net/export.rs b/src/net/export.rs index 460106d9..a4d6a7d6 100644 --- a/src/net/export.rs +++ b/src/net/export.rs @@ -1,21 +1,25 @@ use crate::dbs::DB; -use crate::err::Error; -use crate::net::session; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use axum::{response::Response, Extension}; use bytes::Bytes; +use http::StatusCode; +use http_body::Body as HttpBody; use hyper::body::Body; use surrealdb::dbs::Session; -use warp::Filter; -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path("export") - .and(warp::path::end()) - .and(warp::get()) - .and(session::build()) - .and_then(handler) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new().route("/export", get(handler)) } -async fn handler(session: Session) -> Result { +async fn handler( + Extension(session): Extension, +) -> Result { // Check the permissions match session.au.is_db() { true => { @@ -24,12 +28,12 @@ async fn handler(session: Session) -> Result // Extract the NS header value let nsv = match session.ns { Some(ns) => ns, - None => return Err(warp::reject::custom(Error::NoNsHeader)), + None => return Err((StatusCode::BAD_REQUEST, "No namespace provided")), }; // Extract the DB header value let dbv = match session.db { Some(db) => db, - None => return Err(warp::reject::custom(Error::NoDbHeader)), + None => return Err((StatusCode::BAD_REQUEST, "No database provided")), }; // Create a chunked response let (mut chn, bdy) = Body::channel(); @@ -44,9 +48,9 @@ async fn handler(session: Session) -> Result } }); // Return the chunked body - Ok(warp::reply::Response::new(bdy)) + Ok(Response::builder().status(StatusCode::OK).body(bdy).unwrap()) } - // There was an error with permissions - _ => Err(warp::reject::custom(Error::InvalidAuth)), + // The user does not have the correct permissions + _ => Err((StatusCode::FORBIDDEN, "Invalid permissions")), } } diff --git a/src/net/fail.rs b/src/net/fail.rs deleted file mode 100644 index 8c34c87a..00000000 --- a/src/net/fail.rs +++ /dev/null @@ -1,126 +0,0 @@ -use crate::err::Error; -use serde::Serialize; -use warp::http::StatusCode; - -#[derive(Serialize)] -struct Message { - code: u16, - details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - information: Option, -} - -pub async fn recover(err: warp::Rejection) -> Result { - if let Some(err) = err.find::() { - match err { - Error::InvalidAuth => Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 403, - details: Some("Authentication failed".to_string()), - description: Some("Your authentication details are invalid. Reauthenticate using valid authentication parameters.".to_string()), - information: Some(err.to_string()), - }), - StatusCode::FORBIDDEN, - )), - Error::InvalidType => Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 415, - details: Some("Unsupported media type".to_string()), - description: Some("The request needs to adhere to certain constraints. Refer to the documentation for supported content types.".to_string()), - information: None, - }), - StatusCode::UNSUPPORTED_MEDIA_TYPE, - )), - Error::InvalidStorage => Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 500, - details: Some("Health check failed".to_string()), - description: Some("The database health check for this instance failed. There was an issue with the underlying storage engine.".to_string()), - information: Some(err.to_string()), - }), - StatusCode::INTERNAL_SERVER_ERROR, - )), - _ => Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 400, - details: Some("Request problems detected".to_string()), - description: Some("There is a problem with your request. Refer to the documentation for further information.".to_string()), - information: Some(err.to_string()), - }), - StatusCode::BAD_REQUEST, - )) - } - } else if err.is_not_found() { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 404, - details: Some("Requested resource not found".to_string()), - description: Some("The requested resource does not exist. Check that you have entered the url correctly.".to_string()), - information: None, - }), - StatusCode::NOT_FOUND, - )) - } else if err.find::().is_some() { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 412, - details: Some("Request problems detected".to_string()), - description: Some("The request appears to be missing a required header. Refer to the documentation for request requirements.".to_string()), - information: None, - }), - StatusCode::PRECONDITION_FAILED, - )) - } else if err.find::().is_some() { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 413, - details: Some("Payload too large".to_string()), - description: Some("The request has exceeded the maximum payload size. Refer to the documentation for the request limitations.".to_string()), - information: None, - }), - StatusCode::PAYLOAD_TOO_LARGE, - )) - } else if err.find::().is_some() { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 501, - details: Some("Not implemented".to_string()), - description: Some("The server either does not recognize the query, or it lacks the ability to fulfill the request.".to_string()), - information: None, - }), - StatusCode::NOT_IMPLEMENTED, - )) - } else if err.find::().is_some() { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 501, - details: Some("Not implemented".to_string()), - description: Some("The server either does not recognize a request header, or it lacks the ability to fulfill the request.".to_string()), - information: None, - }), - StatusCode::NOT_IMPLEMENTED, - )) - } else if err.find::().is_some() { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 405, - details: Some("Requested method not allowed".to_string()), - description: Some("The requested http method is not allowed for this resource. Refer to the documentation for allowed methods.".to_string()), - information: None, - }), - StatusCode::METHOD_NOT_ALLOWED, - )) - } else { - Ok(warp::reply::with_status( - warp::reply::json(&Message { - code: 500, - details: Some("Internal server error".to_string()), - description: Some("There was a problem with our servers, and we have been notified. Refer to the documentation for further information".to_string()), - information: None, - }), - StatusCode::INTERNAL_SERVER_ERROR, - )) - } -} diff --git a/src/net/head.rs b/src/net/head.rs deleted file mode 100644 index 13e6e7fa..00000000 --- a/src/net/head.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::cnf::PKG_NAME; -use crate::cnf::PKG_VERSION; -use surrealdb::cnf::SERVER_NAME; - -const ID: &str = "ID"; -const NS: &str = "NS"; -const DB: &str = "DB"; -const SERVER: &str = "Server"; -const VERSION: &str = "Version"; - -pub fn version() -> warp::filters::reply::WithHeader { - let val = format!("{PKG_NAME}-{}", *PKG_VERSION); - warp::reply::with::header(VERSION, val) -} - -pub fn server() -> warp::filters::reply::WithHeader { - warp::reply::with::header(SERVER, SERVER_NAME) -} - -pub fn cors() -> warp::filters::cors::Builder { - warp::cors() - .max_age(86400) - .allow_any_origin() - .allow_methods(vec![ - http::Method::GET, - http::Method::PUT, - http::Method::POST, - http::Method::PATCH, - http::Method::DELETE, - http::Method::OPTIONS, - ]) - .allow_headers(vec![ - http::header::ACCEPT, - http::header::AUTHORIZATION, - http::header::CONTENT_TYPE, - http::header::ORIGIN, - NS.parse().unwrap(), - DB.parse().unwrap(), - ID.parse().unwrap(), - ]) -} diff --git a/src/net/headers.rs b/src/net/headers.rs new file mode 100644 index 00000000..d37877a5 --- /dev/null +++ b/src/net/headers.rs @@ -0,0 +1,95 @@ +use crate::cnf::PKG_NAME; +use crate::cnf::PKG_VERSION; +use axum::headers; +use axum::headers::Header; +use http::HeaderName; +use http::HeaderValue; +use surrealdb::cnf::SERVER_NAME; +use tower_http::set_header::SetResponseHeaderLayer; + +pub(super) const ID: &str = "ID"; +pub(super) const NS: &str = "NS"; +pub(super) const DB: &str = "DB"; +const SERVER: &str = "server"; +const VERSION: &str = "version"; + +pub fn add_version_header() -> SetResponseHeaderLayer { + let val = format!("{PKG_NAME}-{}", *PKG_VERSION); + SetResponseHeaderLayer::if_not_present( + HeaderName::from_static(VERSION), + HeaderValue::try_from(val).unwrap(), + ) +} + +pub fn add_server_header() -> SetResponseHeaderLayer { + SetResponseHeaderLayer::if_not_present( + HeaderName::from_static(SERVER), + HeaderValue::try_from(SERVER_NAME).unwrap(), + ) +} + +/// Typed header implementation for the `Accept` header. +pub enum Accept { + TextPlain, + ApplicationJson, + ApplicationCbor, + ApplicationPack, + ApplicationOctetStream, + Surrealdb, +} + +impl std::fmt::Display for Accept { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Accept::TextPlain => write!(f, "text/plain"), + Accept::ApplicationJson => write!(f, "application/json"), + Accept::ApplicationCbor => write!(f, "application/cbor"), + Accept::ApplicationPack => write!(f, "application/pack"), + Accept::ApplicationOctetStream => write!(f, "application/octet-stream"), + Accept::Surrealdb => write!(f, "application/surrealdb"), + } + } +} + +impl Header for Accept { + fn name() -> &'static HeaderName { + &http::header::ACCEPT + } + + fn decode<'i, I>(values: &mut I) -> Result + where + I: Iterator, + { + let value = values.next().ok_or_else(headers::Error::invalid)?; + + match value.to_str().map_err(|_| headers::Error::invalid())? { + "text/plain" => Ok(Accept::TextPlain), + "application/json" => Ok(Accept::ApplicationJson), + "application/cbor" => Ok(Accept::ApplicationCbor), + "application/pack" => Ok(Accept::ApplicationPack), + "application/octet-stream" => Ok(Accept::ApplicationOctetStream), + "application/surrealdb" => Ok(Accept::Surrealdb), + // TODO: Support more (all?) mime-types + _ => Err(headers::Error::invalid()), + } + } + + fn encode(&self, values: &mut E) + where + E: Extend, + { + values.extend(std::iter::once(self.into())); + } +} + +impl From for HeaderValue { + fn from(value: Accept) -> Self { + HeaderValue::from(&value) + } +} + +impl From<&Accept> for HeaderValue { + fn from(value: &Accept) -> Self { + HeaderValue::from_str(value.to_string().as_str()).unwrap() + } +} diff --git a/src/net/health.rs b/src/net/health.rs index e706ec30..9285fca3 100644 --- a/src/net/health.rs +++ b/src/net/health.rs @@ -1,25 +1,31 @@ use crate::dbs::DB; use crate::err::Error; -use warp::Filter; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use http_body::Body as HttpBody; -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path("health").and(warp::path::end()).and(warp::get()).and_then(handler) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new().route("/health", get(handler)) } -async fn handler() -> Result { +async fn handler() -> impl IntoResponse { // Get the datastore reference let db = DB.get().unwrap(); // Attempt to open a transaction match db.transaction(false, false).await { // The transaction failed to start - Err(_) => Err(warp::reject::custom(Error::InvalidStorage)), + Err(_) => Err(Error::InvalidStorage), // The transaction was successful Ok(mut tx) => { // Cancel the transaction let _ = tx.cancel().await; // Return the response - Ok(warp::reply()) + Ok(()) } } } diff --git a/src/net/import.rs b/src/net/import.rs index 80deb9de..93a76eea 100644 --- a/src/net/import.rs +++ b/src/net/import.rs @@ -2,31 +2,39 @@ use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; -use crate::net::session; +use axum::extract::DefaultBodyLimit; +use axum::response::IntoResponse; +use axum::routing::post; +use axum::Extension; +use axum::Router; +use axum::TypedHeader; use bytes::Bytes; +use http_body::Body as HttpBody; use surrealdb::dbs::Session; -use warp::http; -use warp::Filter; +use tower_http::limit::RequestBodyLimitLayer; -const MAX: u64 = 1024 * 1024 * 1024 * 4; // 4 GiB +use super::headers::Accept; -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path("import") - .and(warp::path::end()) - .and(warp::post()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(session::build()) - .and_then(handler) +const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB + +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: std::error::Error + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new() + .route("/import", post(handler)) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)) } async fn handler( - output: String, + Extension(session): Extension, + maybe_output: Option>, sql: Bytes, - session: Session, -) -> Result { +) -> Result { // Check the permissions match session.au.is_db() { true => { @@ -36,22 +44,22 @@ async fn handler( let sql = bytes_to_utf8(&sql)?; // Execute the sql query in the database match db.execute(sql, &session, None).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // Return nothing - "application/octet-stream" => Ok(output::none()), + Some(Accept::ApplicationOctetStream) => Ok(output::none()), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - _ => Err(warp::reject::custom(Error::InvalidAuth)), + _ => Err(Error::InvalidAuth), } } diff --git a/src/net/index.rs b/src/net/index.rs deleted file mode 100644 index 583ca689..00000000 --- a/src/net/index.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::cnf; -use warp::http::Uri; -use warp::Filter; - -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path::end() - .and(warp::get()) - .map(|| warp::redirect::temporary(Uri::from_static(cnf::APP_ENDPOINT))) -} diff --git a/src/net/input.rs b/src/net/input.rs index a05ae433..bc0398aa 100644 --- a/src/net/input.rs +++ b/src/net/input.rs @@ -1,6 +1,6 @@ use crate::err::Error; use bytes::Bytes; -pub(crate) fn bytes_to_utf8(bytes: &Bytes) -> Result<&str, warp::Rejection> { - std::str::from_utf8(bytes).map_err(|_| warp::reject::custom(Error::Request)) +pub(crate) fn bytes_to_utf8(bytes: &Bytes) -> Result<&str, Error> { + std::str::from_utf8(bytes).map_err(|_| Error::Request) } diff --git a/src/net/key.rs b/src/net/key.rs index cf905c62..10d32bcc 100644 --- a/src/net/key.rs +++ b/src/net/key.rs @@ -2,145 +2,62 @@ use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; -use crate::net::params::{Param, Params}; -use crate::net::session; +use crate::net::params::Params; +use axum::extract::{DefaultBodyLimit, Path, Query}; +use axum::response::IntoResponse; +use axum::routing::options; +use axum::{Extension, Router, TypedHeader}; use bytes::Bytes; +use http_body::Body as HttpBody; use serde::Deserialize; use std::str; use surrealdb::dbs::Session; use surrealdb::sql::Value; -use warp::path; -use warp::Filter; +use tower_http::limit::RequestBodyLimitLayer; -const MAX: u64 = 1024 * 16; // 16 KiB +use super::headers::Accept; + +const MAX: usize = 1024 * 16; // 16 KiB #[derive(Default, Deserialize, Debug, Clone)] -struct Query { +struct QueryOptions { pub limit: Option, pub start: Option, } -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - // ------------------------------ - // Routes for OPTIONS - // ------------------------------ - - let base = warp::path("key"); - // Set opts method - let opts = base.and(warp::options()).map(warp::reply); - - // ------------------------------ - // Routes for a table - // ------------------------------ - - // Set select method - let select = warp::any() - .and(warp::get()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param).and(warp::path::end())) - .and(warp::query()) - .and(session::build()) - .and_then(select_all); - // Set create method - let create = warp::any() - .and(warp::post()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param).and(warp::path::end())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(create_all); - // Set update method - let update = warp::any() - .and(warp::put()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param).and(warp::path::end())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(update_all); - // Set modify method - let modify = warp::any() - .and(warp::patch()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param).and(warp::path::end())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(modify_all); - // Set delete method - let delete = warp::any() - .and(warp::delete()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param).and(warp::path::end())) - .and(warp::query()) - .and(session::build()) - .and_then(delete_all); - // Specify route - let all = select.or(create).or(update).or(modify).or(delete); - - // ------------------------------ - // Routes for a thing - // ------------------------------ - - // Set select method - let select = warp::any() - .and(warp::get()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param / Param).and(warp::path::end())) - .and(session::build()) - .and_then(select_one); - // Set create method - let create = warp::any() - .and(warp::post()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param / Param).and(warp::path::end())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(create_one); - // Set update method - let update = warp::any() - .and(warp::put()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param / Param).and(warp::path::end())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(update_one); - // Set modify method - let modify = warp::any() - .and(warp::patch()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param / Param).and(warp::path::end())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(modify_one); - // Set delete method - let delete = warp::any() - .and(warp::delete()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(path!("key" / Param / Param).and(warp::path::end())) - .and(warp::query()) - .and(session::build()) - .and_then(delete_one); - // Specify route - let one = select.or(create).or(update).or(modify).or(delete); - - // ------------------------------ - // All routes - // ------------------------------ - - // Specify route - opts.or(all).or(one) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: std::error::Error + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new() + .route( + "/key/:table", + options(|| async {}) + .get(select_all) + .post(create_all) + .put(update_all) + .patch(modify_all) + .delete(delete_all), + ) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)) + .merge( + Router::new() + .route( + "/key/:table/:key", + options(|| async {}) + .get(select_one) + .post(create_one) + .put(update_one) + .patch(modify_one) + .delete(delete_one), + ) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)), + ) } // ------------------------------ @@ -148,11 +65,11 @@ pub fn config() -> impl Filter Result { + Extension(session): Extension, + maybe_output: Option>, + Path(table): Path, + Query(query): Query, +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Specify the request statement @@ -167,28 +84,28 @@ async fn select_all( }; // Execute the query and return the result match db.execute(sql.as_str(), &session, Some(vars)).await { - Ok(ref res) => match output.as_ref() { + Ok(ref res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } async fn create_all( - output: String, - table: Param, + Extension(session): Extension, + maybe_output: Option>, + Path(table): Path, + Query(params): Query, body: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the HTTP request body @@ -206,31 +123,31 @@ async fn create_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - Err(_) => Err(warp::reject::custom(Error::Request)), + Err(_) => Err(Error::Request), } } async fn update_all( - output: String, - table: Param, + Extension(session): Extension, + maybe_output: Option>, + Path(table): Path, + Query(params): Query, body: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the HTTP request body @@ -248,31 +165,31 @@ async fn update_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - Err(_) => Err(warp::reject::custom(Error::Request)), + Err(_) => Err(Error::Request), } } async fn modify_all( - output: String, - table: Param, + Extension(session): Extension, + maybe_output: Option>, + Path(table): Path, + Query(params): Query, body: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the HTTP request body @@ -290,30 +207,30 @@ async fn modify_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - Err(_) => Err(warp::reject::custom(Error::Request)), + Err(_) => Err(Error::Request), } } async fn delete_all( - output: String, - table: Param, - params: Params, - session: Session, -) -> Result { + Extension(session): Extension, + maybe_output: Option>, + Path(table): Path, + Query(params): Query, +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Specify the request statement @@ -325,18 +242,18 @@ async fn delete_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } @@ -345,11 +262,10 @@ async fn delete_all( // ------------------------------ async fn select_one( - output: String, - table: Param, - id: Param, - session: Session, -) -> Result { + Extension(session): Extension, + maybe_output: Option>, + Path((table, id)): Path<(String, String)>, +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Specify the request statement @@ -366,29 +282,28 @@ async fn select_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } async fn create_one( - output: String, - table: Param, - id: Param, + Extension(session): Extension, + maybe_output: Option>, + Query(params): Query, + Path((table, id)): Path<(String, String)>, body: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the HTTP request body @@ -412,32 +327,31 @@ async fn create_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - Err(_) => Err(warp::reject::custom(Error::Request)), + Err(_) => Err(Error::Request), } } async fn update_one( - output: String, - table: Param, - id: Param, + Extension(session): Extension, + maybe_output: Option>, + Query(params): Query, + Path((table, id)): Path<(String, String)>, body: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the HTTP request body @@ -461,32 +375,31 @@ async fn update_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - Err(_) => Err(warp::reject::custom(Error::Request)), + Err(_) => Err(Error::Request), } } async fn modify_one( - output: String, - table: Param, - id: Param, + Extension(session): Extension, + maybe_output: Option>, + Query(params): Query, + Path((table, id)): Path<(String, String)>, body: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the HTTP request body @@ -510,31 +423,29 @@ async fn modify_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } - Err(_) => Err(warp::reject::custom(Error::Request)), + Err(_) => Err(Error::Request), } } async fn delete_one( - output: String, - table: Param, - id: Param, - params: Params, - session: Session, -) -> Result { + Extension(session): Extension, + maybe_output: Option>, + Path((table, id)): Path<(String, String)>, +) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Specify the request statement @@ -548,21 +459,20 @@ async fn delete_one( let vars = map! { String::from("table") => Value::from(table), String::from("id") => rid, - => params.parse() }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match output.as_ref() { + Ok(res) => match maybe_output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } diff --git a/src/net/log.rs b/src/net/log.rs deleted file mode 100644 index 1af10910..00000000 --- a/src/net/log.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::fmt; -use tracing::Level; - -struct OptFmt(Option); - -impl fmt::Display for OptFmt { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(ref t) = self.0 { - fmt::Display::fmt(t, f) - } else { - f.write_str("-") - } - } -} - -pub fn write() -> warp::filters::log::Log { - warp::log::custom(|info| { - event!( - Level::INFO, - "{} {} {} {:?} {} \"{}\" {:?}", - OptFmt(info.remote_addr()), - info.method(), - info.path(), - info.version(), - info.status().as_u16(), - OptFmt(info.user_agent()), - info.elapsed(), - ); - }) -} diff --git a/src/net/mod.rs b/src/net/mod.rs index 16c50f77..6e5698ac 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,105 +1,164 @@ +mod auth; pub mod client_ip; mod export; -mod fail; -mod head; +mod headers; mod health; mod import; -mod index; mod input; mod key; -mod log; mod output; mod params; mod rpc; -mod session; mod signals; mod signin; mod signup; mod sql; -mod status; mod sync; +mod tracer; mod version; +use axum::response::Redirect; +use axum::routing::get; +use axum::{middleware, Router}; +use axum_server::Handle; +use http::header; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tower::ServiceBuilder; +use tower_http::add_extension::AddExtensionLayer; +use tower_http::auth::AsyncRequireAuthorizationLayer; +use tower_http::cors::{Any, CorsLayer}; +use tower_http::request_id::MakeRequestUuid; +use tower_http::sensitive_headers::{ + SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer, +}; +use tower_http::trace::TraceLayer; +use tower_http::ServiceBuilderExt; + use crate::cli::CF; +use crate::cnf; use crate::err::Error; -use warp::Filter; +use crate::net::signals::graceful_shutdown; +use crate::telemetry::metrics::HttpMetricsLayer; +use axum_server::tls_rustls::RustlsConfig; + +const LOG: &str = "surrealdb::net"; + +/// +/// AppState is used to share data between routes. +/// +#[derive(Clone)] +struct AppState { + client_ip: client_ip::ClientIp, +} pub async fn init() -> Result<(), Error> { - // Setup web routes - let net = index::config() - // Version endpoint - .or(version::config()) - // Status endpoint - .or(status::config()) - // Health endpoint - .or(health::config()) - // Signup endpoint - .or(signup::config()) - // Signin endpoint - .or(signin::config()) - // Export endpoint - .or(export::config()) - // Import endpoint - .or(import::config()) - // Backup endpoint - .or(sync::config()) - // RPC query endpoint - .or(rpc::config()) - // SQL query endpoint - .or(sql::config()) - // API query endpoint - .or(key::config()) - // Catch all errors - .recover(fail::recover) - // End routes setup - ; - // Specify a generic version header - let net = net.with(head::version()); - // Specify a generic server header - let net = net.with(head::server()); - // Set cors headers on all requests - let net = net.with(head::cors()); - // Log all requests to the console - let net = net.with(log::write()); - // Trace requests - let net = net.with(warp::trace::request()); - // Get local copy of options let opt = CF.get().unwrap(); + let app_state = AppState { + client_ip: opt.client_ip, + }; + + // Specify headers to be obfuscated from all requests/responses + let headers: Arc<[_]> = Arc::new([ + header::AUTHORIZATION, + header::PROXY_AUTHORIZATION, + header::COOKIE, + header::SET_COOKIE, + ]); + + // Build the middleware to our service. + let service = ServiceBuilder::new() + .catch_panic() + .set_x_request_id(MakeRequestUuid) + .propagate_x_request_id() + .layer(AddExtensionLayer::new(app_state)) + .layer(middleware::from_fn(client_ip::client_ip_middleware)) + .layer(SetSensitiveRequestHeadersLayer::from_shared(Arc::clone(&headers))) + .layer( + TraceLayer::new_for_http() + .make_span_with(tracer::HttpTraceLayerHooks) + .on_request(tracer::HttpTraceLayerHooks) + .on_response(tracer::HttpTraceLayerHooks) + .on_failure(tracer::HttpTraceLayerHooks), + ) + .layer(HttpMetricsLayer) + .layer(SetSensitiveResponseHeadersLayer::from_shared(headers)) + .layer(AsyncRequireAuthorizationLayer::new(auth::SurrealAuth)) + .layer(headers::add_server_header()) + .layer(headers::add_version_header()) + .layer( + CorsLayer::new() + .allow_methods([ + http::Method::GET, + http::Method::PUT, + http::Method::POST, + http::Method::PATCH, + http::Method::DELETE, + http::Method::OPTIONS, + ]) + .allow_headers([ + http::header::ACCEPT, + http::header::AUTHORIZATION, + http::header::CONTENT_TYPE, + http::header::ORIGIN, + headers::NS.parse().unwrap(), + headers::DB.parse().unwrap(), + headers::ID.parse().unwrap(), + ]) + // allow requests from any origin + .allow_origin(Any) + .max_age(Duration::from_secs(86400)), + ); + + let axum_app = Router::new() + // Redirect until we provide a UI + .route("/", get(|| async { Redirect::temporary(cnf::APP_ENDPOINT) })) + .route("/status", get(|| async {})) + .merge(health::router()) + .merge(export::router()) + .merge(import::router()) + .merge(rpc::router()) + .merge(version::router()) + .merge(sync::router()) + .merge(sql::router()) + .merge(signin::router()) + .merge(signup::router()) + .merge(key::router()) + .layer(service); + info!("Starting web server on {}", &opt.bind); - if let (Some(c), Some(k)) = (&opt.crt, &opt.key) { - // Bind the server to the desired port - let (adr, srv) = warp::serve(net) - .tls() - .cert_path(c) - .key_path(k) - .bind_with_graceful_shutdown(opt.bind, async move { - // Capture the shutdown signals and log that the graceful shutdown has started - let result = signals::listen().await.expect("Failed to listen to shutdown signal"); - info!("{} received. Start graceful shutdown...", result); - }); - // Log the server startup status - info!("Started web server on {}", &adr); - // Run the server forever - srv.await; - // Log the server shutdown event - info!("Shutdown complete. Bye!") + // Setup the graceful shutdown with no timeout + let handle = Handle::new(); + graceful_shutdown(handle.clone(), None); + + if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) { + // configure certificate and private key used by https + let tls = RustlsConfig::from_pem_file(cert, key).await.unwrap(); + + let server = axum_server::bind_rustls(opt.bind, tls); + + info!(target: LOG, "Started web server on {}", &opt.bind); + + server + .handle(handle) + .serve(axum_app.into_make_service_with_connect_info::()) + .await?; } else { - // Bind the server to the desired port - let (adr, srv) = warp::serve(net).bind_with_graceful_shutdown(opt.bind, async move { - // Capture the shutdown signals and log that the graceful shutdown has started - let result = signals::listen().await.expect("Failed to listen to shutdown signal"); - info!("{} received. Start graceful shutdown...", result); - }); - // Log the server startup status - info!("Started web server on {}", &adr); - // Run the server forever - srv.await; - // Log the server shutdown event - info!("Shutdown complete. Bye!") + let server = axum_server::bind(opt.bind); + + info!(target: LOG, "Started web server on {}", &opt.bind); + + server + .handle(handle) + .serve(axum_app.into_make_service_with_connect_info::()) + .await?; }; + info!(target: LOG, "Web server stopped. Bye!"); + Ok(()) } diff --git a/src/net/output.rs b/src/net/output.rs index 7d6a0c89..6346d32e 100644 --- a/src/net/output.rs +++ b/src/net/output.rs @@ -1,9 +1,12 @@ +use axum::response::{IntoResponse, Response}; use http::header::{HeaderValue, CONTENT_TYPE}; use http::StatusCode; use serde::Serialize; use serde_json::Value as Json; use surrealdb::sql; +use super::headers::Accept; + pub enum Output { None, Fail, @@ -67,38 +70,23 @@ pub fn simplify(v: T) -> Json { sql::to_value(v).unwrap().into() } -impl warp::Reply for Output { - fn into_response(self) -> warp::reply::Response { +impl IntoResponse for Output { + fn into_response(self) -> Response { match self { Output::Text(v) => { - let mut res = warp::reply::Response::new(v.into()); - let con = HeaderValue::from_static("text/plain"); - res.headers_mut().insert(CONTENT_TYPE, con); - res + ([(CONTENT_TYPE, HeaderValue::from(Accept::TextPlain))], v).into_response() } Output::Json(v) => { - let mut res = warp::reply::Response::new(v.into()); - let con = HeaderValue::from_static("application/json"); - res.headers_mut().insert(CONTENT_TYPE, con); - res + ([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationJson))], v).into_response() } Output::Cbor(v) => { - let mut res = warp::reply::Response::new(v.into()); - let con = HeaderValue::from_static("application/cbor"); - res.headers_mut().insert(CONTENT_TYPE, con); - res + ([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationCbor))], v).into_response() } Output::Pack(v) => { - let mut res = warp::reply::Response::new(v.into()); - let con = HeaderValue::from_static("application/pack"); - res.headers_mut().insert(CONTENT_TYPE, con); - res + ([(CONTENT_TYPE, HeaderValue::from(Accept::ApplicationPack))], v).into_response() } Output::Full(v) => { - let mut res = warp::reply::Response::new(v.into()); - let con = HeaderValue::from_static("application/surrealdb"); - res.headers_mut().insert(CONTENT_TYPE, con); - res + ([(CONTENT_TYPE, HeaderValue::from(Accept::Surrealdb))], v).into_response() } Output::None => StatusCode::OK.into_response(), Output::Fail => StatusCode::INTERNAL_SERVER_ERROR.into_response(), diff --git a/src/net/rpc.rs b/src/net/rpc.rs index 8e875ad0..c8ec85a4 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -5,13 +5,16 @@ use crate::cnf::PKG_VERSION; use crate::cnf::WEBSOCKET_PING_FREQUENCY; use crate::dbs::DB; use crate::err::Error; -use crate::net::session; use crate::rpc::args::Take; use crate::rpc::paths::{ID, METHOD, PARAMS}; use crate::rpc::res; use crate::rpc::res::Failure; use crate::rpc::res::Output; +use axum::routing::get; +use axum::Extension; +use axum::Router; use futures::{SinkExt, StreamExt}; +use http_body::Body as HttpBody; use once_cell::sync::Lazy; use std::collections::BTreeMap; use std::collections::HashMap; @@ -27,8 +30,11 @@ use surrealdb::sql::Value; use tokio::sync::RwLock; use tracing::instrument; use uuid::Uuid; -use warp::ws::{Message, WebSocket, Ws}; -use warp::Filter; + +use axum::{ + extract::ws::{Message, WebSocket, WebSocketUpgrade}, + response::IntoResponse, +}; // Mapping of WebSocketID to WebSocket type WebSockets = RwLock>>; @@ -38,17 +44,22 @@ type LiveQueries = RwLock>; static WEBSOCKETS: Lazy = Lazy::new(WebSockets::default); static LIVE_QUERIES: Lazy = Lazy::new(LiveQueries::default); -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path("rpc") - .and(warp::path::end()) - .and(warp::ws()) - .and(session::build()) - .map(|ws: Ws, session: Session| ws.on_upgrade(move |ws| socket(ws, session))) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new().route("/rpc", get(handler)) } -async fn socket(ws: WebSocket, session: Session) { - let rpc = Rpc::new(session); +async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension) -> impl IntoResponse { + // finalize the upgrade process by returning upgrade callback. + // we can customize the callback by sending additional info such as address. + ws.on_upgrade(move |socket| handle_socket(socket, sess)) +} + +async fn handle_socket(ws: WebSocket, sess: Session) { + let rpc = Rpc::new(sess); Rpc::serve(rpc, ws).await } @@ -89,7 +100,7 @@ impl Rpc { let png = chn.clone(); // The WebSocket has connected Rpc::connected(rpc.clone(), chn.clone()).await; - // Send messages to the client + // Send Ping messages to the client tokio::task::spawn(async move { // Create the interval ticker let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY); @@ -98,7 +109,7 @@ impl Rpc { // Wait for the timer interval.tick().await; // Create the ping message - let msg = Message::ping(vec![]); + let msg = Message::Ping(vec![]); // Send the message to the client if png.send(msg).await.is_err() { // Exit out of the loop @@ -146,20 +157,18 @@ impl Rpc { while let Some(msg) = wrx.next().await { match msg { // We've received a message from the client + // Ping is automatically handled by the WebSocket library Ok(msg) => match msg { - msg if msg.is_ping() => { - let _ = chn.send(Message::pong(vec![])).await; - } - msg if msg.is_text() => { + Message::Text(_) => { tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone())); } - msg if msg.is_binary() => { + Message::Binary(_) => { tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone())); } - msg if msg.is_close() => { + Message::Close(_) => { break; } - msg if msg.is_pong() => { + Message::Pong(_) => { continue; } _ => { @@ -214,16 +223,14 @@ impl Rpc { // Parse the request let req = match msg { // This is a binary message - m if m.is_binary() => { + Message::Binary(val) => { // Use binary output out = Output::Full; // Deserialize the input - Value::from(m.into_bytes()) + Value::from(val) } // This is a text message - m if m.is_text() => { - // This won't panic due to the check above - let val = m.to_str().unwrap(); + Message::Text(ref val) => { // Parse the SurrealQL object match surrealdb::sql::value(val) { // The SurrealQL message parsed ok diff --git a/src/net/session.rs b/src/net/session.rs deleted file mode 100644 index e22270e3..00000000 --- a/src/net/session.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::dbs::DB; -use crate::err::Error; -use crate::iam::verify::basic; -use crate::iam::BASIC; -use crate::net::client_ip; -use surrealdb::dbs::Session; -use surrealdb::iam::verify::token; -use surrealdb::iam::TOKEN; -use warp::Filter; - -pub fn build() -> impl Filter + Clone { - // Enable on any path - let conf = warp::any(); - // Add remote ip address - let conf = conf.and(client_ip::build()); - // Add authorization header - let conf = conf.and(warp::header::optional::("authorization")); - // Add http origin header - let conf = conf.and(warp::header::optional::("origin")); - // Add session id header - let conf = conf.and(warp::header::optional::("id")); - // Add namespace header - let conf = conf.and(warp::header::optional::("ns")); - // Add database header - let conf = conf.and(warp::header::optional::("db")); - // Process all headers - conf.and_then(process) -} - -async fn process( - ip: Option, - au: Option, - or: Option, - id: Option, - ns: Option, - db: Option, -) -> Result { - let kvs = DB.get().unwrap(); - // Create session - #[rustfmt::skip] - let mut session = Session { ip, or, id, ns, db, ..Default::default() }; - // Parse the authentication header - match au { - // Basic authentication data was supplied - Some(auth) if auth.starts_with(BASIC) => basic(&mut session, auth).await, - // Token authentication data was supplied - Some(auth) if auth.starts_with(TOKEN) => { - token(kvs, &mut session, auth).await.map_err(Error::from) - } - // Wrong authentication data was supplied - Some(_) => Err(Error::InvalidAuth), - // No authentication data was supplied - None => Ok(()), - }?; - // Pass the authenticated session through - Ok(session) -} diff --git a/src/net/signals.rs b/src/net/signals.rs index 6121b082..1c7084a1 100644 --- a/src/net/signals.rs +++ b/src/net/signals.rs @@ -1,5 +1,19 @@ +use std::time::Duration; + +use axum_server::Handle; + use crate::err::Error; +/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received. +pub fn graceful_shutdown(handle: Handle, dur: Option) { + tokio::spawn(async move { + let result = listen().await.expect("Failed to listen to shutdown signal"); + info!(target: super::LOG, "{} received. Start graceful shutdown...", result); + + handle.graceful_shutdown(dur) + }); +} + #[cfg(unix)] pub async fn listen() -> Result { // Import the OS signals @@ -11,7 +25,7 @@ pub async fn listen() -> Result { let mut sigterm = signal(SignalKind::terminate())?; // Listen and wait for the system signals tokio::select! { - // Wait for a SIGQUIT signal + // Wait for a SIGHUP signal _ = sighup.recv() => { Ok(String::from("SIGHUP")) } diff --git a/src/net/signin.rs b/src/net/signin.rs index a5ccc1ba..317b813b 100644 --- a/src/net/signin.rs +++ b/src/net/signin.rs @@ -2,16 +2,24 @@ use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; -use crate::net::session; use crate::net::CF; +use axum::extract::DefaultBodyLimit; +use axum::response::IntoResponse; +use axum::routing::options; +use axum::Extension; +use axum::Router; +use axum::TypedHeader; use bytes::Bytes; +use http_body::Body as HttpBody; use serde::Serialize; use surrealdb::dbs::Session; use surrealdb::opt::auth::Root; use surrealdb::sql::Value; -use warp::Filter; +use tower_http::limit::RequestBodyLimitLayer; -const MAX: u64 = 1024; // 1 KiB +use super::headers::Accept; + +const MAX: usize = 1024; // 1 KiB #[derive(Serialize)] struct Success { @@ -30,29 +38,24 @@ impl Success { } } -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - // Set base path - let base = warp::path("signin").and(warp::path::end()); - // Set opts method - let opts = base.and(warp::options()).map(warp::reply); - // Set post method - let post = base - .and(warp::post()) - .and(warp::header::optional::(http::header::ACCEPT.as_str())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(session::build()) - .and_then(handler); - // Specify route - opts.or(post) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: std::error::Error + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new() + .route("/signin", options(|| async {}).post(handler)) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)) } async fn handler( - output: Option, + Extension(mut session): Extension, + maybe_output: Option>, body: Bytes, - mut session: Session, -) -> Result { +) -> Result { // Get a database reference let kvs = DB.get().unwrap(); // Get the config options @@ -72,25 +75,25 @@ async fn handler( .map_err(Error::from) { // Authentication was successful - Ok(v) => match output.as_deref() { + Ok(v) => match maybe_output.as_deref() { + // Simple serialization + Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))), + Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))), + // Internal serialization + Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))), + // Text serialization + Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())), // Return nothing None => Ok(output::none()), - // Text serialization - Some("text/plain") => Ok(output::text(v.unwrap_or_default())), - // Simple serialization - Some("application/json") => Ok(output::json(&Success::new(v))), - Some("application/cbor") => Ok(output::cbor(&Success::new(v))), - Some("application/pack") => Ok(output::pack(&Success::new(v))), - // Internal serialization - Some("application/surrealdb") => Ok(output::full(&Success::new(v))), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error with authentication - Err(e) => Err(warp::reject::custom(e)), + Err(err) => Err(err), } } // The provided value was not an object - _ => Err(warp::reject::custom(Error::Request)), + _ => Err(Error::Request), } } diff --git a/src/net/signup.rs b/src/net/signup.rs index f3bf54d5..ae992263 100644 --- a/src/net/signup.rs +++ b/src/net/signup.rs @@ -2,14 +2,20 @@ use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; -use crate::net::session; +use axum::extract::DefaultBodyLimit; +use axum::response::IntoResponse; +use axum::routing::options; +use axum::{Extension, Router, TypedHeader}; use bytes::Bytes; +use http_body::Body as HttpBody; use serde::Serialize; use surrealdb::dbs::Session; use surrealdb::sql::Value; -use warp::Filter; +use tower_http::limit::RequestBodyLimitLayer; -const MAX: u64 = 1024; // 1 KiB +use super::headers::Accept; + +const MAX: usize = 1024; // 1 KiB #[derive(Serialize)] struct Success { @@ -28,29 +34,24 @@ impl Success { } } -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - // Set base path - let base = warp::path("signup").and(warp::path::end()); - // Set opts method - let opts = base.and(warp::options()).map(warp::reply); - // Set post method - let post = base - .and(warp::post()) - .and(warp::header::optional::(http::header::ACCEPT.as_str())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(session::build()) - .and_then(handler); - // Specify route - opts.or(post) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: std::error::Error + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new() + .route("/signup", options(|| async {}).post(handler)) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)) } async fn handler( - output: Option, + Extension(mut session): Extension, + maybe_output: Option>, body: Bytes, - mut session: Session, -) -> Result { +) -> Result { // Get a database reference let kvs = DB.get().unwrap(); // Convert the HTTP body into text @@ -62,25 +63,25 @@ async fn handler( match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from) { // Authentication was successful - Ok(v) => match output.as_deref() { + Ok(v) => match maybe_output.as_deref() { + // Simple serialization + Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))), + Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))), + // Internal serialization + Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))), + // Text serialization + Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())), // Return nothing None => Ok(output::none()), - // Text serialization - Some("text/plain") => Ok(output::text(v.unwrap_or_default())), - // Simple serialization - Some("application/json") => Ok(output::json(&Success::new(v))), - Some("application/cbor") => Ok(output::cbor(&Success::new(v))), - Some("application/pack") => Ok(output::pack(&Success::new(v))), - // Internal serialization - Some("application/surrealdb") => Ok(output::full(&Success::new(v))), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error with authentication - Err(e) => Err(warp::reject::custom(e)), + Err(err) => Err(err), } } // The provided value was not an object - _ => Err(warp::reject::custom(Error::Request)), + _ => Err(Error::Request), } } diff --git a/src/net/sql.rs b/src/net/sql.rs index 9ff44960..50ed614f 100644 --- a/src/net/sql.rs +++ b/src/net/sql.rs @@ -3,74 +3,80 @@ use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; use crate::net::params::Params; -use crate::net::session; +use axum::extract::ws::Message; +use axum::extract::ws::WebSocket; +use axum::extract::DefaultBodyLimit; +use axum::extract::Query; +use axum::extract::WebSocketUpgrade; +use axum::response::IntoResponse; +use axum::routing::options; +use axum::Extension; +use axum::Router; +use axum::TypedHeader; use bytes::Bytes; use futures::{SinkExt, StreamExt}; +use http_body::Body as HttpBody; use surrealdb::dbs::Session; -use warp::ws::{Message, WebSocket, Ws}; -use warp::Filter; +use tower_http::limit::RequestBodyLimitLayer; -const MAX: u64 = 1024 * 1024; // 1 MiB +use super::headers::Accept; -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - // Set base path - let base = warp::path("sql").and(warp::path::end()); - // Set opts method - let opts = base.and(warp::options()).map(warp::reply); - // Set post method - let post = base - .and(warp::post()) - .and(warp::header::(http::header::ACCEPT.as_str())) - .and(warp::body::content_length_limit(MAX)) - .and(warp::body::bytes()) - .and(warp::query()) - .and(session::build()) - .and_then(handler); - // Set sock method - let sock = base - .and(warp::ws()) - .and(session::build()) - .map(|ws: Ws, session: Session| ws.on_upgrade(move |ws| socket(ws, session))); - // Specify route - opts.or(post).or(sock) +const MAX: usize = 1024 * 1024; // 1 MiB + +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: std::error::Error + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new() + .route("/sql", options(|| async {}).get(ws_handler).post(post_handler)) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)) } -async fn handler( - output: String, +async fn post_handler( + Extension(session): Extension, + output: Option>, + params: Query, sql: Bytes, - params: Params, - session: Session, -) -> Result { +) -> Result { // Get a database reference let db = DB.get().unwrap(); // Convert the received sql query let sql = bytes_to_utf8(&sql)?; // Execute the received sql query - match db.execute(sql, &session, params.parse().into()).await { - // Convert the response to JSON - Ok(res) => match output.as_ref() { + match db.execute(sql, &session, params.0.parse().into()).await { + Ok(res) => match output.as_deref() { // Simple serialization - "application/json" => Ok(output::json(&output::simplify(res))), - "application/cbor" => Ok(output::cbor(&output::simplify(res))), - "application/pack" => Ok(output::pack(&output::simplify(res))), + Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), + Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), + Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization - "application/surrealdb" => Ok(output::full(&res)), + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested - _ => Err(warp::reject::custom(Error::InvalidType)), + _ => Err(Error::InvalidType), }, // There was an error when executing the query - Err(err) => Err(warp::reject::custom(Error::from(err))), + Err(err) => Err(Error::from(err)), } } -async fn socket(ws: WebSocket, session: Session) { +async fn ws_handler( + ws: WebSocketUpgrade, + Extension(sess): Extension, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_socket(socket, sess)) +} + +async fn handle_socket(ws: WebSocket, session: Session) { // Split the WebSocket connection let (mut tx, mut rx) = ws.split(); // Wait to receive the next message while let Some(res) = rx.next().await { if let Ok(msg) = res { - if let Ok(sql) = msg.to_str() { + if let Ok(sql) = msg.to_text() { // Get a database reference let db = DB.get().unwrap(); // Execute the received sql query @@ -78,12 +84,12 @@ async fn socket(ws: WebSocket, session: Session) { // Convert the response to JSON Ok(v) => match serde_json::to_string(&v) { // Send the JSON response to the client - Ok(v) => tx.send(Message::text(v)).await, + Ok(v) => tx.send(Message::Text(v)).await, // There was an error converting to JSON - Err(e) => tx.send(Message::text(Error::from(e))).await, + Err(e) => tx.send(Message::Text(Error::from(e).to_string())).await, }, // There was an error when executing the query - Err(e) => tx.send(Message::text(Error::from(e))).await, + Err(e) => tx.send(Message::Text(Error::from(e).to_string())).await, }; } } diff --git a/src/net/status.rs b/src/net/status.rs deleted file mode 100644 index 29f8a467..00000000 --- a/src/net/status.rs +++ /dev/null @@ -1,6 +0,0 @@ -use warp::Filter; - -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path("status").and(warp::path::end()).and(warp::get()).map(warp::reply) -} diff --git a/src/net/sync.rs b/src/net/sync.rs index 5215f9d0..6fc09c25 100644 --- a/src/net/sync.rs +++ b/src/net/sync.rs @@ -1,22 +1,20 @@ -use warp::http; -use warp::Filter; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use http_body::Body as HttpBody; -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - // Set base path - let base = warp::path("sync").and(warp::path::end()); - // Set save method - let save = base.and(warp::get()).and_then(save); - // Set load method - let load = base.and(warp::post()).and_then(load); - // Specify route - save.or(load) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new().route("/sync", get(save).post(load)) } -pub async fn load() -> Result { - Ok(warp::reply::with_status("Load", http::StatusCode::OK)) +async fn load() -> impl IntoResponse { + "Load" } -pub async fn save() -> Result { - Ok(warp::reply::with_status("Save", http::StatusCode::OK)) +async fn save() -> impl IntoResponse { + "Save" } diff --git a/src/net/tracer.rs b/src/net/tracer.rs new file mode 100644 index 00000000..0842271b --- /dev/null +++ b/src/net/tracer.rs @@ -0,0 +1,235 @@ +use std::{fmt, time::Duration}; + +use axum::{ + body::{boxed, Body, BoxBody}, + extract::MatchedPath, + headers::{ + authorization::{Basic, Bearer}, + Authorization, Origin, + }, + Extension, RequestPartsExt, TypedHeader, +}; +use futures_util::future::BoxFuture; +use http::{header, request::Parts, StatusCode}; +use hyper::{Request, Response}; +use surrealdb::{dbs::Session, iam::verify::token}; +use tower_http::{ + auth::AsyncAuthorizeRequest, + request_id::RequestId, + trace::{MakeSpan, OnFailure, OnRequest, OnResponse}, +}; +use tracing::{field, Level, Span}; + +use crate::{dbs::DB, err::Error, iam::verify::basic}; + +use super::{client_ip::ExtractClientIP, AppState}; + +/// +/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait. +/// It is used to authorize requests to SurrealDB using Basic or Token authentication. +/// +/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer: +/// +/// ```rust +/// use tower_http::auth::RequireAuthorizationLayer; +/// use surrealdb::net::SurrealAuth; +/// use axum::Router; +/// +/// let auth = RequireAuthorizationLayer::new(SurrealAuth); +/// +/// let app = Router::new() +/// .route("/version", get(|| async { "0.1.0" })) +/// .layer(auth); +/// ``` +#[derive(Clone, Copy)] +pub(super) struct SurrealAuth; + +impl AsyncAuthorizeRequest for SurrealAuth +where + B: Send + Sync + 'static, +{ + type RequestBody = B; + type ResponseBody = BoxBody; + type Future = BoxFuture<'static, Result, Response>>; + + fn authorize(&mut self, request: Request) -> Self::Future { + Box::pin(async { + let (mut parts, body) = request.into_parts(); + match check_auth(&mut parts).await { + Ok(sess) => { + parts.extensions.insert(sess); + Ok(Request::from_parts(parts, body)) + } + Err(err) => { + let unauthorized_response = Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(boxed(Body::from(err.to_string()))) + .unwrap(); + Err(unauthorized_response) + } + } + }) + } +} + +async fn check_auth(parts: &mut Parts) -> Result { + let kvs = DB.get().unwrap(); + + let or = if let Ok(or) = parts.extract::>().await { + if !or.is_null() { + Some(or.to_string()) + } else { + None + } + } else { + None + }; + + let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader + let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader + let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader + + let Extension(state) = parts.extract::>().await.map_err(|err| { + tracing::error!("Error extracting the app state: {:?}", err); + Error::InvalidAuth + })?; + let ExtractClientIP(ip) = + parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None)); + + // Create session + #[rustfmt::skip] + let mut session = Session { ip, or, id, ns, db, ..Default::default() }; + + // If Basic authentication data was supplied + if let Ok(au) = parts.extract::>>().await { + basic(&mut session, au.username(), au.password()).await + } else if let Ok(au) = parts.extract::>>().await { + token(kvs, &mut session, au.token().into()).await.map_err(|e| e.into()) + } else { + Err(Error::InvalidAuth) + }?; + + Ok(session) +} + +/// +/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer. +/// +/// Example: +/// +/// ```rust +/// use tower_http::trace::TraceLayer; +/// use surrealdb::net::HttpTraceLayerHooks; +/// use axum::Router; +/// +/// let trace = TraceLayer::new_for_http().on_request(HttpTraceLayerHooks::default()); +/// +/// let app = Router::new() +/// .route("/version", get(|| async { "0.1.0" })) +/// .layer(trace); +/// ``` + +#[derive(Default, Clone)] +pub(crate) struct HttpTraceLayerHooks; + +impl MakeSpan for HttpTraceLayerHooks { + fn make_span(&mut self, req: &Request) -> Span { + // The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md + let span = tracing::info_span!( + target: "surreal::http", + "request", + otel.name = field::Empty, + otel.kind = "server", + http.route = field::Empty, + http.request.method = req.method().as_str(), + http.request.body.size = field::Empty, + url.path = req.uri().path(), + url.query = field::Empty, + url.scheme = field::Empty, + http.request.id = field::Empty, + user_agent.original = field::Empty, + network.protocol.name = "http", + network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"), + client.address = field::Empty, + client.port = field::Empty, + client.socket.address = field::Empty, + server.address = field::Empty, + server.port = field::Empty, + // set on the response hook + http.latency.ms = field::Empty, + http.response.status_code = field::Empty, + http.response.body.size = field::Empty, + // set on the failure hook + error = field::Empty, + error_message = field::Empty, + ); + + req.uri().query().map(|v| span.record("url.query", v)); + req.uri().scheme().map(|v| span.record("url.scheme", v.as_str())); + req.uri().host().map(|v| span.record("server.address", v)); + req.uri().port_u16().map(|v| span.record("server.port", v)); + + req.headers() + .get(header::CONTENT_LENGTH) + .map(|v| v.to_str().map(|v| span.record("http.request.body.size", v))); + req.headers() + .get(header::USER_AGENT) + .map(|v| v.to_str().map(|v| span.record("user_agent.original", v))); + + if let Some(path) = req.extensions().get::() { + span.record("otel.name", format!("{} {}", req.method(), path.as_str())); + span.record("http.route", path.as_str()); + } else { + span.record("otel.name", format!("{} -", req.method())); + }; + + if let Some(req_id) = req.extensions().get::() { + match req_id.header_value().to_str() { + Err(err) => tracing::error!(error = %err, "failed to parse request id"), + Ok(request_id) => { + span.record("http.request.id", request_id); + } + } + } + + if let Some(client_ip) = req.extensions().get::() { + if let Some(ref client_ip) = client_ip.0 { + span.record("client.address", client_ip); + } + } + + span + } +} + +impl OnRequest for HttpTraceLayerHooks { + fn on_request(&mut self, _: &Request, _: &Span) { + tracing::event!(Level::INFO, "started processing request"); + } +} + +impl OnResponse for HttpTraceLayerHooks { + fn on_response(self, response: &Response, latency: Duration, span: &Span) { + if let Some(size) = response.headers().get(header::CONTENT_LENGTH) { + span.record("http.response.body.size", size.to_str().unwrap()); + } + span.record("http.response.status_code", response.status().as_u16()); + + // Server errors are handled by the OnFailure hook + if !response.status().is_server_error() { + span.record("http.latency.ms", latency.as_millis()); + tracing::event!(Level::INFO, "finished processing request"); + } + } +} + +impl OnFailure for HttpTraceLayerHooks +where + FailureClass: fmt::Display, +{ + fn on_failure(&mut self, error: FailureClass, latency: Duration, span: &Span) { + span.record("error_message", &error.to_string()); + span.record("http.latency.ms", latency.as_millis()); + tracing::event!(Level::ERROR, error = error.to_string(), "response failed"); + } +} diff --git a/src/net/version.rs b/src/net/version.rs index f8749950..e364fe0b 100644 --- a/src/net/version.rs +++ b/src/net/version.rs @@ -1,14 +1,18 @@ use crate::cnf::PKG_NAME; use crate::cnf::PKG_VERSION; -use warp::http; -use warp::Filter; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use http_body::Body as HttpBody; -#[allow(opaque_hidden_inferred_bound)] -pub fn config() -> impl Filter + Clone { - warp::path("version").and(warp::path::end()).and(warp::get()).and_then(handler) +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new().route("/version", get(handler)) } -pub async fn handler() -> Result { - let val = format!("{PKG_NAME}-{}", *PKG_VERSION); - Ok(warp::reply::with_status(val, http::StatusCode::OK)) +async fn handler() -> impl IntoResponse { + format!("{PKG_NAME}-{}", *PKG_VERSION) } diff --git a/src/rpc/res.rs b/src/rpc/res.rs index 0cb84d30..6f8c2f51 100644 --- a/src/rpc/res.rs +++ b/src/rpc/res.rs @@ -1,3 +1,4 @@ +use axum::extract::ws::Message; use serde::Serialize; use serde_json::{json, Value as Json}; use std::borrow::Cow; @@ -7,7 +8,6 @@ use surrealdb::dbs::Notification; use surrealdb::sql; use surrealdb::sql::Value; use tracing::instrument; -use warp::ws::Message; #[derive(Clone)] pub enum Output { @@ -87,19 +87,19 @@ impl Response { let message = match out { Output::Json => { let res = serde_json::to_string(&self.simplify()).unwrap(); - Message::text(res) + Message::Text(res) } Output::Cbor => { let res = serde_cbor::to_vec(&self.simplify()).unwrap(); - Message::binary(res) + Message::Binary(res) } Output::Pack => { let res = serde_pack::to_vec(&self.simplify()).unwrap(); - Message::binary(res) + Message::Binary(res) } Output::Full => { let res = surrealdb::sql::serde::serialize(&self).unwrap(); - Message::binary(res) + Message::Binary(res) } }; let _ = chn.send(message).await; diff --git a/src/o11y/logger.rs b/src/telemetry/logs/mod.rs similarity index 100% rename from src/o11y/logger.rs rename to src/telemetry/logs/mod.rs diff --git a/src/telemetry/metrics/http/mod.rs b/src/telemetry/metrics/http/mod.rs new file mode 100644 index 00000000..079f8e08 --- /dev/null +++ b/src/telemetry/metrics/http/mod.rs @@ -0,0 +1,112 @@ +pub(super) mod tower_layer; + +use once_cell::sync::Lazy; +use opentelemetry::{ + metrics::{Histogram, Meter, MeterProvider, ObservableUpDownCounter, Unit}, + runtime, + sdk::{ + export::metrics::aggregation, + metrics::{ + controllers::{self, BasicController}, + processors, selectors, + }, + }, + Context, +}; +use opentelemetry_otlp::MetricsExporterBuilder; + +use crate::telemetry::OTEL_DEFAULT_RESOURCE; + +// Histogram buckets in milliseconds +static HTTP_DURATION_MS_HISTOGRAM_BUCKETS: &[f64] = &[ + 5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0, + 2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0, +]; + +const KB: f64 = 1024.0; +const MB: f64 = 1024.0 * KB; + +const HTTP_SIZE_HISTOGRAM_BUCKETS: &[f64] = &[ + 1.0 * KB, // 1 KB + 2.0 * KB, // 2 KB + 5.0 * KB, // 5 KB + 10.0 * KB, // 10 KB + 100.0 * KB, // 100 KB + 500.0 * KB, // 500 KB + 1.0 * MB, // 1 MB + 2.5 * MB, // 2 MB + 5.0 * MB, // 5 MB + 10.0 * MB, // 10 MB + 25.0 * MB, // 25 MB + 50.0 * MB, // 50 MB + 100.0 * MB, // 100 MB +]; + +static METER_PROVIDER_HTTP_DURATION: Lazy = Lazy::new(|| { + let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic()) + .build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector())) + .unwrap(); + + let builder = controllers::basic(processors::factory( + selectors::simple::histogram(HTTP_DURATION_MS_HISTOGRAM_BUCKETS), + aggregation::cumulative_temporality_selector(), + )) + .with_exporter(exporter) + .with_resource(OTEL_DEFAULT_RESOURCE.clone()); + + let controller = builder.build(); + controller.start(&Context::current(), runtime::Tokio).unwrap(); + controller +}); + +static METER_PROVIDER_HTTP_SIZE: Lazy = Lazy::new(|| { + let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic()) + .build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector())) + .unwrap(); + + let builder = controllers::basic(processors::factory( + selectors::simple::histogram(HTTP_SIZE_HISTOGRAM_BUCKETS), + aggregation::cumulative_temporality_selector(), + )) + .with_exporter(exporter) + .with_resource(OTEL_DEFAULT_RESOURCE.clone()); + + let controller = builder.build(); + controller.start(&Context::current(), runtime::Tokio).unwrap(); + controller +}); + +static HTTP_DURATION_METER: Lazy = + Lazy::new(|| METER_PROVIDER_HTTP_DURATION.meter("http_duration")); +static HTTP_SIZE_METER: Lazy = Lazy::new(|| METER_PROVIDER_HTTP_SIZE.meter("http_size")); + +pub static HTTP_SERVER_DURATION: Lazy> = Lazy::new(|| { + HTTP_DURATION_METER + .u64_histogram("http.server.duration") + .with_description("The HTTP server duration in milliseconds.") + .with_unit(Unit::new("ms")) + .init() +}); + +pub static HTTP_SERVER_ACTIVE_REQUESTS: Lazy> = Lazy::new(|| { + HTTP_DURATION_METER + .i64_observable_up_down_counter("http.server.active_requests") + .with_description("The number of active HTTP requests.") + .init() +}); + +pub static HTTP_SERVER_REQUEST_SIZE: Lazy> = Lazy::new(|| { + HTTP_SIZE_METER + .u64_histogram("http.server.request.size") + .with_description("Measures the size of HTTP request messages.") + .with_unit(Unit::new("mb")) + .init() +}); + +pub static HTTP_SERVER_RESPONSE_SIZE: Lazy> = Lazy::new(|| { + HTTP_SIZE_METER + .u64_histogram("http.server.response.size") + .with_description("Measures the size of HTTP response messages.") + .with_unit(Unit::new("mb")) + .init() +}); diff --git a/src/telemetry/metrics/http/tower_layer.rs b/src/telemetry/metrics/http/tower_layer.rs new file mode 100644 index 00000000..16ec44aa --- /dev/null +++ b/src/telemetry/metrics/http/tower_layer.rs @@ -0,0 +1,310 @@ +use axum::extract::MatchedPath; +use opentelemetry::{metrics::MetricsError, Context as TelemetryContext, KeyValue}; +use pin_project_lite::pin_project; +use std::{ + cell::Cell, + fmt, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use futures::Future; +use http::{Request, Response, StatusCode, Version}; +use tower::{Layer, Service}; + +use super::{ + HTTP_DURATION_METER, HTTP_SERVER_ACTIVE_REQUESTS, HTTP_SERVER_DURATION, + HTTP_SERVER_REQUEST_SIZE, HTTP_SERVER_RESPONSE_SIZE, +}; + +#[derive(Clone, Default)] +pub struct HttpMetricsLayer; + +impl Layer for HttpMetricsLayer { + type Service = HttpMetrics; + + fn layer(&self, inner: S) -> Self::Service { + HttpMetrics { + inner, + } + } +} + +#[derive(Clone)] +pub struct HttpMetrics { + inner: S, +} + +impl Service> for HttpMetrics +where + S: Service, Response = Response>, + ReqBody: http_body::Body, + ResBody: http_body::Body, + S::Error: fmt::Display + 'static, +{ + type Response = Response; + type Error = S::Error; + type Future = HttpCallMetricsFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let tracker = HttpCallMetricTracker::new(&request); + + HttpCallMetricsFuture::new(self.inner.call(request), tracker) + } +} + +pin_project! { + pub struct HttpCallMetricsFuture { + #[pin] + inner: F, + tracker: HttpCallMetricTracker, + } +} + +impl HttpCallMetricsFuture { + fn new(inner: F, tracker: HttpCallMetricTracker) -> Self { + Self { + inner, + tracker, + } + } +} + +impl Future for HttpCallMetricsFuture +where + Fut: Future, E>>, + ResBody: http_body::Body, + E: std::fmt::Display + 'static, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + this.tracker.set_state(ResultState::Started); + + if let Err(err) = on_request_start(this.tracker) { + error!("Failed to setup metrics when request started: {}", err); + // Consider this request not tracked: reset the state to None, so that the drop handler does not decrease the counter. + this.tracker.set_state(ResultState::None); + }; + + let response = futures_util::ready!(this.inner.poll(cx)); + + let result = match response { + Ok(reply) => { + this.tracker.set_state(ResultState::Result( + reply.status(), + reply.version(), + reply.body().size_hint().exact(), + )); + Ok(reply) + } + Err(e) => { + this.tracker.set_state(ResultState::Failed); + Err(e) + } + }; + Poll::Ready(result) + } +} + +pub struct HttpCallMetricTracker { + version: String, + method: hyper::Method, + scheme: Option, + host: Option, + route: Option, + state: Cell, + status_code: Option, + request_size: Option, + response_size: Option, + start: Instant, + finish: Option, +} + +pub enum ResultState { + /// The result was already processed. + None, + /// Request was started. + Started, + /// The result failed with an error. + Failed, + /// The result is an actual HTTP response. + Result(StatusCode, Version, Option), +} + +impl HttpCallMetricTracker { + fn new(request: &Request) -> Self + where + B: http_body::Body, + { + Self { + version: format!("{:?}", request.version()), + method: request.method().clone(), + scheme: request.uri().scheme().cloned(), + host: request.uri().host().map(|s| s.to_string()), + route: request.extensions().get::().map(|v| v.as_str().to_string()), + state: Cell::new(ResultState::None), + status_code: None, + request_size: request.body().size_hint().exact(), + response_size: None, + start: Instant::now(), + finish: None, + } + } + + fn set_state(&self, state: ResultState) { + self.state.set(state) + } + + pub fn duration(&self) -> Duration { + self.finish.unwrap_or(Instant::now()) - self.start + } + + // Follows the OpenTelemetry semantic conventions for HTTP metrics define here: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/metrics/semantic_conventions/http-metrics.md + fn olel_common_attrs(&self) -> Vec { + let mut res = vec![ + KeyValue::new("http.request.method", self.method.as_str().to_owned()), + KeyValue::new("network.protocol.name", "http".to_owned()), + ]; + + if let Some(scheme) = &self.scheme { + res.push(KeyValue::new("url.scheme", scheme.as_str().to_owned())); + } + + if let Some(host) = &self.host { + res.push(KeyValue::new("server.address", host.to_owned())); + } + + res + } + + pub(super) fn active_req_attrs(&self) -> Vec { + self.olel_common_attrs() + } + + pub(super) fn request_duration_attrs(&self) -> Vec { + let mut res = self.olel_common_attrs(); + + res.push(KeyValue::new( + "http.response.status_code", + self.status_code.map(|v| v.as_str().to_owned()).unwrap_or("000".to_owned()), + )); + + if let Some(v) = self.version.strip_prefix("HTTP/") { + res.push(KeyValue::new("network.protocol.version", v.to_owned())); + } + + if let Some(target) = &self.route { + res.push(KeyValue::new("http.route", target.to_owned())); + } + + res + } + + pub(super) fn request_size_attrs(&self) -> Vec { + self.request_duration_attrs() + } + + pub(super) fn response_size_attrs(&self) -> Vec { + self.request_duration_attrs() + } +} + +impl Drop for HttpCallMetricTracker { + fn drop(&mut self) { + match self.state.replace(ResultState::None) { + ResultState::None => { + // Request was not tracked, so no need to decrease the counter. + return; + } + ResultState::Started => { + // If the response was never processed, we can't get a valid status code + } + ResultState::Failed => { + // If there's an error processing the request and we don't have a response, we can't get a valid status code + } + ResultState::Result(s, v, size) => { + self.status_code = Some(s); + self.version = format!("{:?}", v); + self.response_size = size; + } + }; + + self.finish = Some(Instant::now()); + + if let Err(err) = on_request_finish(self) { + error!(target: "surrealdb::telemetry", "Failed to setup metrics when request finished: {}", err); + } + } +} + +pub fn on_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + // Setup the active_requests observer + observe_active_request_start(tracker) +} + +pub fn on_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + // Setup the active_requests observer + observe_active_request_finish(tracker)?; + + // Record the duration of the request. + record_request_duration(tracker); + + // Record the request size if known + if let Some(size) = tracker.request_size { + record_request_size(tracker, size) + } + + // Record the response size if known + if let Some(size) = tracker.response_size { + record_response_size(tracker, size) + } + + Ok(()) +} + +fn observe_active_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + let attrs = tracker.active_req_attrs(); + // Setup the callback to observe the active requests. + HTTP_DURATION_METER + .register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, 1, &attrs)) +} + +fn observe_active_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + let attrs = tracker.active_req_attrs(); + // Setup the callback to observe the active requests. + HTTP_DURATION_METER + .register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, -1, &attrs)) +} + +fn record_request_duration(tracker: &HttpCallMetricTracker) { + // Record the duration of the request. + HTTP_SERVER_DURATION.record( + &TelemetryContext::current(), + tracker.duration().as_millis() as u64, + &tracker.request_duration_attrs(), + ); +} + +pub fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) { + HTTP_SERVER_REQUEST_SIZE.record( + &TelemetryContext::current(), + size, + &tracker.request_size_attrs(), + ); +} + +pub fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) { + HTTP_SERVER_RESPONSE_SIZE.record( + &TelemetryContext::current(), + size, + &tracker.response_size_attrs(), + ); +} diff --git a/src/telemetry/metrics/mod.rs b/src/telemetry/metrics/mod.rs new file mode 100644 index 00000000..b2b235a3 --- /dev/null +++ b/src/telemetry/metrics/mod.rs @@ -0,0 +1,3 @@ +pub mod http; + +pub use self::http::tower_layer::HttpMetricsLayer; diff --git a/src/o11y/mod.rs b/src/telemetry/mod.rs similarity index 64% rename from src/o11y/mod.rs rename to src/telemetry/mod.rs index d6616f78..aab616c9 100644 --- a/src/o11y/mod.rs +++ b/src/telemetry/mod.rs @@ -1,7 +1,16 @@ -mod logger; -mod tracers; +mod logs; +pub mod metrics; +mod traces; + +use std::time::Duration; use crate::cli::validator::parser::env_filter::CustomEnvFilter; +use once_cell::sync::Lazy; +use opentelemetry::sdk::resource::{ + EnvResourceDetector, SdkProvidedResourceDetector, TelemetryResourceDetector, +}; +use opentelemetry::sdk::Resource; +use opentelemetry::KeyValue; use tracing::Subscriber; use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::prelude::*; @@ -9,6 +18,28 @@ use tracing_subscriber::util::SubscriberInitExt; #[cfg(feature = "has-storage")] use tracing_subscriber::EnvFilter; +pub static OTEL_DEFAULT_RESOURCE: Lazy = Lazy::new(|| { + let res = Resource::from_detectors( + Duration::from_secs(5), + vec![ + // set service.name from env OTEL_SERVICE_NAME > env OTEL_RESOURCE_ATTRIBUTES > option_env! CARGO_BIN_NAME > unknown_service + Box::new(SdkProvidedResourceDetector), + // detect res from env OTEL_RESOURCE_ATTRIBUTES (resources string like key1=value1,key2=value2,...) + Box::new(EnvResourceDetector::new()), + // set telemetry.sdk.{name, language, version} + Box::new(TelemetryResourceDetector), + ], + ); + + // If no external service.name is set, set it to surrealdb + if res.get("service.name".into()).unwrap_or("".into()).as_str() == "unknown_service" { + debug!("No service.name detected, use 'surrealdb'"); + res.merge(&Resource::new([KeyValue::new("service.name", "surrealdb")])) + } else { + res + } +}); + #[derive(Default, Debug, Clone)] pub struct Builder { log_level: Option, @@ -32,7 +63,8 @@ impl Builder { self.filter = Some(CustomEnvFilter(filter)); self } - /// Build a dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber + + /// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber pub fn build(self) -> Box { let registry = tracing_subscriber::registry(); let registry = registry.with(self.filter.map(|filter| { @@ -44,11 +76,12 @@ impl Builder { .with_filter(filter.0) .boxed() })); - let registry = registry.with(self.log_level.map(logger::new)); - let registry = registry.with(tracers::new()); + let registry = registry.with(self.log_level.map(logs::new)); + let registry = registry.with(traces::new()); Box::new(registry) } - /// Build a dispatcher and set it as global + + /// tracing pipeline pub fn init(self) { self.build().init() } @@ -60,10 +93,12 @@ mod tests { use tracing::{span, Level}; use tracing_subscriber::util::SubscriberInitExt; + use crate::telemetry; + #[tokio::test(flavor = "multi_thread")] async fn test_otlp_tracer() { println!("Starting mock otlp server..."); - let (addr, mut req_rx) = super::tracers::tests::mock_otlp_server().await; + let (addr, mut req_rx) = telemetry::traces::tests::mock_otlp_server().await; { let otlp_endpoint = format!("http://{}", addr); @@ -73,7 +108,7 @@ mod tests { ("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())), ], || { - let _enter = super::builder().build().set_default(); + let _enter = telemetry::builder().build().set_default(); println!("Sending span..."); @@ -90,16 +125,8 @@ mod tests { println!("Waiting for request..."); let req = req_rx.recv().await.expect("missing export request"); - let first_span = req - .resource_spans - .first() - .unwrap() - .instrumentation_library_spans - .first() - .unwrap() - .spans - .first() - .unwrap(); + let first_span = + req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap(); assert_eq!("test-surreal-span", first_span.name); let first_event = first_span.events.first().unwrap(); assert_eq!("test-surreal-event", first_event.name); @@ -108,7 +135,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_tracing_filter() { println!("Starting mock otlp server..."); - let (addr, mut req_rx) = super::tracers::tests::mock_otlp_server().await; + let (addr, mut req_rx) = telemetry::traces::tests::mock_otlp_server().await; { let otlp_endpoint = format!("http://{}", addr); @@ -119,7 +146,7 @@ mod tests { ("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())), ], || { - let _enter = super::builder().build().set_default(); + let _enter = telemetry::builder().build().set_default(); println!("Sending spans..."); @@ -144,14 +171,7 @@ mod tests { println!("Waiting for request..."); let req = req_rx.recv().await.expect("missing export request"); - let spans = &req - .resource_spans - .first() - .unwrap() - .instrumentation_library_spans - .first() - .unwrap() - .spans; + let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans; assert_eq!(1, spans.len()); assert_eq!("debug", spans.first().unwrap().name); diff --git a/src/o11y/tracers/mod.rs b/src/telemetry/traces/mod.rs similarity index 96% rename from src/o11y/tracers/mod.rs rename to src/telemetry/traces/mod.rs index d2dc68d5..b562e435 100644 --- a/src/o11y/tracers/mod.rs +++ b/src/telemetry/traces/mod.rs @@ -59,7 +59,9 @@ pub mod tests { request: tonic::Request, ) -> Result, tonic::Status> { self.tx.lock().unwrap().try_send(request.into_inner()).expect("Channel full"); - Ok(tonic::Response::new(ExportTraceServiceResponse {})) + Ok(tonic::Response::new(ExportTraceServiceResponse { + partial_success: None, + })) } } diff --git a/src/o11y/tracers/otlp.rs b/src/telemetry/traces/otlp.rs similarity index 80% rename from src/o11y/tracers/otlp.rs rename to src/telemetry/traces/otlp.rs index 94bd8705..f820583c 100644 --- a/src/o11y/tracers/otlp.rs +++ b/src/telemetry/traces/otlp.rs @@ -1,10 +1,11 @@ -use opentelemetry::sdk::{trace::Tracer, Resource}; +use opentelemetry::sdk::trace::Tracer; use opentelemetry::trace::TraceError; -use opentelemetry::KeyValue; use opentelemetry_otlp::WithExportConfig; use tracing::{Level, Subscriber}; use tracing_subscriber::{EnvFilter, Layer}; +use crate::telemetry::OTEL_DEFAULT_RESOURCE; + const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER"; pub fn new() -> Box + Send + Sync> @@ -15,12 +16,12 @@ where } fn tracer() -> Result { - let resource = Resource::new(vec![KeyValue::new("service.name", "surrealdb")]); - opentelemetry_otlp::new_pipeline() .tracing() .with_exporter(opentelemetry_otlp::new_exporter().tonic().with_env()) - .with_trace_config(opentelemetry::sdk::trace::config().with_resource(resource)) + .with_trace_config( + opentelemetry::sdk::trace::config().with_resource(OTEL_DEFAULT_RESOURCE.clone()), + ) .install_batch(opentelemetry::runtime::Tokio) } diff --git a/tests/cli.rs b/tests/cli.rs deleted file mode 100644 index 77c18217..00000000 --- a/tests/cli.rs +++ /dev/null @@ -1,354 +0,0 @@ -mod cli_integration { - // cargo test --package surreal --bin surreal --no-default-features --features storage-mem --test cli -- cli_integration --nocapture - - use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild}; - use rand::{thread_rng, Rng}; - use serial_test::serial; - use std::fs; - use std::path::Path; - use std::process::{Command, Stdio}; - - /// Child is a (maybe running) CLI process. It can be killed by dropping it - struct Child { - inner: Option, - } - - impl Child { - /// Send some thing to the child's stdin - fn input(mut self, input: &str) -> Self { - let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap(); - use std::io::Write; - stdin.write_all(input.as_bytes()).unwrap(); - self - } - - fn kill(mut self) -> Self { - self.inner.as_mut().unwrap().kill().unwrap(); - self - } - - /// Read the child's stdout concatenated with its stderr. Returns Ok if the child - /// returns successfully, Err otherwise. - fn output(mut self) -> Result { - let output = self.inner.take().unwrap().wait_with_output().unwrap(); - - let mut buf = String::from_utf8(output.stdout).unwrap(); - buf.push_str(&String::from_utf8(output.stderr).unwrap()); - - if output.status.success() { - Ok(buf) - } else { - Err(buf) - } - } - } - - impl Drop for Child { - fn drop(&mut self) { - if let Some(inner) = self.inner.as_mut() { - let _ = inner.kill(); - } - } - } - - fn run_internal>(args: &str, current_dir: Option

) -> Child { - let mut path = std::env::current_exe().unwrap(); - assert!(path.pop()); - if path.ends_with("deps") { - assert!(path.pop()); - } - - // Note: Cargo automatically builds this binary for integration tests. - path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX)); - - let mut cmd = Command::new(path); - if let Some(dir) = current_dir { - cmd.current_dir(&dir); - } - cmd.stdin(Stdio::piped()); - cmd.stdout(Stdio::piped()); - cmd.stderr(Stdio::piped()); - cmd.args(args.split_ascii_whitespace()); - Child { - inner: Some(cmd.spawn().unwrap()), - } - } - - /// Run the CLI with the given args - fn run(args: &str) -> Child { - run_internal::(args, None) - } - - /// Run the CLI with the given args inside a temporary directory - fn run_in_dir>(args: &str, current_dir: P) -> Child { - run_internal(args, Some(current_dir)) - } - - fn tmp_file(name: &str) -> String { - let path = Path::new(env!("OUT_DIR")).join(name); - path.to_string_lossy().into_owned() - } - - #[test] - #[serial] - fn version() { - assert!(run("version").output().is_ok()); - } - - #[test] - #[serial] - fn help() { - assert!(run("help").output().is_ok()); - } - - #[test] - #[serial] - fn nonexistent_subcommand() { - assert!(run("nonexistent").output().is_err()); - } - - #[test] - #[serial] - fn nonexistent_option() { - assert!(run("version --turbo").output().is_err()); - } - - #[test] - #[serial] - fn start() { - let mut rng = thread_rng(); - - let port: u16 = rng.gen_range(13000..14000); - let addr = format!("127.0.0.1:{port}"); - - let pass = rng.gen::().to_string(); - - let start_args = - format!("start --bind {addr} --user root --pass {pass} memory --no-banner --log info"); - - println!("starting server with args: {start_args}"); - - let _server = run(&start_args); - - std::thread::sleep(std::time::Duration::from_millis(5000)); - - assert!(run(&format!("isready --conn http://{addr}")).output().is_ok()); - - // Create a record - { - let args = - format!("sql --conn http://{addr} --user root --pass {pass} --ns N --db D --multi"); - assert_eq!( - run(&args).input("CREATE thing:one;\n").output(), - Ok("[{ id: thing:one }]\n\n".to_owned()), - "failed to send sql: {args}" - ); - } - - // Export to stdout - { - let args = - format!("export --conn http://{addr} --user root --pass {pass} --ns N --db D -"); - let output = run(&args).output().expect("failed to run stdout export: {args}"); - assert!(output.contains("DEFINE TABLE thing SCHEMALESS PERMISSIONS NONE;")); - assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };")); - } - - // Export to file - let exported = { - let exported = tmp_file("exported.surql"); - let args = format!( - "export --conn http://{addr} --user root --pass {pass} --ns N --db D {exported}" - ); - run(&args).output().expect("failed to run file export: {args}"); - exported - }; - - // Import the exported file - { - let args = format!( - "import --conn http://{addr} --user root --pass {pass} --ns N --db D2 {exported}" - ); - run(&args).output().expect("failed to run import: {args}"); - } - - // Query from the import (pretty-printed this time) - { - let args = format!( - "sql --conn http://{addr} --user root --pass {pass} --ns N --db D2 --pretty" - ); - assert_eq!( - run(&args).input("SELECT * FROM thing;\n").output(), - Ok("[\n\t{\n\t\tid: thing:one\n\t}\n]\n\n".to_owned()), - "failed to send sql: {args}" - ); - } - - // Unfinished backup CLI - { - let file = tmp_file("backup.db"); - let args = format!("backup --user root --pass {pass} http://{addr} {file}"); - run(&args).output().expect("failed to run backup: {args}"); - - // TODO: Once backups are functional, update this test. - assert_eq!(fs::read_to_string(file).unwrap(), "Save"); - } - - // Multi-statement (and multi-line) query including error(s) over WS - { - let args = format!( - "sql --conn ws://{addr} --user root --pass {pass} --ns N3 --db D3 --multi --pretty" - ); - let output = run(&args) - .input( - r#"CREATE thing:success; \ - CREATE thing:fail SET bad=rand('evil'); \ - SELECT * FROM sleep(10ms) TIMEOUT 1ms; \ - CREATE thing:also_success; - "#, - ) - .output() - .unwrap(); - - assert!(output.contains("thing:success"), "missing success in {output}"); - assert!(output.contains("rgument"), "missing argument error in {output}"); - assert!( - output.contains("time") && output.contains("out"), - "missing timeout error in {output}" - ); - assert!(output.contains("thing:also_success"), "missing also_success in {output}") - } - - // Multi-statement (and multi-line) transaction including error(s) over WS - { - let args = format!( - "sql --conn ws://{addr} --user root --pass {pass} --ns N4 --db D4 --multi --pretty" - ); - let output = run(&args) - .input( - r#"BEGIN; \ - CREATE thing:success; \ - CREATE thing:fail SET bad=rand('evil'); \ - SELECT * FROM sleep(10ms) TIMEOUT 1ms; \ - CREATE thing:also_success; \ - COMMIT; - "#, - ) - .output() - .unwrap(); - - assert_eq!( - output.lines().filter(|s| s.contains("transaction")).count(), - 3, - "missing failed txn errors in {output:?}" - ); - assert!(output.contains("rgument"), "missing argument error in {output}"); - } - - // Pass neither ns nor db - { - let args = format!("sql --conn http://{addr} --user root --pass {pass}"); - let output = run(&args) - .input("USE NS N5 DB D5; CREATE thing:one;\n") - .output() - .expect("neither ns nor db"); - assert!(output.contains("thing:one"), "missing thing:one in {output}"); - } - - // Pass only ns - { - let args = format!("sql --conn http://{addr} --user root --pass {pass} --ns N5"); - let output = run(&args) - .input("USE DB D5; SELECT * FROM thing:one;\n") - .output() - .expect("only ns"); - assert!(output.contains("thing:one"), "missing thing:one in {output}"); - } - - // Pass only db and expect an error - { - let args = format!("sql --conn http://{addr} --user root --pass {pass} --db D5"); - run(&args).output().expect_err("only db"); - } - } - - #[test] - #[serial] - fn start_tls() { - let mut rng = thread_rng(); - - let port: u16 = rng.gen_range(13000..14000); - let addr = format!("127.0.0.1:{port}"); - - let pass = rng.gen::().to_string(); - - // Test the crt/key args but the keys are self signed so don't actually connect. - let crt_path = tmp_file("crt.crt"); - let key_path = tmp_file("key.pem"); - - let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap(); - fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap(); - fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap(); - - let start_args = format!( - "start --bind {addr} --user root --pass {pass} memory --log info --web-crt {crt_path} --web-key {key_path}" - ); - - println!("starting server with args: {start_args}"); - - let server = run(&start_args); - - std::thread::sleep(std::time::Duration::from_millis(750)); - - let output = server.kill().output().unwrap_err(); - assert!(output.contains("Started web server"), "couldn't start web server: {output}"); - } - - #[test] - #[serial] - fn validate_found_no_files() { - let temp_dir = assert_fs::TempDir::new().unwrap(); - - temp_dir.child("file.txt").touch().unwrap(); - - assert!(run_in_dir("validate", &temp_dir).output().is_err()); - } - - #[test] - #[serial] - fn validate_succeed_for_valid_surql_files() { - let temp_dir = assert_fs::TempDir::new().unwrap(); - - let statement_file = temp_dir.child("statement.surql"); - - statement_file.touch().unwrap(); - statement_file.write_str("CREATE thing:success;").unwrap(); - - assert!(run_in_dir("validate", &temp_dir).output().is_ok()); - } - - #[test] - #[serial] - fn validate_failed_due_to_invalid_glob_pattern() { - let temp_dir = assert_fs::TempDir::new().unwrap(); - - const WRONG_GLOB_PATTERN: &str = "**/*{.txt"; - - let args = format!("validate \"{}\"", WRONG_GLOB_PATTERN); - - assert!(run_in_dir(&args, &temp_dir).output().is_err()); - } - - #[test] - #[serial] - fn validate_failed_due_to_invalid_surql_files_syntax() { - let temp_dir = assert_fs::TempDir::new().unwrap(); - - let statement_file = temp_dir.child("statement.surql"); - - statement_file.touch().unwrap(); - statement_file.write_str("CREATE $thing WHERE value = '';").unwrap(); - - assert!(run_in_dir("validate", &temp_dir).output().is_err()); - } -} diff --git a/tests/cli_integration.rs b/tests/cli_integration.rs new file mode 100644 index 00000000..61f70549 --- /dev/null +++ b/tests/cli_integration.rs @@ -0,0 +1,273 @@ +// cargo test --package surreal --bin surreal --no-default-features --features storage-mem --test cli -- cli_integration --nocapture + +mod common; + +use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild}; +use serial_test::serial; +use std::fs; + +use common::{PASS, USER}; + +#[test] +#[serial] +fn version() { + assert!(common::run("version").output().is_ok()); +} + +#[test] +#[serial] +fn help() { + assert!(common::run("help").output().is_ok()); +} + +#[test] +#[serial] +fn nonexistent_subcommand() { + assert!(common::run("nonexistent").output().is_err()); +} + +#[test] +#[serial] +fn nonexistent_option() { + assert!(common::run("version --turbo").output().is_err()); +} + +#[tokio::test] +#[serial] +async fn all_commands() { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let creds = format!("--user {USER} --pass {PASS}"); + // Create a record + { + let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi"); + assert_eq!( + common::run(&args).input("CREATE thing:one;\n").output(), + Ok("[{ id: thing:one }]\n\n".to_owned()), + "failed to send sql: {args}" + ); + } + + // Export to stdout + { + let args = format!("export --conn http://{addr} {creds} --ns N --db D -"); + let output = common::run(&args).output().expect("failed to run stdout export: {args}"); + assert!(output.contains("DEFINE TABLE thing SCHEMALESS PERMISSIONS NONE;")); + assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };")); + } + + // Export to file + let exported = { + let exported = common::tmp_file("exported.surql"); + let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}"); + common::run(&args).output().expect("failed to run file export: {args}"); + exported + }; + + // Import the exported file + { + let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}"); + common::run(&args).output().expect("failed to run import: {args}"); + } + + // Query from the import (pretty-printed this time) + { + let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty"); + assert_eq!( + common::run(&args).input("SELECT * FROM thing;\n").output(), + Ok("[\n\t{\n\t\tid: thing:one\n\t}\n]\n\n".to_owned()), + "failed to send sql: {args}" + ); + } + + // Unfinished backup CLI + { + let file = common::tmp_file("backup.db"); + let args = format!("backup {creds} http://{addr} {file}"); + common::run(&args).output().expect("failed to run backup: {args}"); + + // TODO: Once backups are functional, update this test. + assert_eq!(fs::read_to_string(file).unwrap(), "Save"); + } + + // Multi-statement (and multi-line) query including error(s) over WS + { + let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty"); + let output = common::run(&args) + .input( + r#"CREATE thing:success; \ + CREATE thing:fail SET bad=rand('evil'); \ + SELECT * FROM sleep(10ms) TIMEOUT 1ms; \ + CREATE thing:also_success; + "#, + ) + .output() + .unwrap(); + + assert!(output.contains("thing:success"), "missing success in {output}"); + assert!(output.contains("rgument"), "missing argument error in {output}"); + assert!( + output.contains("time") && output.contains("out"), + "missing timeout error in {output}" + ); + assert!(output.contains("thing:also_success"), "missing also_success in {output}") + } + + // Multi-statement (and multi-line) transaction including error(s) over WS + { + let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty"); + let output = common::run(&args) + .input( + r#"BEGIN; \ + CREATE thing:success; \ + CREATE thing:fail SET bad=rand('evil'); \ + SELECT * FROM sleep(10ms) TIMEOUT 1ms; \ + CREATE thing:also_success; \ + COMMIT; + "#, + ) + .output() + .unwrap(); + + assert_eq!( + output.lines().filter(|s| s.contains("transaction")).count(), + 3, + "missing failed txn errors in {output:?}" + ); + assert!(output.contains("rgument"), "missing argument error in {output}"); + } + + // Pass neither ns nor db + { + let args = format!("sql --conn http://{addr} {creds}"); + let output = common::run(&args) + .input("USE NS N5 DB D5; CREATE thing:one;\n") + .output() + .expect("neither ns nor db"); + assert!(output.contains("thing:one"), "missing thing:one in {output}"); + } + + // Pass only ns + { + let args = format!("sql --conn http://{addr} {creds} --ns N5"); + let output = common::run(&args) + .input("USE DB D5; SELECT * FROM thing:one;\n") + .output() + .expect("only ns"); + assert!(output.contains("thing:one"), "missing thing:one in {output}"); + } + + // Pass only db and expect an error + { + let args = format!("sql --conn http://{addr} {creds} --db D5"); + common::run(&args).output().expect_err("only db"); + } +} + +#[tokio::test] +#[serial] +async fn start_tls() { + let (_, server) = common::start_server(true, false).await.unwrap(); + + std::thread::sleep(std::time::Duration::from_millis(2000)); + let output = server.kill().output().err().unwrap(); + + // Test the crt/key args but the keys are self signed so don't actually connect. + assert!(output.contains("Started web server"), "couldn't start web server: {output}"); +} + +#[tokio::test] +#[serial] +async fn with_root_auth() { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let creds = format!("--user {USER} --pass {PASS}"); + let sql_args = format!("sql --conn http://{addr} --multi --pretty"); + + // Can query /sql over HTTP + { + let args = format!("{sql_args} {creds}"); + let input = "INFO FOR ROOT;"; + let output = common::run(&args).input(input).output(); + assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap()); + } + + // Can query /sql over WS + { + let args = format!("sql --conn ws://{addr} --multi --pretty {creds}"); + let input = "INFO FOR ROOT;"; + let output = common::run(&args).input(input).output(); + assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap()); + } + + // KV user can do exports + let exported = { + let exported = common::tmp_file("exported.surql"); + let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}"); + + common::run(&args).output().unwrap_or_else(|_| panic!("failed to run export: {args}")); + exported + }; + + // KV user can do imports + { + let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}"); + common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}")); + } + + // KV user can do backups + { + let file = common::tmp_file("backup.db"); + let args = format!("backup {creds} http://{addr} {file}"); + common::run(&args).output().unwrap_or_else(|_| panic!("failed to run backup: {args}")); + + // TODO: Once backups are functional, update this test. + assert_eq!(fs::read_to_string(file).unwrap(), "Save"); + } +} + +#[test] +#[serial] +fn validate_found_no_files() { + let temp_dir = assert_fs::TempDir::new().unwrap(); + + temp_dir.child("file.txt").touch().unwrap(); + + assert!(common::run_in_dir("validate", &temp_dir).output().is_err()); +} + +#[test] +#[serial] +fn validate_succeed_for_valid_surql_files() { + let temp_dir = assert_fs::TempDir::new().unwrap(); + + let statement_file = temp_dir.child("statement.surql"); + + statement_file.touch().unwrap(); + statement_file.write_str("CREATE thing:success;").unwrap(); + + assert!(common::run_in_dir("validate", &temp_dir).output().is_ok()); +} + +#[test] +#[serial] +fn validate_failed_due_to_invalid_glob_pattern() { + let temp_dir = assert_fs::TempDir::new().unwrap(); + + const WRONG_GLOB_PATTERN: &str = "**/*{.txt"; + + let args = format!("validate \"{}\"", WRONG_GLOB_PATTERN); + + assert!(common::run_in_dir(&args, &temp_dir).output().is_err()); +} + +#[test] +#[serial] +fn validate_failed_due_to_invalid_surql_files_syntax() { + let temp_dir = assert_fs::TempDir::new().unwrap(); + + let statement_file = temp_dir.child("statement.surql"); + + statement_file.touch().unwrap(); + statement_file.write_str("CREATE $thing WHERE value = '';").unwrap(); + + assert!(common::run_in_dir("validate", &temp_dir).output().is_err()); +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 00000000..9f5104cb --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,140 @@ +#![allow(dead_code)] +use rand::{thread_rng, Rng}; +use std::error::Error; +use std::fs; +use std::path::Path; +use std::process::{Command, Stdio}; +use tokio::time; + +pub const USER: &str = "root"; +pub const PASS: &str = "root"; + +/// Child is a (maybe running) CLI process. It can be killed by dropping it +pub struct Child { + inner: Option, +} + +impl Child { + /// Send some thing to the child's stdin + pub fn input(mut self, input: &str) -> Self { + let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap(); + use std::io::Write; + stdin.write_all(input.as_bytes()).unwrap(); + self + } + + pub fn kill(mut self) -> Self { + self.inner.as_mut().unwrap().kill().unwrap(); + self + } + + /// Read the child's stdout concatenated with its stderr. Returns Ok if the child + /// returns successfully, Err otherwise. + pub fn output(mut self) -> Result { + let output = self.inner.take().unwrap().wait_with_output().unwrap(); + + let mut buf = String::from_utf8(output.stdout).unwrap(); + buf.push_str(&String::from_utf8(output.stderr).unwrap()); + + if output.status.success() { + Ok(buf) + } else { + Err(buf) + } + } +} + +impl Drop for Child { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_mut() { + let _ = inner.kill(); + } + } +} + +pub fn run_internal>(args: &str, current_dir: Option

) -> Child { + let mut path = std::env::current_exe().unwrap(); + assert!(path.pop()); + if path.ends_with("deps") { + assert!(path.pop()); + } + + // Note: Cargo automatically builds this binary for integration tests. + path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX)); + + let mut cmd = Command::new(path); + if let Some(dir) = current_dir { + cmd.current_dir(&dir); + } + cmd.stdin(Stdio::piped()); + cmd.stdout(Stdio::piped()); + cmd.stderr(Stdio::piped()); + cmd.args(args.split_ascii_whitespace()); + Child { + inner: Some(cmd.spawn().unwrap()), + } +} + +/// Run the CLI with the given args +pub fn run(args: &str) -> Child { + run_internal::(args, None) +} + +/// Run the CLI with the given args inside a temporary directory +pub fn run_in_dir>(args: &str, current_dir: P) -> Child { + run_internal(args, Some(current_dir)) +} + +pub fn tmp_file(name: &str) -> String { + let path = Path::new(env!("OUT_DIR")).join(name); + path.to_string_lossy().into_owned() +} + +pub async fn start_server( + tls: bool, + wait_is_ready: bool, +) -> Result<(String, Child), Box> { + let mut rng = thread_rng(); + + let port: u16 = rng.gen_range(13000..14000); + let addr = format!("127.0.0.1:{port}"); + + let mut extra_args = String::default(); + if tls { + // Test the crt/key args but the keys are self signed so don't actually connect. + let crt_path = tmp_file("crt.crt"); + let key_path = tmp_file("key.pem"); + + let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap(); + fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap(); + fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap(); + + extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str()); + } + + let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}"); + + println!("starting server with args: {start_args}"); + + let server = run(&start_args); + + if !wait_is_ready { + return Ok((addr, server)); + } + + // Wait 5 seconds for the server to start + let mut interval = time::interval(time::Duration::from_millis(500)); + println!("Waiting for server to start..."); + for _i in 0..10 { + interval.tick().await; + + if run(&format!("isready --conn http://{addr}")).output().is_ok() { + println!("Server ready!"); + return Ok((addr, server)); + } + } + + let server_out = server.kill().output().err().unwrap(); + println!("server output: {server_out}"); + Err("server failed to start".into()) +} diff --git a/tests/http_integration.rs b/tests/http_integration.rs new file mode 100644 index 00000000..940834e0 --- /dev/null +++ b/tests/http_integration.rs @@ -0,0 +1,1302 @@ +mod common; + +use std::time::Duration; + +use http::{header, Method}; +use reqwest::Client; +use serde_json::json; +use serial_test::serial; + +use crate::common::{PASS, USER}; + +#[tokio::test] +#[serial] +async fn basic_auth() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/sql"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Request without credentials, gives an anonymous session + { + let res = client.post(url).body("CREATE foo").send().await?; + assert_eq!(res.status(), 200); + let body = res.text().await?; + assert!( + body.contains("You don't have permission to perform this query type"), + "body: {}", + body + ); + } + + // Request with invalid credentials, returns 401 + { + let res = + client.post(url).basic_auth("user", Some("pass")).body("CREATE foo").send().await?; + assert_eq!(res.status(), 401); + } + + // Request with valid root credentials, gives a ROOT session + { + let res = client.post(url).basic_auth(USER, Some(PASS)).body("CREATE foo").send().await?; + assert_eq!(res.status(), 200); + let body = res.text().await?; + assert!(body.contains(r#"[{"result":[{"id":"foo:"#), "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn bearer_auth() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/sql"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Create user + { + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .body(r#"DEFINE LOGIN user ON DB PASSWORD 'pass'"#) + .send() + .await?; + assert!(res.status().is_success(), "body: {}", res.text().await?); + } + + // Signin with user and get the token + let token: String; + { + let req_body = serde_json::to_string( + json!({ + "ns": "N", + "db": "D", + "user": "user", + "pass": "pass", + }) + .as_object() + .unwrap(), + ) + .unwrap(); + + let res = client.post(format!("http://{addr}/signin")).body(req_body).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + token = body["token"].as_str().unwrap().to_owned(); + } + + // Request with valid token, gives a LOGIN session + { + let res = client.post(url).bearer_auth(&token).body("CREATE foo").send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + let body = res.text().await?; + assert!(body.contains(r#"[{"result":[{"id":"foo:"#), "body: {}", body); + + // Check the selected namespace and database + let res = client + .post(url) + .header("NS", "N2") + .header("DB", "D2") + .bearer_auth(&token) + .body("SELECT * FROM session::ns(); SELECT * FROM session::db()") + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + let body = res.text().await?; + assert!(body.contains(r#""result":["N"]"#), "body: {}", body); + assert!(body.contains(r#""result":["D"]"#), "body: {}", body); + } + + // Request with invalid token, returns 401 + { + let res = client.post(url).bearer_auth("token").body("CREATE foo").send().await?; + assert_eq!(res.status(), 401, "body: {}", res.text().await?); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn client_ip_extractor() -> Result<(), Box> { + // TODO: test the client IP extractor + Ok(()) +} + +#[tokio::test] +#[serial] +async fn export_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/export"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Create some data + { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(USER, Some(PASS)) + .body("CREATE foo") + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + } + + // When no auth is provided, the endpoint returns a 403 + { + let res = client.get(url).send().await?; + assert_eq!(res.status(), 403, "body: {}", res.text().await?); + } + + // When auth is provided, it returns the contents of the DB + { + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + let body = res.text().await?; + assert!(body.contains("DEFINE TABLE foo"), "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn health_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/health"); + + let res = Client::default().get(url).send().await?; + assert_eq!(res.status(), 200, "response: {:#?}", res); + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn import_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/import"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // When no auth is provided, the endpoint returns a 403 + { + let res = client.post(url).body("").send().await?; + assert_eq!(res.status(), 401, "body: {}", res.text().await?); + } + + // When auth is provided, it persists the import data + { + let data = r#" + -- -------------------------------- + -- OPTION + -- ------------------------------ + + OPTION IMPORT; + + -- ------------------------------ + -- TABLE: foo + -- ------------------------------ + + DEFINE TABLE foo SCHEMALESS PERMISSIONS NONE; + + -- ------------------------------ + -- TRANSACTION + -- ------------------------------ + + BEGIN TRANSACTION; + + -- ------------------------------ + -- TABLE DATA: foo + -- ------------------------------ + + UPDATE foo:bvklxkhtxumyrfzqoc5i CONTENT { id: foo:bvklxkhtxumyrfzqoc5i }; + + -- ------------------------------ + -- TRANSACTION + -- ------------------------------ + + COMMIT TRANSACTION; + "#; + let res = client.post(url).basic_auth(USER, Some(PASS)).body(data).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Check that the data was persisted + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(USER, Some(PASS)) + .body("SELECT * FROM foo") + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + let body = res.text().await?; + assert!(body.contains("foo:bvklxkhtxumyrfzqoc5i"), "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn rpc_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/rpc"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Test WebSocket upgrade + { + let res = client + .get(url) + .header(header::CONNECTION, "Upgrade") + .header(header::UPGRADE, "websocket") + .header(header::SEC_WEBSOCKET_VERSION, "13") + .header(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .send() + .await? + .upgrade() + .await; + assert!(res.is_ok(), "upgrade err: {}", res.unwrap_err()); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn signin_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/signin"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Create a user + { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(USER, Some(PASS)) + .body(r#"DEFINE LOGIN user ON DB PASSWORD 'pass'"#) + .send() + .await?; + assert!(res.status().is_success(), "body: {}", res.text().await?); + } + + // Signin with valid credentials and get the token + { + let req_body = serde_json::to_string( + json!({ + "ns": "N", + "db": "D", + "user": "user", + "pass": "pass", + }) + .as_object() + .unwrap(), + ) + .unwrap(); + + let res = client.post(url).body(req_body).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert!(!body["token"].as_str().unwrap().to_string().is_empty(), "body: {}", body); + } + + // Signin with invalid credentials returns 403 + { + let req_body = serde_json::to_string( + json!({ + "ns": "N", + "db": "D", + "user": "user", + "pass": "invalid_pass", + }) + .as_object() + .unwrap(), + ) + .unwrap(); + + let res = client.post(url).body(req_body).send().await?; + assert_eq!(res.status(), 401, "body: {}", res.text().await?); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn signup_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/signup"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Create a scope + { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(USER, Some(PASS)) + .body( + r#" + DEFINE SCOPE scope SESSION 24h + SIGNUP ( CREATE user SET email = $email, pass = crypto::argon2::generate($pass) ) + SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) ) + ; + "#, + ) + .send() + .await?; + assert!(res.status().is_success(), "body: {}", res.text().await?); + } + + // Signup into the scope + { + let req_body = serde_json::to_string( + json!({ + "ns": "N", + "db": "D", + "sc": "scope", + "email": "email@email.com", + "pass": "pass", + }) + .as_object() + .unwrap(), + ) + .unwrap(); + + let res = client.post(url).body(req_body).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert!(!body["token"].as_str().unwrap().to_string().is_empty(), "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn sql_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/sql"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Options method works + { + let res = client.request(Method::OPTIONS, url).send().await?; + assert_eq!(res.status(), 200); + } + + // Creating a record without credentials is not allowed + { + let res = client.post(url).body("CREATE foo").send().await?; + assert_eq!(res.status(), 200); + + let body = res.text().await?; + assert!( + body.contains("You don't have permission to perform this query type"), + "body: {}", + body + ); + } + + // Creating a record with Accept JSON encoding is allowed + { + let res = client.post(url).basic_auth(USER, Some(PASS)).body("CREATE foo").send().await?; + assert_eq!(res.status(), 200); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["status"], "OK", "body: {}", body); + } + + // Creating a record with Accept CBOR encoding is allowed + { + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .header(header::ACCEPT, "application/cbor") + .body("CREATE foo") + .send() + .await?; + assert_eq!(res.status(), 200); + + let _: serde_cbor::Value = serde_cbor::from_slice(&res.bytes().await?).unwrap(); + } + + // Creating a record with Accept PACK encoding is allowed + { + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .header(header::ACCEPT, "application/pack") + .body("CREATE foo") + .send() + .await?; + assert_eq!(res.status(), 200); + + let _: serde_cbor::Value = serde_pack::from_slice(&res.bytes().await?).unwrap(); + } + + // Creating a record with Accept Surrealdb encoding is allowed + { + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .header(header::ACCEPT, "application/surrealdb") + .body("CREATE foo") + .send() + .await?; + assert_eq!(res.status(), 200); + + // TODO: parse the result + } + + // Creating a record with an unsupported Accept header, returns a 415 + { + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .header(header::ACCEPT, "text/plain") + .body("CREATE foo") + .send() + .await?; + assert_eq!(res.status(), 415); + } + + // Test WebSocket upgrade + { + let res = client + .get(url) + .header(header::CONNECTION, "Upgrade") + .header(header::UPGRADE, "websocket") + .header(header::SEC_WEBSOCKET_VERSION, "13") + .header(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .send() + .await? + .upgrade() + .await; + assert!(res.is_ok(), "upgrade err: {}", res.unwrap_err()); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn sync_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/sync"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // GET + { + let res = client.get(url).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + let body = res.text().await?; + assert_eq!(body, r#"Save"#, "body: {}", body); + } + // POST + { + let res = client.post(url).body("").send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + let body = res.text().await?; + assert_eq!(body, r#"Load"#, "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn version_endpoint() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let url = &format!("http://{addr}/version"); + + let res = Client::default().get(url).send().await?; + assert_eq!(res.status(), 200, "response: {:#?}", res); + let body = res.text().await?; + assert!(body.starts_with("surrealdb-"), "body: {}", body); + + Ok(()) +} + +/// +/// Key endpoint tests +/// + +async fn seed_table( + client: &Client, + addr: &str, + table: &str, + num_records: usize, +) -> Result<(), Box> { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(USER, Some(PASS)) + .body(format!("CREATE |{table}:1..{num_records}| SET default = 'content'")) + .send() + .await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + + assert_eq!( + body[0]["result"].as_array().unwrap().len(), + num_records, + "error seeding the table: {}", + body + ); + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_select_all() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let num_records = 50; + let url = &format!("http://{addr}/key/{table_name}"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Seed the table + seed_table(&client, &addr, table_name, num_records).await?; + + // GET all records + { + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + } + + // GET records with a limit + { + let res = + client.get(format!("{}?limit=10", url)).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 10, "body: {}", body); + } + + // GET records with a start + { + let res = + client.get(format!("{}?start=10", url)).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records - 10, "body: {}", body); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:11", "body: {}", body); + } + + // GET records with a start and limit + { + let res = client + .get(format!("{}?start=10&limit=10", url)) + .basic_auth(USER, Some(PASS)) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 10, "body: {}", body); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:11", "body: {}", body); + } + + // GET without authentication returns no records + { + let res = client.get(url).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 0, "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_create_all() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Create record with random ID + { + let table_name = "table"; + let url = &format!("http://{addr}/key/{table_name}"); + + // Verify there are no records + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 0, "body: {}", body); + + // Try to create the record + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{"name": "record_name"}"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was created + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 1, "body: {}", body); + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["name"], + "record_name", + "body: {}", + body + ); + } + + // POST without authentication creates no records + { + let table_name = "table_noauth"; + let url = &format!("http://{addr}/key/{table_name}"); + + // Try to create the record + let res = client.post(url).body(r#"{"name": "record_name"}"#).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the table is empty + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 0, "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_update_all() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let num_records = 10; + let url = &format!("http://{addr}/key/{table_name}"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + seed_table(&client, &addr, table_name, num_records).await?; + + // Update all records + { + // Try to update the records + let res = client + .put(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{"name": "record_name"}"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were updated + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + + // Verify the records have the new data + for record in body[0]["result"].as_array().unwrap() { + assert_eq!(record["name"], "record_name", "body: {}", body); + } + // Verify the records don't have the original data + for record in body[0]["result"].as_array().unwrap() { + assert!(record["default"].is_null(), "body: {}", body); + } + } + + // Update all records without authentication + { + // Try to update the records + let res = client.put(url).body(r#"{"noauth": "yes"}"#).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were not updated + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + + // Verify the records don't have the new data + for record in body[0]["result"].as_array().unwrap() { + assert!(record["noauth"].is_null(), "body: {}", body); + } + // Verify the records have the original data + for record in body[0]["result"].as_array().unwrap() { + assert_eq!(record["name"], "record_name", "body: {}", body); + } + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_modify_all() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let num_records = 10; + let url = &format!("http://{addr}/key/{table_name}"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + seed_table(&client, &addr, table_name, num_records).await?; + + // Modify all records + { + // Try to modify the records + let res = client + .patch(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{"name": "record_name"}"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were modified + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + + // Verify the records have the new data + for record in body[0]["result"].as_array().unwrap() { + assert_eq!(record["name"], "record_name", "body: {}", body); + } + // Verify the records also have the original data + for record in body[0]["result"].as_array().unwrap() { + assert_eq!(record["default"], "content", "body: {}", body); + } + } + + // Modify all records without authentication + { + // Try to modify the records + let res = client.patch(url).body(r#"{"noauth": "yes"}"#).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were not modified + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + + // Verify the records don't have the new data + for record in body[0]["result"].as_array().unwrap() { + assert!(record["noauth"].is_null(), "body: {}", body); + } + // Verify the records have the original data + for record in body[0]["result"].as_array().unwrap() { + assert_eq!(record["name"], "record_name", "body: {}", body); + } + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_delete_all() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let num_records = 10; + let url = &format!("http://{addr}/key/{table_name}"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Delete all records + { + seed_table(&client, &addr, table_name, num_records).await?; + + // Verify there are records + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + + // Try to delete the records + let res = client.delete(url).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were deleted + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 0, "body: {}", body); + } + + // Delete all records without authentication + { + seed_table(&client, &addr, table_name, num_records).await?; + + // Try to delete the records + let res = client.delete(url).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were not deleted + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), num_records, "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_select_one() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let url = &format!("http://{addr}/key/{table_name}/1"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Seed the table + seed_table(&client, &addr, table_name, 1).await?; + + // GET one record + { + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 1, "body: {}", body); + } + + // GET without authentication returns no record + { + let res = client.get(url).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 0, "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_create_one() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Create record with known ID + { + let url = &format!("http://{addr}/key/{table_name}/new_id"); + + // Try to create the record + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{"name": "record_name"}"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was created with the given ID + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 1, "body: {}", body); + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["id"], + "table:new_id", + "body: {}", + body + ); + } + + // Create record with known ID and query params + { + let url = &format!( + "http://{addr}/key/{table_name}/new_id_query?{params}", + params = "age=45&elems=[1,2,3]&other={test: true}" + ); + + // Try to create the record + let res = client + .post(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{ age: $age, elems: $elems, other: $other }"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was created with the given ID + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 1, "body: {}", body); + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["id"], + "table:new_id_query", + "body: {}", + body + ); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["age"], 45, "body: {}", body); + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["elems"].as_array().unwrap().len(), + 3, + "body: {}", + body + ); + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["other"].as_object().unwrap()["test"], + true, + "body: {}", + body + ); + } + + // POST without authentication creates no records + { + let url = &format!("http://{addr}/key/{table_name}/noauth_id"); + + // Try to create the record + let res = client.post(url).body(r#"{"name": "record_name"}"#).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the table is empty + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 0, "body: {}", body); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_update_one() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let url = &format!("http://{addr}/key/{table_name}/1"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + seed_table(&client, &addr, table_name, 1).await?; + + // Update one record + { + // Try to update the record + let res = client + .put(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{"name": "record_name"}"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was updated + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:1", "body: {}", body); + + // Verify the record has the new data + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["name"], + "record_name", + "body: {}", + body + ); + + // Verify the record doesn't have the original data + assert!(body[0]["result"].as_array().unwrap()[0]["default"].is_null(), "body: {}", body); + } + + // Update one record without authentication + { + // Try to update the record + let res = client.put(url).body(r#"{"noauth": "yes"}"#).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was not updated + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:1", "body: {}", body); + + // Verify the record doesn't have the new data + assert!(body[0]["result"].as_array().unwrap()[0]["noauth"].is_null(), "body: {}", body); + + // Verify the record has the original data + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["name"], + "record_name", + "body: {}", + body + ); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_modify_one() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let url = &format!("http://{addr}/key/{table_name}/1"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + seed_table(&client, &addr, table_name, 1).await?; + + // Modify one record + { + // Try to modify one record + let res = client + .patch(url) + .basic_auth(USER, Some(PASS)) + .body(r#"{"name": "record_name"}"#) + .send() + .await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the records were modified + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:1", "body: {}", body); + + // Verify the record has the new data + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["name"], + "record_name", + "body: {}", + body + ); + + // Verify the record has the original data too + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["default"], + "content", + "body: {}", + body + ); + } + + // Modify one record without authentication + { + // Try to modify the record + let res = client.patch(url).body(r#"{"noauth": "yes"}"#).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was not modified + let res = client.get(url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:1", "body: {}", body); + + // Verify the record doesn't have the new data + assert!(body[0]["result"].as_array().unwrap()[0]["noauth"].is_null(), "body: {}", body); + + // Verify the record has the original data too + assert_eq!( + body[0]["result"].as_array().unwrap()[0]["default"], + "content", + "body: {}", + body + ); + } + + Ok(()) +} + +#[tokio::test] +#[serial] +async fn key_endpoint_delete_one() -> Result<(), Box> { + let (addr, _server) = common::start_server(false, true).await.unwrap(); + let table_name = "table"; + let base_url = &format!("http://{addr}/key/{table_name}"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", "N".parse()?); + headers.insert("DB", "D".parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Delete all records + { + seed_table(&client, &addr, table_name, 2).await?; + + // Verify there are records + let res = client.get(base_url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 2, "body: {}", body); + + // Try to delete the record + let res = + client.delete(format!("{}/1", base_url)).basic_auth(USER, Some(PASS)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify only one record was deleted + let res = client.get(base_url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 1, "body: {}", body); + assert_eq!(body[0]["result"].as_array().unwrap()[0]["id"], "table:2", "body: {}", body); + } + + // Delete one record without authentication + { + // Try to delete the record + let res = client.delete(format!("{}/2", base_url)).send().await?; + assert_eq!(res.status(), 200, "body: {}", res.text().await?); + + // Verify the record was not deleted + let res = client.get(base_url).basic_auth(USER, Some(PASS)).send().await?; + let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); + assert_eq!(body[0]["result"].as_array().unwrap().len(), 1, "body: {}", body); + } + + Ok(()) +}