diff --git a/.env.example b/.env.example
new file mode 100755
index 0000000..ca0f400
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,3 @@
+DATABASE_URL="sqlite://app.db"
+# hash for 'almond12345'
+INSTANCE_KEY_HASH="ee64a01f59ebedeb149f6419b2d4c1510de817e06558fca2e3fcbfcdf29ae4e5"
\ No newline at end of file
diff --git a/almond.toml.example b/almond.example.toml
similarity index 67%
rename from almond.toml.example
rename to almond.example.toml
index 5e38353..b6b5226 100755
--- a/almond.toml.example
+++ b/almond.example.toml
@@ -1,4 +1,3 @@
host = "127.0.0.1"
port = 3000
-password = "almond12345"
videos_per_page = 10
diff --git a/src/instance.rs b/src/instance.rs
index 3b42e44..234fb18 100755
--- a/src/instance.rs
+++ b/src/instance.rs
@@ -1,14 +1,14 @@
-use std::{fs, io, path::Path};
+use std::fs;
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use thiserror::Error;
+use tracing::{error, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub host: String,
pub port: u16,
- pub password: String,
pub videos_per_page: usize,
}
@@ -17,7 +17,6 @@ impl Default for Config {
Self {
host: "0.0.0.0".into(),
port: 3000,
- password: "123456".into(),
videos_per_page: 10,
}
}
@@ -27,27 +26,43 @@ impl Default for Config {
#[derive(Debug, Clone)]
pub struct Instance {
pub config: Config,
+ pub key_hash: String,
pub pool: SqlitePool,
}
#[derive(Debug, Error)]
pub enum NewInstanceError {
#[error("Failed to open TOML configuration file")]
- ConfigLoad(#[from] io::Error),
+ ConfigLoad(#[from] std::io::Error),
#[error("Could not parse TOML configuration")]
ConfigParse(#[from] toml::de::Error),
#[error("Failed to create connection pool")]
PoolConnect(#[from] sqlx::Error),
+ #[error("Almond API key hash missing from environment variables")]
+ MissingHashEnv,
}
impl Instance {
- pub async fn new
(config_file: P, database_url: &str) -> Result
- where
- P: AsRef,
- {
- let config = toml::from_str(&fs::read_to_string(config_file)?)?;
- let pool = SqlitePool::connect(database_url).await?;
+ pub async fn new() -> Result {
+ let default_database_url = "sqlite://app.db".into();
+ let database_url =
+ std::env::var("DATABASE_URL").map_err(|_| {
+ warn!("Could not find DATABASE_URL environment variable! Using default '{default_database_url}'");
+ }).unwrap_or(default_database_url);
- Ok(Self { config, pool })
+ let config = toml::from_str(&fs::read_to_string("almond.toml")?)?;
+
+ let key_hash = std::env::var("INSTANCE_KEY_HASH").map_err(|_| {
+ error!("Could not find INSTANCE_KEY_HASH environment variable!");
+ NewInstanceError::MissingHashEnv
+ })?;
+
+ let pool = SqlitePool::connect(&database_url).await?;
+
+ Ok(Self {
+ config,
+ key_hash,
+ pool,
+ })
}
}
diff --git a/src/main.rs b/src/main.rs
index d5d559e..19bba3d 100755
--- a/src/main.rs
+++ b/src/main.rs
@@ -3,11 +3,13 @@ use axum::{
routing::{get, post},
};
use instance::Instance;
+use middleware::auth;
use routes::{list_videos, upload_video};
use tokio::signal;
use tracing::info;
mod instance;
+mod middleware;
mod routes;
mod video;
@@ -16,22 +18,19 @@ async fn main() -> Result<(), Box> {
tracing_subscriber::fmt::init();
dotenvy::dotenv()?;
- let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite://app.db".into());
- let instance = Instance::new("almond.toml", &database_url).await?;
+ let instance = Instance::new().await?;
info!(
- "Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Password: {:?}\n+ Videos per page: {}",
- instance.config.host,
- instance.config.port,
- instance.config.password,
- instance.config.videos_per_page
+ "Instance configuration:\n+ Host: {}\n+ Port: {}\n+ Videos per page: {}",
+ instance.config.host, instance.config.port, instance.config.videos_per_page
);
let address = format!("{}:{}", instance.config.host, instance.config.port);
let almond = Router::new()
- .route("/", get(list_videos))
.route("/upload", post(upload_video))
+ .route_layer(axum::middleware::from_fn_with_state(instance.clone(), auth))
+ .route("/", get(list_videos))
.with_state(instance);
let listener = tokio::net::TcpListener::bind(address).await?;
diff --git a/src/middleware.rs b/src/middleware.rs
new file mode 100755
index 0000000..197967c
--- /dev/null
+++ b/src/middleware.rs
@@ -0,0 +1,57 @@
+use axum::{
+ extract::{Request, State},
+ http::StatusCode,
+ middleware::Next,
+ response::Response,
+};
+use serde::Deserialize;
+use sha3::{Digest, Sha3_256};
+
+use crate::instance::Instance;
+
+#[derive(Debug, PartialEq, Eq, Deserialize)]
+pub struct Key(pub String);
+
+#[derive(Debug)]
+pub enum KeyError {
+ Empty,
+ Invalid,
+}
+
+impl Key {
+ /// Checks if the API key is valid by testing it against a hash.
+ fn validate(&self, state: &Instance) -> Result<(), KeyError> {
+ let k = &self.0;
+ if k.is_empty() {
+ return Err(KeyError::Empty);
+ }
+
+ let hash = state.key_hash.clone();
+
+ if format!("{:x}", Sha3_256::digest(k.as_bytes())) == hash {
+ Ok(())
+ } else {
+ Err(KeyError::Invalid)
+ }
+ }
+}
+pub async fn auth(
+ State(state): State,
+ req: Request,
+ next: Next,
+) -> Result {
+ let Some(auth_header) = req
+ .headers()
+ .get("almond-api-key")
+ .and_then(|h| h.to_str().ok())
+ else {
+ tracing::error!("Could not find almond-api-key header");
+ return Err(StatusCode::UNAUTHORIZED);
+ };
+
+ let key = Key(auth_header.into());
+ match key.validate(&state) {
+ Ok(()) => Ok(next.run(req).await),
+ Err(_) => Err(StatusCode::UNAUTHORIZED),
+ }
+}
diff --git a/src/routes.rs b/src/routes.rs
index 1c54683..5865f40 100755
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -38,10 +38,10 @@ pub async fn list_videos(
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
- let page = query.page.unwrap_or(1).max(1);
let per_page = state.config.videos_per_page;
let total = videos.len();
let pages = (total + per_page - 1).div_ceil(per_page);
+ let page = query.page.unwrap_or(1).max(1).min(pages);
let start = per_page * (page - 1);
let end = (start + per_page).min(total);
@@ -87,19 +87,20 @@ pub async fn upload_video(
.await
.map_err(|e| match e {
VideoError::InvalidUrl | VideoError::UrlParse(_) => StatusCode::BAD_REQUEST,
+ VideoError::AlreadyExists => StatusCode::OK,
_ => StatusCode::INTERNAL_SERVER_ERROR,
});
match new_video {
Ok(video) => {
match sqlx::query!(
- r#"
+ "
INSERT INTO video (
id, url, youtube_id, title, description, author, author_id, author_url,
views, upload_date, likes, dislikes, file_name, file_size, sha256, thumbnail
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- "#,
+ ",
video.id,
video.url,
video.youtube_id,
diff --git a/src/video.rs b/src/video.rs
index 3596f66..b923c1e 100755
--- a/src/video.rs
+++ b/src/video.rs
@@ -14,6 +14,8 @@ pub enum VideoError {
UrlParse(#[from] ParseError),
#[error("URL is an invalid YouTube URL")]
InvalidUrl,
+ #[error("Video already exists in database")]
+ AlreadyExists,
#[error("IO Error: {0}")]
IOError(#[from] io::Error),
#[error("Could not serialize info JSON: {0}")]
@@ -118,11 +120,12 @@ impl Video {
// ? Uploading a video doesn't mean updating it, make a PUT route for that later
if info_json.exists() {
- warn!("Video already exists, skipping yt-dlp task");
- } else {
- Self::yt_dlp_task(url.as_str(), cookie).await?;
+ warn!("Video already exists, skipping");
+ return Err(VideoError::AlreadyExists);
}
+ Self::yt_dlp_task(url.as_str(), cookie).await?;
+
let info: Value = serde_json::from_str(&fs::read_to_string(info_json)?)?;
// info!("Info JSON for {youtube_id}\n{info}");