Add authentication middleware

This commit is contained in:
roaming97 2025-04-13 19:14:37 -06:00
parent 461ccaa4b9
commit bdee736864
7 changed files with 103 additions and 26 deletions

3
.env.example Executable file
View File

@ -0,0 +1,3 @@
DATABASE_URL="sqlite://app.db"
# hash for 'almond12345'
INSTANCE_KEY_HASH="ee64a01f59ebedeb149f6419b2d4c1510de817e06558fca2e3fcbfcdf29ae4e5"

View File

@ -1,4 +1,3 @@
host = "127.0.0.1"
port = 3000
password = "almond12345"
videos_per_page = 10

View File

@ -1,14 +1,14 @@
use std::{fs, io, path::Path};
use std::fs;
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use thiserror::Error;
use tracing::{error, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub host: String,
pub port: u16,
pub password: String,
pub videos_per_page: usize,
}
@ -17,7 +17,6 @@ impl Default for Config {
Self {
host: "0.0.0.0".into(),
port: 3000,
password: "123456".into(),
videos_per_page: 10,
}
}
@ -27,27 +26,43 @@ impl Default for Config {
#[derive(Debug, Clone)]
pub struct Instance {
pub config: Config,
pub key_hash: String,
pub pool: SqlitePool,
}
#[derive(Debug, Error)]
pub enum NewInstanceError {
#[error("Failed to open TOML configuration file")]
ConfigLoad(#[from] io::Error),
ConfigLoad(#[from] std::io::Error),
#[error("Could not parse TOML configuration")]
ConfigParse(#[from] toml::de::Error),
#[error("Failed to create connection pool")]
PoolConnect(#[from] sqlx::Error),
#[error("Almond API key hash missing from environment variables")]
MissingHashEnv,
}
impl Instance {
pub async fn new<P>(config_file: P, database_url: &str) -> Result<Self, NewInstanceError>
where
P: AsRef<Path>,
{
let config = toml::from_str(&fs::read_to_string(config_file)?)?;
let pool = SqlitePool::connect(database_url).await?;
pub async fn new() -> Result<Self, NewInstanceError> {
let default_database_url = "sqlite://app.db".into();
let database_url =
std::env::var("DATABASE_URL").map_err(|_| {
warn!("Could not find DATABASE_URL environment variable! Using default '{default_database_url}'");
}).unwrap_or(default_database_url);
Ok(Self { config, pool })
let config = toml::from_str(&fs::read_to_string("almond.toml")?)?;
let key_hash = std::env::var("INSTANCE_KEY_HASH").map_err(|_| {
error!("Could not find INSTANCE_KEY_HASH environment variable!");
NewInstanceError::MissingHashEnv
})?;
let pool = SqlitePool::connect(&database_url).await?;
Ok(Self {
config,
key_hash,
pool,
})
}
}

View File

@ -3,11 +3,13 @@ use axum::{
routing::{get, post},
};
use instance::Instance;
use middleware::auth;
use routes::{list_videos, upload_video};
use tokio::signal;
use tracing::info;
mod instance;
mod middleware;
mod routes;
mod video;
@ -16,22 +18,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
dotenvy::dotenv()?;
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite://app.db".into());
let instance = Instance::new("almond.toml", &database_url).await?;
let instance = Instance::new().await?;
info!(
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Password: {:?}\n+ Videos per page: {}",
instance.config.host,
instance.config.port,
instance.config.password,
instance.config.videos_per_page
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Videos per page: {}",
instance.config.host, instance.config.port, instance.config.videos_per_page
);
let address = format!("{}:{}", instance.config.host, instance.config.port);
let almond = Router::new()
.route("/", get(list_videos))
.route("/upload", post(upload_video))
.route_layer(axum::middleware::from_fn_with_state(instance.clone(), auth))
.route("/", get(list_videos))
.with_state(instance);
let listener = tokio::net::TcpListener::bind(address).await?;

57
src/middleware.rs Executable file
View File

@ -0,0 +1,57 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use serde::Deserialize;
use sha3::{Digest, Sha3_256};
use crate::instance::Instance;
#[derive(Debug, PartialEq, Eq, Deserialize)]
pub struct Key(pub String);
#[derive(Debug)]
pub enum KeyError {
Empty,
Invalid,
}
impl Key {
/// Checks if the API key is valid by testing it against a hash.
fn validate(&self, state: &Instance) -> Result<(), KeyError> {
let k = &self.0;
if k.is_empty() {
return Err(KeyError::Empty);
}
let hash = state.key_hash.clone();
if format!("{:x}", Sha3_256::digest(k.as_bytes())) == hash {
Ok(())
} else {
Err(KeyError::Invalid)
}
}
}
pub async fn auth(
State(state): State<Instance>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let Some(auth_header) = req
.headers()
.get("almond-api-key")
.and_then(|h| h.to_str().ok())
else {
tracing::error!("Could not find almond-api-key header");
return Err(StatusCode::UNAUTHORIZED);
};
let key = Key(auth_header.into());
match key.validate(&state) {
Ok(()) => Ok(next.run(req).await),
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}

View File

@ -38,10 +38,10 @@ pub async fn list_videos(
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
let page = query.page.unwrap_or(1).max(1);
let per_page = state.config.videos_per_page;
let total = videos.len();
let pages = (total + per_page - 1).div_ceil(per_page);
let page = query.page.unwrap_or(1).max(1).min(pages);
let start = per_page * (page - 1);
let end = (start + per_page).min(total);
@ -87,19 +87,20 @@ pub async fn upload_video(
.await
.map_err(|e| match e {
VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST,
VideoError::AlreadyExists => StatusCode::OK,
_ => StatusCode::INTERNAL_SERVER_ERROR,
});
match new_video {
Ok(video) => {
match sqlx::query!(
r#"
"
INSERT INTO video (
id, url, youtube_id, title, description, author, author_id, author_url,
views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
",
video.id,
video.url,
video.youtube_id,

View File

@ -14,6 +14,8 @@ pub enum VideoError {
UrlParse(#[from] ParseError),
#[error("URL is an invalid YouTube URL")]
InvalidUrl,
#[error("Video already exists in database")]
AlreadyExists,
#[error("IO Error: {0}")]
IOError(#[from] io::Error),
#[error("Could not serialize info JSON: {0}")]
@ -118,11 +120,12 @@ impl Video {
// ? Uploading a video doesn't mean updating it, make a PUT route for that later
if info_json.exists() {
warn!("Video already exists, skipping yt-dlp task");
} else {
Self::yt_dlp_task(url.as_str(), cookie).await?;
warn!("Video already exists, skipping");
return Err(VideoError::AlreadyExists);
}
Self::yt_dlp_task(url.as_str(), cookie).await?;
let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?;
// info!("Info JSON for {youtube_id}\n{info}");