From bdee73686424bfd159bb8d20225415bc5e7ec1ff Mon Sep 17 00:00:00 2001 From: roaming97 Date: Sun, 13 Apr 2025 19:14:37 -0600 Subject: [PATCH] Add authentication middleware --- .env.example | 3 ++ almond.toml.example => almond.example.toml | 1 - src/instance.rs | 37 +++++++++----- src/main.rs | 15 +++--- src/middleware.rs | 57 ++++++++++++++++++++++ src/routes.rs | 7 +-- src/video.rs | 9 ++-- 7 files changed, 103 insertions(+), 26 deletions(-) create mode 100755 .env.example rename almond.toml.example => almond.example.toml (67%) create mode 100755 src/middleware.rs diff --git a/.env.example b/.env.example new file mode 100755 index 0000000..ca0f400 --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +DATABASE_URL="sqlite://app.db" +# hash for 'almond12345' +INSTANCE_KEY_HASH="ee64a01f59ebedeb149f6419b2d4c1510de817e06558fca2e3fcbfcdf29ae4e5" \ No newline at end of file diff --git a/almond.toml.example b/almond.example.toml similarity index 67% rename from almond.toml.example rename to almond.example.toml index 5e38353..b6b5226 100755 --- a/almond.toml.example +++ b/almond.example.toml @@ -1,4 +1,3 @@ host = "127.0.0.1" port = 3000 -password = "almond12345" videos_per_page = 10 diff --git a/src/instance.rs b/src/instance.rs index 3b42e44..234fb18 100755 --- a/src/instance.rs +++ b/src/instance.rs @@ -1,14 +1,14 @@ -use std::{fs, io, path::Path}; +use std::fs; use serde::{Deserialize, Serialize}; use sqlx::SqlitePool; use thiserror::Error; +use tracing::{error, warn}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { pub host: String, pub port: u16, - pub password: String, pub videos_per_page: usize, } @@ -17,7 +17,6 @@ impl Default for Config { Self { host: "0.0.0.0".into(), port: 3000, - password: "123456".into(), videos_per_page: 10, } } @@ -27,27 +26,43 @@ impl Default for Config { #[derive(Debug, Clone)] pub struct Instance { pub config: Config, + pub key_hash: String, pub pool: SqlitePool, } #[derive(Debug, Error)] pub enum NewInstanceError { #[error("Failed to open TOML configuration file")] - ConfigLoad(#[from] io::Error), + ConfigLoad(#[from] std::io::Error), #[error("Could not parse TOML configuration")] ConfigParse(#[from] toml::de::Error), #[error("Failed to create connection pool")] PoolConnect(#[from] sqlx::Error), + #[error("Almond API key hash missing from environment variables")] + MissingHashEnv, } impl Instance { - pub async fn new

