Add authentication middleware
This commit is contained in:
parent
461ccaa4b9
commit
bdee736864
3
.env.example
Executable file
3
.env.example
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
DATABASE_URL="sqlite://app.db"
|
||||||
|
# hash for 'almond12345'
|
||||||
|
INSTANCE_KEY_HASH="ee64a01f59ebedeb149f6419b2d4c1510de817e06558fca2e3fcbfcdf29ae4e5"
|
@ -1,4 +1,3 @@
|
|||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
port = 3000
|
port = 3000
|
||||||
password = "almond12345"
|
|
||||||
videos_per_page = 10
|
videos_per_page = 10
|
@ -1,14 +1,14 @@
|
|||||||
use std::{fs, io, path::Path};
|
use std::fs;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tracing::{error, warn};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub host: String,
|
pub host: String,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub password: String,
|
|
||||||
pub videos_per_page: usize,
|
pub videos_per_page: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -17,7 +17,6 @@ impl Default for Config {
|
|||||||
Self {
|
Self {
|
||||||
host: "0.0.0.0".into(),
|
host: "0.0.0.0".into(),
|
||||||
port: 3000,
|
port: 3000,
|
||||||
password: "123456".into(),
|
|
||||||
videos_per_page: 10,
|
videos_per_page: 10,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -27,27 +26,43 @@ impl Default for Config {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Instance {
|
pub struct Instance {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
|
pub key_hash: String,
|
||||||
pub pool: SqlitePool,
|
pub pool: SqlitePool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum NewInstanceError {
|
pub enum NewInstanceError {
|
||||||
#[error("Failed to open TOML configuration file")]
|
#[error("Failed to open TOML configuration file")]
|
||||||
ConfigLoad(#[from] io::Error),
|
ConfigLoad(#[from] std::io::Error),
|
||||||
#[error("Could not parse TOML configuration")]
|
#[error("Could not parse TOML configuration")]
|
||||||
ConfigParse(#[from] toml::de::Error),
|
ConfigParse(#[from] toml::de::Error),
|
||||||
#[error("Failed to create connection pool")]
|
#[error("Failed to create connection pool")]
|
||||||
PoolConnect(#[from] sqlx::Error),
|
PoolConnect(#[from] sqlx::Error),
|
||||||
|
#[error("Almond API key hash missing from environment variables")]
|
||||||
|
MissingHashEnv,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Instance {
|
impl Instance {
|
||||||
pub async fn new<P>(config_file: P, database_url: &str) -> Result<Self, NewInstanceError>
|
pub async fn new() -> Result<Self, NewInstanceError> {
|
||||||
where
|
let default_database_url = "sqlite://app.db".into();
|
||||||
P: AsRef<Path>,
|
let database_url =
|
||||||
{
|
std::env::var("DATABASE_URL").map_err(|_| {
|
||||||
let config = toml::from_str(&fs::read_to_string(config_file)?)?;
|
warn!("Could not find DATABASE_URL environment variable! Using default '{default_database_url}'");
|
||||||
let pool = SqlitePool::connect(database_url).await?;
|
}).unwrap_or(default_database_url);
|
||||||
|
|
||||||
Ok(Self { config, pool })
|
let config = toml::from_str(&fs::read_to_string("almond.toml")?)?;
|
||||||
|
|
||||||
|
let key_hash = std::env::var("INSTANCE_KEY_HASH").map_err(|_| {
|
||||||
|
error!("Could not find INSTANCE_KEY_HASH environment variable!");
|
||||||
|
NewInstanceError::MissingHashEnv
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let pool = SqlitePool::connect(&database_url).await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
key_hash,
|
||||||
|
pool,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
15
src/main.rs
15
src/main.rs
@ -3,11 +3,13 @@ use axum::{
|
|||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use instance::Instance;
|
use instance::Instance;
|
||||||
|
use middleware::auth;
|
||||||
use routes::{list_videos, upload_video};
|
use routes::{list_videos, upload_video};
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
mod instance;
|
mod instance;
|
||||||
|
mod middleware;
|
||||||
mod routes;
|
mod routes;
|
||||||
mod video;
|
mod video;
|
||||||
|
|
||||||
@ -16,22 +18,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
dotenvy::dotenv()?;
|
dotenvy::dotenv()?;
|
||||||
|
|
||||||
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite://app.db".into());
|
let instance = Instance::new().await?;
|
||||||
let instance = Instance::new("almond.toml", &database_url).await?;
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Password: {:?}\n+ Videos per page: {}",
|
"Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Videos per page: {}",
|
||||||
instance.config.host,
|
instance.config.host, instance.config.port, instance.config.videos_per_page
|
||||||
instance.config.port,
|
|
||||||
instance.config.password,
|
|
||||||
instance.config.videos_per_page
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let address = format!("{}:{}", instance.config.host, instance.config.port);
|
let address = format!("{}:{}", instance.config.host, instance.config.port);
|
||||||
|
|
||||||
let almond = Router::new()
|
let almond = Router::new()
|
||||||
.route("/", get(list_videos))
|
|
||||||
.route("/upload", post(upload_video))
|
.route("/upload", post(upload_video))
|
||||||
|
.route_layer(axum::middleware::from_fn_with_state(instance.clone(), auth))
|
||||||
|
.route("/", get(list_videos))
|
||||||
.with_state(instance);
|
.with_state(instance);
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind(address).await?;
|
let listener = tokio::net::TcpListener::bind(address).await?;
|
||||||
|
57
src/middleware.rs
Executable file
57
src/middleware.rs
Executable file
@ -0,0 +1,57 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Request, State},
|
||||||
|
http::StatusCode,
|
||||||
|
middleware::Next,
|
||||||
|
response::Response,
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use sha3::{Digest, Sha3_256};
|
||||||
|
|
||||||
|
use crate::instance::Instance;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Deserialize)]
|
||||||
|
pub struct Key(pub String);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum KeyError {
|
||||||
|
Empty,
|
||||||
|
Invalid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Key {
|
||||||
|
/// Checks if the API key is valid by testing it against a hash.
|
||||||
|
fn validate(&self, state: &Instance) -> Result<(), KeyError> {
|
||||||
|
let k = &self.0;
|
||||||
|
if k.is_empty() {
|
||||||
|
return Err(KeyError::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
let hash = state.key_hash.clone();
|
||||||
|
|
||||||
|
if format!("{:x}", Sha3_256::digest(k.as_bytes())) == hash {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(KeyError::Invalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub async fn auth(
|
||||||
|
State(state): State<Instance>,
|
||||||
|
req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
let Some(auth_header) = req
|
||||||
|
.headers()
|
||||||
|
.get("almond-api-key")
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
else {
|
||||||
|
tracing::error!("Could not find almond-api-key header");
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
};
|
||||||
|
|
||||||
|
let key = Key(auth_header.into());
|
||||||
|
match key.validate(&state) {
|
||||||
|
Ok(()) => Ok(next.run(req).await),
|
||||||
|
Err(_) => Err(StatusCode::UNAUTHORIZED),
|
||||||
|
}
|
||||||
|
}
|
@ -38,10 +38,10 @@ pub async fn list_videos(
|
|||||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
};
|
};
|
||||||
|
|
||||||
let page = query.page.unwrap_or(1).max(1);
|
|
||||||
let per_page = state.config.videos_per_page;
|
let per_page = state.config.videos_per_page;
|
||||||
let total = videos.len();
|
let total = videos.len();
|
||||||
let pages = (total + per_page - 1).div_ceil(per_page);
|
let pages = (total + per_page - 1).div_ceil(per_page);
|
||||||
|
let page = query.page.unwrap_or(1).max(1).min(pages);
|
||||||
|
|
||||||
let start = per_page * (page - 1);
|
let start = per_page * (page - 1);
|
||||||
let end = (start + per_page).min(total);
|
let end = (start + per_page).min(total);
|
||||||
@ -87,19 +87,20 @@ pub async fn upload_video(
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| match e {
|
.map_err(|e| match e {
|
||||||
VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST,
|
VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST,
|
||||||
|
VideoError::AlreadyExists => StatusCode::OK,
|
||||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
});
|
});
|
||||||
|
|
||||||
match new_video {
|
match new_video {
|
||||||
Ok(video) => {
|
Ok(video) => {
|
||||||
match sqlx::query!(
|
match sqlx::query!(
|
||||||
r#"
|
"
|
||||||
INSERT INTO video (
|
INSERT INTO video (
|
||||||
id, url, youtube_id, title, description, author, author_id, author_url,
|
id, url, youtube_id, title, description, author, author_id, author_url,
|
||||||
views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail
|
views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
"#,
|
",
|
||||||
video.id,
|
video.id,
|
||||||
video.url,
|
video.url,
|
||||||
video.youtube_id,
|
video.youtube_id,
|
||||||
|
@ -14,6 +14,8 @@ pub enum VideoError {
|
|||||||
UrlParse(#[from] ParseError),
|
UrlParse(#[from] ParseError),
|
||||||
#[error("URL is an invalid YouTube URL")]
|
#[error("URL is an invalid YouTube URL")]
|
||||||
InvalidUrl,
|
InvalidUrl,
|
||||||
|
#[error("Video already exists in database")]
|
||||||
|
AlreadyExists,
|
||||||
#[error("IO Error: {0}")]
|
#[error("IO Error: {0}")]
|
||||||
IOError(#[from] io::Error),
|
IOError(#[from] io::Error),
|
||||||
#[error("Could not serialize info JSON: {0}")]
|
#[error("Could not serialize info JSON: {0}")]
|
||||||
@ -118,11 +120,12 @@ impl Video {
|
|||||||
|
|
||||||
// ? Uploading a video doesn't mean updating it, make a PUT route for that later
|
// ? Uploading a video doesn't mean updating it, make a PUT route for that later
|
||||||
if info_json.exists() {
|
if info_json.exists() {
|
||||||
warn!("Video already exists, skipping yt-dlp task");
|
warn!("Video already exists, skipping");
|
||||||
} else {
|
return Err(VideoError::AlreadyExists);
|
||||||
Self::yt_dlp_task(url.as_str(), cookie).await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Self::yt_dlp_task(url.as_str(), cookie).await?;
|
||||||
|
|
||||||
let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?;
|
let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?;
|
||||||
|
|
||||||
// info!("Info JSON for {youtube_id}\n{info}");
|
// info!("Info JSON for {youtube_id}\n{info}");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user