summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEric Hodel <drbrain@segment7.net>2023-12-27 06:15:48 -0800
committerGitHub <noreply@github.com>2023-12-27 14:15:48 +0000
commitd52e57612942cbe0c6a0dd774fcc2caac8f439d5 (patch)
tree6abc226ffa71156b0ac747529e7effaa21c75c15
parent86f50e0356e4b661be43c2aeba97a67d83910095 (diff)
feat: Add TLS to atuin-server (#1457)
* Add TLS to atuin-server atuin as a project already includes most of the dependencies necessary for server-side TLS. This allows `atuin server start` to use a TLS certificate when self-hosting in order to avoid the complication of wrapping it in a TLS-aware proxy server. Configuration is handled similar to the metrics server with its own struct and currently accepts only the private key and certificate file paths. Starting a TLS server and a TCP server are divergent because the tests need to bind to an arbitrary port to avoid collisions across tests. The API to accomplish this for a TLS server is much more verbose. * Fix clippy, fmt * Add TLS section to self-hosting
-rw-r--r--Cargo.lock33
-rw-r--r--atuin-server/Cargo.toml5
-rw-r--r--atuin-server/server.toml5
-rw-r--r--atuin-server/src/lib.rs75
-rw-r--r--atuin-server/src/settings.rs54
-rw-r--r--atuin/src/command/server.rs5
-rw-r--r--atuin/tests/sync.rs5
-rw-r--r--docs/docs/self-hosting/self-hosting.md11
8 files changed, 175 insertions, 18 deletions
diff --git a/Cargo.lock b/Cargo.lock
index f84e25c1..440f3024 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -110,6 +110,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
+name = "arc-swap"
+version = "1.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
+
+[[package]]
name = "argon2"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -269,15 +275,20 @@ dependencies = [
"atuin-common",
"atuin-server-database",
"axum",
+ "axum-server",
"base64 0.21.5",
"config",
"eyre",
"fs-err",
"http",
+ "hyper",
+ "hyper-rustls",
"metrics",
"metrics-exporter-prometheus",
"rand",
"reqwest",
+ "rustls",
+ "rustls-pemfile",
"semver",
"serde",
"serde_json",
@@ -373,6 +384,26 @@ dependencies = [
]
[[package]]
+name = "axum-server"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063"
+dependencies = [
+ "arc-swap",
+ "bytes",
+ "futures-util",
+ "http",
+ "http-body",
+ "hyper",
+ "pin-project-lite",
+ "rustls",
+ "rustls-pemfile",
+ "tokio",
+ "tokio-rustls",
+ "tower-service",
+]
+
+[[package]]
name = "backtrace"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1445,7 +1476,9 @@ dependencies = [
"futures-util",
"http",
"hyper",
+ "log",
"rustls",
+ "rustls-native-certs",
"tokio",
"tokio-rustls",
]
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml
index 445dfcb7..ecfef524 100644
--- a/atuin-server/Cargo.toml
+++ b/atuin-server/Cargo.toml
@@ -26,11 +26,16 @@ rand = { workspace = true }
tokio = { workspace = true }
async-trait = { workspace = true }
axum = "0.6.4"
+axum-server = { version = "0.5.1", features = ["tls-rustls"] }
http = "0.2"
+hyper = "0.14"
+hyper-rustls = "0.24"
fs-err = { workspace = true }
tower = "0.4"
tower-http = { version = "0.4", features = ["trace"] }
reqwest = { workspace = true }
+rustls = "0.21"
+rustls-pemfile = "1.0"
argon2 = "0.5.0"
semver = { workspace = true }
metrics-exporter-prometheus = "0.12.1"
diff --git a/atuin-server/server.toml b/atuin-server/server.toml
index b2468ddb..946769c9 100644
--- a/atuin-server/server.toml
+++ b/atuin-server/server.toml
@@ -27,3 +27,8 @@
# enable = false
# host = 127.0.0.1
# port = 9001
+
+# [tls]
+# enable = false
+# cert_path = ""
+# pkey_path = ""
diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs
index 2d2a9c78..b505a8ec 100644
--- a/atuin-server/src/lib.rs
+++ b/atuin-server/src/lib.rs
@@ -1,10 +1,13 @@
#![forbid(unsafe_code)]
+use std::net::SocketAddr;
+use std::sync::Arc;
use std::{future::Future, net::TcpListener};
use atuin_server_database::Database;
use axum::Router;
use axum::Server;
+use axum_server::Handle;
use eyre::{Context, Result};
mod handlers;
@@ -12,6 +15,7 @@ mod metrics;
mod router;
mod utils;
+use rustls::ServerConfig;
pub use settings::example_config;
pub use settings::Settings;
@@ -44,27 +48,26 @@ async fn shutdown_signal() {
pub async fn launch<Db: Database>(
settings: Settings<Db::Settings>,
- host: &str,
- port: u16,
+ addr: SocketAddr,
) -> Result<()> {
- launch_with_listener::<Db>(
- settings,
- TcpListener::bind((host, port)).context("could not connect to socket")?,
- shutdown_signal(),
- )
- .await
+ if settings.tls.enable {
+ launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
+ } else {
+ launch_with_tcp_listener::<Db>(
+ settings,
+ TcpListener::bind(addr).context("could not connect to socket")?,
+ shutdown_signal(),
+ )
+ .await
+ }
}
-pub async fn launch_with_listener<Db: Database>(
+pub async fn launch_with_tcp_listener<Db: Database>(
settings: Settings<Db::Settings>,
listener: TcpListener,
shutdown: impl Future<Output = ()>,
) -> Result<()> {
- let db = Db::new(&settings.db_settings)
- .await
- .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
-
- let r = router::router(db, settings);
+ let r = make_router::<Db>(settings).await?;
Server::from_tcp(listener)
.context("could not launch server")?
@@ -75,6 +78,40 @@ pub async fn launch_with_listener<Db: Database>(
Ok(())
}
+async fn launch_with_tls<Db: Database>(
+ settings: Settings<Db::Settings>,
+ addr: SocketAddr,
+ shutdown: impl Future<Output = ()>,
+) -> Result<()> {
+ let certificates = settings.tls.certificates()?;
+ let pkey = settings.tls.private_key()?;
+
+ let server_config = ServerConfig::builder()
+ .with_safe_defaults()
+ .with_no_client_auth()
+ .with_single_cert(certificates, pkey)?;
+
+ let server_config = Arc::new(server_config);
+ let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config);
+
+ let r = make_router::<Db>(settings).await?;
+
+ let handle = Handle::new();
+
+ let server = axum_server::bind_rustls(addr, rustls_config)
+ .handle(handle.clone())
+ .serve(r.into_make_service());
+
+ tokio::select! {
+ _ = server => {}
+ _ = shutdown => {
+ handle.graceful_shutdown(None);
+ }
+ }
+
+ Ok(())
+}
+
// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
// the public.
pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
@@ -95,3 +132,13 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
Ok(())
}
+
+async fn make_router<Db: Database>(
+ settings: Settings<<Db as Database>::Settings>,
+) -> Result<Router, eyre::Error> {
+ let db = Db::new(&settings.db_settings)
+ .await
+ .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
+ let r = router::router(db, settings);
+ Ok(r)
+}
diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs
index d6f1867c..70008fbc 100644
--- a/atuin-server/src/settings.rs
+++ b/atuin-server/src/settings.rs
@@ -1,7 +1,7 @@
use std::{io::prelude::*, path::PathBuf};
use config::{Config, Environment, File as ConfigFile, FileFormat};
-use eyre::{eyre, Result};
+use eyre::{bail, eyre, Context, Result};
use fs_err::{create_dir_all, File};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@@ -36,6 +36,7 @@ pub struct Settings<DbSettings> {
pub register_webhook_url: Option<String>,
pub register_webhook_username: String,
pub metrics: Metrics,
+ pub tls: Tls,
#[serde(flatten)]
pub db_settings: DbSettings,
@@ -67,6 +68,9 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
.set_default("metrics.enable", false)?
.set_default("metrics.host", "127.0.0.1")?
.set_default("metrics.port", 9001)?
+ .set_default("tls.enable", false)?
+ .set_default("tls.cert_path", "")?
+ .set_default("tls.key_path", "")?
.add_source(
Environment::with_prefix("atuin")
.prefix_separator("_")
@@ -97,3 +101,51 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
pub fn example_config() -> &'static str {
EXAMPLE_CONFIG
}
+
+#[derive(Clone, Debug, Default, Deserialize, Serialize)]
+pub struct Tls {
+ pub enable: bool,
+ pub cert_path: PathBuf,
+ pub pkey_path: PathBuf,
+}
+
+impl Tls {
+ pub fn certificates(&self) -> Result<Vec<rustls::Certificate>> {
+ let cert_file = std::fs::File::open(&self.cert_path)
+ .with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?;
+ let mut reader = std::io::BufReader::new(cert_file);
+ let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
+ .with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))?
+ .into_iter()
+ .map(rustls::Certificate)
+ .collect();
+
+ if certs.is_empty() {
+ bail!(
+ "tls.cert_path {:?} must have at least one certificate",
+ self.cert_path
+ );
+ }
+
+ Ok(certs)
+ }
+
+ pub fn private_key(&self) -> Result<rustls::PrivateKey> {
+ let pkey_file = std::fs::File::open(&self.pkey_path)
+ .with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?;
+ let mut reader = std::io::BufReader::new(pkey_file);
+ let keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
+ .with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?;
+
+ if keys.is_empty() {
+ bail!(
+ "tls.pkey_path {:?} must have at least one private key",
+ self.pkey_path
+ );
+ }
+
+ let key = rustls::PrivateKey(keys[0].clone());
+
+ Ok(key)
+ }
+}
diff --git a/atuin/src/command/server.rs b/atuin/src/command/server.rs
index 4bcf19db..d45d6ef8 100644
--- a/atuin/src/command/server.rs
+++ b/atuin/src/command/server.rs
@@ -1,3 +1,5 @@
+use std::net::SocketAddr;
+
use atuin_server_postgres::Postgres;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
@@ -39,6 +41,7 @@ impl Cmd {
let settings = Settings::new().wrap_err("could not load server settings")?;
let host = host.as_ref().unwrap_or(&settings.host).clone();
let port = port.unwrap_or(settings.port);
+ let addr = SocketAddr::new(host.parse()?, port);
if settings.metrics.enable {
tokio::spawn(launch_metrics_server(
@@ -47,7 +50,7 @@ impl Cmd {
));
}
- launch::<Postgres>(settings, &host, port).await
+ launch::<Postgres>(settings, addr).await
}
Self::DefaultConfig => {
println!("{}", example_config());
diff --git a/atuin/tests/sync.rs b/atuin/tests/sync.rs
index 765b9cb8..8c42b171 100644
--- a/atuin/tests/sync.rs
+++ b/atuin/tests/sync.rs
@@ -2,7 +2,7 @@ use std::{env, net::TcpListener, time::Duration};
use atuin_client::api_client;
use atuin_common::{api::AddHistoryRequest, utils::uuid_v7};
-use atuin_server::{launch_with_listener, Settings as ServerSettings};
+use atuin_server::{launch_with_tcp_listener, Settings as ServerSettings};
use atuin_server_postgres::{Postgres, PostgresSettings};
use futures_util::TryFutureExt;
use time::OffsetDateTime;
@@ -38,6 +38,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
register_webhook_username: String::new(),
db_settings: PostgresSettings { db_uri },
metrics: atuin_server::settings::Metrics::default(),
+ tls: atuin_server::settings::Tls::default(),
};
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
@@ -46,7 +47,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
let server = tokio::spawn(async move {
let _tracing_guard = dispatcher::set_default(&dispatch);
- if let Err(e) = launch_with_listener::<Postgres>(
+ if let Err(e) = launch_with_tcp_listener::<Postgres>(
server_settings,
listener,
shutdown_rx.unwrap_or_else(|_| ()),
diff --git a/docs/docs/self-hosting/self-hosting.md b/docs/docs/self-hosting/self-hosting.md
index 8379f43f..621b00f3 100644
--- a/docs/docs/self-hosting/self-hosting.md
+++ b/docs/docs/self-hosting/self-hosting.md
@@ -39,3 +39,14 @@ ATUIN_DB_URI="postgres://user:password@hostname/database"
| `db_uri` | A valid PostgreSQL URI, for saving history (default: false) |
| `path` | A path to prepend to all routes of the server (default: false) |
+### TLS
+
+The server supports TLS through the `[tls]` section:
+
+```toml
+[tls]
+enabled = true
+cert_path = "/path/to/letsencrypt/live/fully.qualified.domain/fullchain.pem"
+pkey_path = "/path/to/letsencrypt/live/fully.qualified.domain/privkey.pem"
+```
+