(config_file: P, database_url: &str) -> Result - where - P: AsRef, - { - let config = toml::from_str(&fs::read_to_string(config_file)?)?; - let pool = SqlitePool::connect(database_url).await?; + pub async fn new() -> Result { + let default_database_url = "sqlite://app.db".into(); + let database_url = + std::env::var("DATABASE_URL").map_err(|_| { + warn!("Could not find DATABASE_URL environment variable! Using default '{default_database_url}'"); + }).unwrap_or(default_database_url); - Ok(Self { config, pool }) + let config = toml::from_str(&fs::read_to_string("almond.toml")?)?; + + let key_hash = std::env::var("INSTANCE_KEY_HASH").map_err(|_| { + error!("Could not find INSTANCE_KEY_HASH environment variable!"); + NewInstanceError::MissingHashEnv + })?; + + let pool = SqlitePool::connect(&database_url).await?; + + Ok(Self { + config, + key_hash, + pool, + }) } } diff --git a/src/main.rs b/src/main.rs index d5d559e..19bba3d 100755 --- a/src/main.rs +++ b/src/main.rs @@ -3,11 +3,13 @@ use axum::{ routing::{get, post}, }; use instance::Instance; +use middleware::auth; use routes::{list_videos, upload_video}; use tokio::signal; use tracing::info; mod instance; +mod middleware; mod routes; mod video; @@ -16,22 +18,19 @@ async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); dotenvy::dotenv()?; - let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite://app.db".into()); - let instance = Instance::new("almond.toml", &database_url).await?; + let instance = Instance::new().await?; info!( - "Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Password: {:?}\n+ Videos per page: {}", - instance.config.host, - instance.config.port, - instance.config.password, - instance.config.videos_per_page + "Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Videos per page: {}", + instance.config.host, instance.config.port, instance.config.videos_per_page ); let address = format!("{}:{}", instance.config.host, instance.config.port); let almond = Router::new() - .route("/", get(list_videos)) .route("/upload", post(upload_video)) + .route_layer(axum::middleware::from_fn_with_state(instance.clone(), auth)) + .route("/", get(list_videos)) .with_state(instance); let listener = tokio::net::TcpListener::bind(address).await?; diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100755 index 0000000..197967c --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,57 @@ +use axum::{ + extract::{Request, State}, + http::StatusCode, + middleware::Next, + response::Response, +}; +use serde::Deserialize; +use sha3::{Digest, Sha3_256}; + +use crate::instance::Instance; + +#[derive(Debug, PartialEq, Eq, Deserialize)] +pub struct Key(pub String); + +#[derive(Debug)] +pub enum KeyError { + Empty, + Invalid, +} + +impl Key { + /// Checks if the API key is valid by testing it against a hash. + fn validate(&self, state: &Instance) -> Result<(), KeyError> { + let k = &self.0; + if k.is_empty() { + return Err(KeyError::Empty); + } + + let hash = state.key_hash.clone(); + + if format!("{:x}", Sha3_256::digest(k.as_bytes())) == hash { + Ok(()) + } else { + Err(KeyError::Invalid) + } + } +} +pub async fn auth( + State(state): State, + req: Request, + next: Next, +) -> Result { + let Some(auth_header) = req + .headers() + .get("almond-api-key") + .and_then(|h| h.to_str().ok()) + else { + tracing::error!("Could not find almond-api-key header"); + return Err(StatusCode::UNAUTHORIZED); + }; + + let key = Key(auth_header.into()); + match key.validate(&state) { + Ok(()) => Ok(next.run(req).await), + Err(_) => Err(StatusCode::UNAUTHORIZED), + } +} diff --git a/src/routes.rs b/src/routes.rs index 1c54683..5865f40 100755 --- a/src/routes.rs +++ b/src/routes.rs @@ -38,10 +38,10 @@ pub async fn list_videos( return Err(StatusCode::INTERNAL_SERVER_ERROR); }; - let page = query.page.unwrap_or(1).max(1); let per_page = state.config.videos_per_page; let total = videos.len(); let pages = (total + per_page - 1).div_ceil(per_page); + let page = query.page.unwrap_or(1).max(1).min(pages); let start = per_page * (page - 1); let end = (start + per_page).min(total); @@ -87,19 +87,20 @@ pub async fn upload_video( .await .map_err(|e| match e { VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST, + VideoError::AlreadyExists => StatusCode::OK, _ => StatusCode::INTERNAL_SERVER_ERROR, }); match new_video { Ok(video) => { match sqlx::query!( - r#" + " INSERT INTO video ( id, url, youtube_id, title, description, author, author_id, author_url, views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - "#, + ", video.id, video.url, video.youtube_id, diff --git a/src/video.rs b/src/video.rs index 3596f66..b923c1e 100755 --- a/src/video.rs +++ b/src/video.rs @@ -14,6 +14,8 @@ pub enum VideoError { UrlParse(#[from] ParseError), #[error("URL is an invalid YouTube URL")] InvalidUrl, + #[error("Video already exists in database")] + AlreadyExists, #[error("IO Error: {0}")] IOError(#[from] io::Error), #[error("Could not serialize info JSON: {0}")] @@ -118,11 +120,12 @@ impl Video { // ? Uploading a video doesn't mean updating it, make a PUT route for that later if info_json.exists() { - warn!("Video already exists, skipping yt-dlp task"); - } else { - Self::yt_dlp_task(url.as_str(), cookie).await?; + warn!("Video already exists, skipping"); + return Err(VideoError::AlreadyExists); } + Self::yt_dlp_task(url.as_str(), cookie).await?; + let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?; // info!("Info JSON for {youtube_id}\n{info}");