Add authentication middleware
This commit is contained in:
parent
461ccaa4b9
commit
bdee736864
3
.env.example
Executable file
3
.env.example
Executable file
@ -0,0 +1,3 @@
|
||||
DATABASE_URL="sqlite://app.db"
|
||||
# hash for 'almond12345'
|
||||
INSTANCE_KEY_HASH="ee64a01f59ebedeb149f6419b2d4c1510de817e06558fca2e3fcbfcdf29ae4e5"
|
@ -1,4 +1,3 @@
|
||||
host = "127.0.0.1"
|
||||
port = 3000
|
||||
password = "almond12345"
|
||||
videos_per_page = 10
|
@ -1,14 +1,14 @@
|
||||
use std::{fs, io, path::Path};
|
||||
use std::fs;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::SqlitePool;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, warn};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub password: String,
|
||||
pub videos_per_page: usize,
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@ impl Default for Config {
|
||||
Self {
|
||||
host: "0.0.0.0".into(),
|
||||
port: 3000,
|
||||
password: "123456".into(),
|
||||
videos_per_page: 10,
|
||||
}
|
||||
}
|
||||
@ -27,27 +26,43 @@ impl Default for Config {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Instance {
|
||||
pub config: Config,
|
||||
pub key_hash: String,
|
||||
pub pool: SqlitePool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NewInstanceError {
|
||||
#[error("Failed to open TOML configuration file")]
|
||||
ConfigLoad(#[from] io::Error),
|
||||
ConfigLoad(#[from] std::io::Error),
|
||||
#[error("Could not parse TOML configuration")]
|
||||
ConfigParse(#[from] toml::de::Error),
|
||||
#[error("Failed to create connection pool")]
|
||||
PoolConnect(#[from] sqlx::Error),
|
||||
#[error("Almond API key hash missing from environment variables")]
|
||||
MissingHashEnv,
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
pub async fn new<P>(config_file: P, database_url: &str) -> Result<Self, NewInstanceError>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let config = toml::from_str(&fs::read_to_string(config_file)?)?;
|
||||
let pool = SqlitePool::connect(database_url).await?;
|
||||
pub async fn new() -> Result<Self, NewInstanceError> {
|
||||
let default_database_url = "sqlite://app.db".into();
|
||||
let database_url =
|
||||
std::env::var("DATABASE_URL").map_err(|_| {
|
||||
warn!("Could not find DATABASE_URL environment variable! Using default '{default_database_url}'");
|
||||
}).unwrap_or(default_database_url);
|
||||
|
||||
Ok(Self { config, pool })
|
||||
let config = toml::from_str(&fs::read_to_string("almond.toml")?)?;
|
||||
|
||||
let key_hash = std::env::var("INSTANCE_KEY_HASH").map_err(|_| {
|
||||
error!("Could not find INSTANCE_KEY_HASH environment variable!");
|
||||
NewInstanceError::MissingHashEnv
|
||||
})?;
|
||||
|
||||
let pool = SqlitePool::connect(&database_url).await?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
key_hash,
|
||||
pool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
15
src/main.rs
15
src/main.rs
@ -3,11 +3,13 @@ use axum::{
|
||||
routing::{get, post},
|
||||
};
|
||||
use instance::Instance;
|
||||
use middleware::auth;
|
||||
use routes::{list_videos, upload_video};
|
||||
use tokio::signal;
|
||||
use tracing::info;
|
||||
|
||||
mod instance;
|
||||
mod middleware;
|
||||
mod routes;
|
||||
mod video;
|
||||
|
||||
@ -16,22 +18,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
dotenvy::dotenv()?;
|
||||
|
||||
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite://app.db".into());
|
||||
let instance = Instance::new("almond.toml", &database_url).await?;
|
||||
let instance = Instance::new().await?;
|
||||
|
||||
info!(
|
||||
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Password: {:?}\n+ Videos per page: {}",
|
||||
instance.config.host,
|
||||
instance.config.port,
|
||||
instance.config.password,
|
||||
instance.config.videos_per_page
|
||||
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Videos per page: {}",
|
||||
instance.config.host, instance.config.port, instance.config.videos_per_page
|
||||
);
|
||||
|
||||
let address = format!("{}:{}", instance.config.host, instance.config.port);
|
||||
|
||||
let almond = Router::new()
|
||||
.route("/", get(list_videos))
|
||||
.route("/upload", post(upload_video))
|
||||
.route_layer(axum::middleware::from_fn_with_state(instance.clone(), auth))
|
||||
.route("/", get(list_videos))
|
||||
.with_state(instance);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(address).await?;
|
||||
|
57
src/middleware.rs
Executable file
57
src/middleware.rs
Executable file
@ -0,0 +1,57 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use sha3::{Digest, Sha3_256};
|
||||
|
||||
use crate::instance::Instance;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize)]
|
||||
pub struct Key(pub String);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum KeyError {
|
||||
Empty,
|
||||
Invalid,
|
||||
}
|
||||
|
||||
impl Key {
|
||||
/// Checks if the API key is valid by testing it against a hash.
|
||||
fn validate(&self, state: &Instance) -> Result<(), KeyError> {
|
||||
let k = &self.0;
|
||||
if k.is_empty() {
|
||||
return Err(KeyError::Empty);
|
||||
}
|
||||
|
||||
let hash = state.key_hash.clone();
|
||||
|
||||
if format!("{:x}", Sha3_256::digest(k.as_bytes())) == hash {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KeyError::Invalid)
|
||||
}
|
||||
}
|
||||
}
|
||||
pub async fn auth(
|
||||
State(state): State<Instance>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let Some(auth_header) = req
|
||||
.headers()
|
||||
.get("almond-api-key")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
else {
|
||||
tracing::error!("Could not find almond-api-key header");
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
let key = Key(auth_header.into());
|
||||
match key.validate(&state) {
|
||||
Ok(()) => Ok(next.run(req).await),
|
||||
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
}
|
@ -38,10 +38,10 @@ pub async fn list_videos(
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
let page = query.page.unwrap_or(1).max(1);
|
||||
let per_page = state.config.videos_per_page;
|
||||
let total = videos.len();
|
||||
let pages = (total + per_page - 1).div_ceil(per_page);
|
||||
let page = query.page.unwrap_or(1).max(1).min(pages);
|
||||
|
||||
let start = per_page * (page - 1);
|
||||
let end = (start + per_page).min(total);
|
||||
@ -87,19 +87,20 @@ pub async fn upload_video(
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST,
|
||||
VideoError::AlreadyExists => StatusCode::OK,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
});
|
||||
|
||||
match new_video {
|
||||
Ok(video) => {
|
||||
match sqlx::query!(
|
||||
r#"
|
||||
"
|
||||
INSERT INTO video (
|
||||
id, url, youtube_id, title, description, author, author_id, author_url,
|
||||
views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
",
|
||||
video.id,
|
||||
video.url,
|
||||
video.youtube_id,
|
||||
|
@ -14,6 +14,8 @@ pub enum VideoError {
|
||||
UrlParse(#[from] ParseError),
|
||||
#[error("URL is an invalid YouTube URL")]
|
||||
InvalidUrl,
|
||||
#[error("Video already exists in database")]
|
||||
AlreadyExists,
|
||||
#[error("IO Error: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("Could not serialize info JSON: {0}")]
|
||||
@ -118,11 +120,12 @@ impl Video {
|
||||
|
||||
// ? Uploading a video doesn't mean updating it, make a PUT route for that later
|
||||
if info_json.exists() {
|
||||
warn!("Video already exists, skipping yt-dlp task");
|
||||
} else {
|
||||
Self::yt_dlp_task(url.as_str(), cookie).await?;
|
||||
warn!("Video already exists, skipping");
|
||||
return Err(VideoError::AlreadyExists);
|
||||
}
|
||||
|
||||
Self::yt_dlp_task(url.as_str(), cookie).await?;
|
||||
|
||||
let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?;
|
||||
|
||||
// info!("Info JSON for {youtube_id}\n{info}");
|
||||
|
Loading…
x
Reference in New Issue
Block a user