Add authentication middleware

This commit is contained in:
roaming97 2025-04-13 19:14:37 -06:00
parent 461ccaa4b9
commit bdee736864
7 changed files with 103 additions and 26 deletions

3
.env.example Executable file
View File

@ -0,0 +1,3 @@
DATABASE_URL="sqlite://app.db"
# hash for 'almond12345'
INSTANCE_KEY_HASH="ee64a01f59ebedeb149f6419b2d4c1510de817e06558fca2e3fcbfcdf29ae4e5"

View File

@ -1,4 +1,3 @@
host = "127.0.0.1" host = "127.0.0.1"
port = 3000 port = 3000
password = "almond12345"
videos_per_page = 10 videos_per_page = 10

View File

@ -1,14 +1,14 @@
use std::{fs, io, path::Path}; use std::fs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use thiserror::Error; use thiserror::Error;
use tracing::{error, warn};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config { pub struct Config {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
pub password: String,
pub videos_per_page: usize, pub videos_per_page: usize,
} }
@ -17,7 +17,6 @@ impl Default for Config {
Self { Self {
host: "0.0.0.0".into(), host: "0.0.0.0".into(),
port: 3000, port: 3000,
password: "123456".into(),
videos_per_page: 10, videos_per_page: 10,
} }
} }
@ -27,27 +26,43 @@ impl Default for Config {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Instance { pub struct Instance {
pub config: Config, pub config: Config,
pub key_hash: String,
pub pool: SqlitePool, pub pool: SqlitePool,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum NewInstanceError { pub enum NewInstanceError {
#[error("Failed to open TOML configuration file")] #[error("Failed to open TOML configuration file")]
ConfigLoad(#[from] io::Error), ConfigLoad(#[from] std::io::Error),
#[error("Could not parse TOML configuration")] #[error("Could not parse TOML configuration")]
ConfigParse(#[from] toml::de::Error), ConfigParse(#[from] toml::de::Error),
#[error("Failed to create connection pool")] #[error("Failed to create connection pool")]
PoolConnect(#[from] sqlx::Error), PoolConnect(#[from] sqlx::Error),
#[error("Almond API key hash missing from environment variables")]
MissingHashEnv,
} }
impl Instance { impl Instance {
pub async fn new<P>(config_file: P, database_url: &str) -> Result<Self, NewInstanceError> pub async fn new() -> Result<Self, NewInstanceError> {
where let default_database_url = "sqlite://app.db".into();
P: AsRef<Path>, let database_url =
{ std::env::var("DATABASE_URL").map_err(|_| {
let config = toml::from_str(&fs::read_to_string(config_file)?)?; warn!("Could not find DATABASE_URL environment variable! Using default '{default_database_url}'");
let pool = SqlitePool::connect(database_url).await?; }).unwrap_or(default_database_url);
Ok(Self { config, pool }) let config = toml::from_str(&fs::read_to_string("almond.toml")?)?;
let key_hash = std::env::var("INSTANCE_KEY_HASH").map_err(|_| {
error!("Could not find INSTANCE_KEY_HASH environment variable!");
NewInstanceError::MissingHashEnv
})?;
let pool = SqlitePool::connect(&database_url).await?;
Ok(Self {
config,
key_hash,
pool,
})
} }
} }

View File

@ -3,11 +3,13 @@ use axum::{
routing::{get, post}, routing::{get, post},
}; };
use instance::Instance; use instance::Instance;
use middleware::auth;
use routes::{list_videos, upload_video}; use routes::{list_videos, upload_video};
use tokio::signal; use tokio::signal;
use tracing::info; use tracing::info;
mod instance; mod instance;
mod middleware;
mod routes; mod routes;
mod video; mod video;
@ -16,22 +18,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
dotenvy::dotenv()?; dotenvy::dotenv()?;
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite://app.db".into()); let instance = Instance::new().await?;
let instance = Instance::new("almond.toml", &database_url).await?;
info!( info!(
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Password: {:?}\n+ Videos per page: {}", "Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Videos per page: {}",
instance.config.host, instance.config.host, instance.config.port, instance.config.videos_per_page
instance.config.port,
instance.config.password,
instance.config.videos_per_page
); );
let address = format!("{}:{}", instance.config.host, instance.config.port); let address = format!("{}:{}", instance.config.host, instance.config.port);
let almond = Router::new() let almond = Router::new()
.route("/", get(list_videos))
.route("/upload", post(upload_video)) .route("/upload", post(upload_video))
.route_layer(axum::middleware::from_fn_with_state(instance.clone(), auth))
.route("/", get(list_videos))
.with_state(instance); .with_state(instance);
let listener = tokio::net::TcpListener::bind(address).await?; let listener = tokio::net::TcpListener::bind(address).await?;

57
src/middleware.rs Executable file
View File

@ -0,0 +1,57 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use serde::Deserialize;
use sha3::{Digest, Sha3_256};
use crate::instance::Instance;
#[derive(Debug, PartialEq, Eq, Deserialize)]
pub struct Key(pub String);
#[derive(Debug)]
pub enum KeyError {
Empty,
Invalid,
}
impl Key {
/// Checks if the API key is valid by testing it against a hash.
fn validate(&self, state: &Instance) -> Result<(), KeyError> {
let k = &self.0;
if k.is_empty() {
return Err(KeyError::Empty);
}
let hash = state.key_hash.clone();
if format!("{:x}", Sha3_256::digest(k.as_bytes())) == hash {
Ok(())
} else {
Err(KeyError::Invalid)
}
}
}
pub async fn auth(
State(state): State<Instance>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let Some(auth_header) = req
.headers()
.get("almond-api-key")
.and_then(|h| h.to_str().ok())
else {
tracing::error!("Could not find almond-api-key header");
return Err(StatusCode::UNAUTHORIZED);
};
let key = Key(auth_header.into());
match key.validate(&state) {
Ok(()) => Ok(next.run(req).await),
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}

View File

@ -38,10 +38,10 @@ pub async fn list_videos(
return Err(StatusCode::INTERNAL_SERVER_ERROR); return Err(StatusCode::INTERNAL_SERVER_ERROR);
}; };
let page = query.page.unwrap_or(1).max(1);
let per_page = state.config.videos_per_page; let per_page = state.config.videos_per_page;
let total = videos.len(); let total = videos.len();
let pages = (total + per_page - 1).div_ceil(per_page); let pages = (total + per_page - 1).div_ceil(per_page);
let page = query.page.unwrap_or(1).max(1).min(pages);
let start = per_page * (page - 1); let start = per_page * (page - 1);
let end = (start + per_page).min(total); let end = (start + per_page).min(total);
@ -87,19 +87,20 @@ pub async fn upload_video(
.await .await
.map_err(|e| match e { .map_err(|e| match e {
VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST, VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST,
VideoError::AlreadyExists => StatusCode::OK,
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR,
}); });
match new_video { match new_video {
Ok(video) => { Ok(video) => {
match sqlx::query!( match sqlx::query!(
r#" "
INSERT INTO video ( INSERT INTO video (
id, url, youtube_id, title, description, author, author_id, author_url, id, url, youtube_id, title, description, author, author_id, author_url,
views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#, ",
video.id, video.id,
video.url, video.url,
video.youtube_id, video.youtube_id,

View File

@ -14,6 +14,8 @@ pub enum VideoError {
UrlParse(#[from] ParseError), UrlParse(#[from] ParseError),
#[error("URL is an invalid YouTube URL")] #[error("URL is an invalid YouTube URL")]
InvalidUrl, InvalidUrl,
#[error("Video already exists in database")]
AlreadyExists,
#[error("IO Error: {0}")] #[error("IO Error: {0}")]
IOError(#[from] io::Error), IOError(#[from] io::Error),
#[error("Could not serialize info JSON: {0}")] #[error("Could not serialize info JSON: {0}")]
@ -118,11 +120,12 @@ impl Video {
// ? Uploading a video doesn't mean updating it, make a PUT route for that later // ? Uploading a video doesn't mean updating it, make a PUT route for that later
if info_json.exists() { if info_json.exists() {
warn!("Video already exists, skipping yt-dlp task"); warn!("Video already exists, skipping");
} else { return Err(VideoError::AlreadyExists);
Self::yt_dlp_task(url.as_str(), cookie).await?;
} }
Self::yt_dlp_task(url.as_str(), cookie).await?;
let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?; let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?;
// info!("Info JSON for {youtube_id}\n{info}"); // info!("Info JSON for {youtube_id}\n{info}");