switched to axum, added BIND_ADDR env var

This commit is contained in:
uh wot 2022-08-16 20:36:15 +02:00
parent ae75a24129
commit 6a4bec9e70
Signed by: uhwot
GPG Key ID: CB2454984587B781
4 changed files with 288 additions and 734 deletions

927
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -6,10 +6,12 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
actix-web = "4.0" axum = "0.5"
actix-cors = "0.6" http = "0.2"
tower-http = { version = "0.3", features = ["cors", "trace", "compression-br", "compression-deflate", "compression-gzip"] }
tracing-subscriber = "0.3"
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
reqwest = { version = "0.11", features = ["json", "rustls-tls", "cookies", "gzip"], default-features = false } reqwest = { version = "0.11", features = ["json", "rustls-tls", "cookies", "gzip"], default-features = false }
serde_json = "1.0" serde_json = "1.0"
thiserror = "1.0" thiserror = "1.0"
env_logger = "0.9"

View File

@ -9,7 +9,7 @@ cmd = "./dzmedia"
[env] [env]
PORT = "8080" PORT = "8080"
RUST_LOG = "actix_web=info" RUST_LOG = "tower_http=trace"
[[services]] [[services]]
internal_port = 8080 internal_port = 8080

View File

@ -1,5 +1,11 @@
use actix_web::{App, HttpResponse, HttpServer, Responder, get, post, web, middleware, http::header}; use axum::{
use actix_cors::Cors; routing::{get, post},
http::StatusCode,
response::IntoResponse,
Json, Router, Extension,
};
use tower_http::{cors::{CorsLayer, Any}, compression::CompressionLayer, trace::TraceLayer};
use http::{Method, header::CONTENT_TYPE};
use serde_json::json; use serde_json::json;
use serde::Deserialize; use serde::Deserialize;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
@ -18,8 +24,7 @@ struct DeezerTrack {
TRACK_TOKEN: String TRACK_TOKEN: String
} }
#[get("/")] async fn root() -> &'static str {
async fn index() -> &'static str {
"marecchione gay af" "marecchione gay af"
} }
@ -29,13 +34,12 @@ struct RequestParams {
ids: Vec<u32>, ids: Vec<u32>,
} }
#[post("/get_url")] async fn get_url(Json(req): Json<RequestParams>, Extension(state): Extension<Arc<RwLock<APIClient>>>) -> impl IntoResponse {
async fn get_url(req: web::Json<RequestParams>, state: web::Data<Arc<RwLock<APIClient>>>) -> impl Responder {
if req.formats.is_empty() { if req.formats.is_empty() {
return HttpResponse::BadRequest().body("Format list cannot be empty"); return (StatusCode::BAD_REQUEST, "Format list cannot be empty".to_string());
} }
if req.ids.is_empty() { if req.ids.is_empty() {
return HttpResponse::BadRequest().body("ID list cannot be empty"); return (StatusCode::BAD_REQUEST, "ID list cannot be empty".to_string());
} }
let mut client = state.read().unwrap().clone(); let mut client = state.read().unwrap().clone();
@ -45,11 +49,11 @@ async fn get_url(req: web::Json<RequestParams>, state: web::Data<Arc<RwLock<APIC
let track_list; let track_list;
match resp { match resp {
Ok(t) => track_list = t, Ok(t) => track_list = t,
Err(e) => return HttpResponse::InternalServerError().body(e.to_string()), Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
}; };
if track_list.data.is_empty() { if track_list.data.is_empty() {
return HttpResponse::BadRequest().body("No valid IDs found"); return (StatusCode::BAD_REQUEST, "No valid IDs found".to_string());
} }
let track_tokens: Vec<&str> = track_list.data.iter().map(|t| t.TRACK_TOKEN.as_str()).collect(); let track_tokens: Vec<&str> = track_list.data.iter().map(|t| t.TRACK_TOKEN.as_str()).collect();
@ -58,7 +62,7 @@ async fn get_url(req: web::Json<RequestParams>, state: web::Data<Arc<RwLock<APIC
let media_resp; let media_resp;
match media_result { match media_result {
Ok(r) => media_resp = r, Ok(r) => media_resp = r,
Err(_) => return HttpResponse::InternalServerError().body("Error while getting response from media.deezer.com"), Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Error while getting response from media.deezer.com".to_string()),
}; };
if client.license_token != old_license { if client.license_token != old_license {
@ -66,41 +70,38 @@ async fn get_url(req: web::Json<RequestParams>, state: web::Data<Arc<RwLock<APIC
*client_write = client; *client_write = client;
} }
HttpResponse::Ok() (StatusCode::OK, media_resp.text().await.unwrap())
.content_type("application/json")
.body(media_resp.text().await.unwrap())
} }
#[actix_web::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() {
let bind_addr = std::env::var("BIND_ADDR").unwrap_or("[::]".to_string());
let port = std::env::var("PORT").unwrap_or("8000".to_string()); let port = std::env::var("PORT").unwrap_or("8000".to_string());
let port: u16 = port.parse().unwrap_or(8000); let port: u16 = port.parse().unwrap_or(8000);
let bind_addr = format!("[::]:{}", port); let shared_state = Arc::new(RwLock::new(APIClient::new()));
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS, Method::HEAD])
.allow_headers([CONTENT_TYPE]);
tracing_subscriber::fmt::init();
let app = Router::new()
.route("/", get(root))
.route("/get_url", post(get_url))
.layer(Extension(shared_state))
.layer(cors)
.layer(CompressionLayer::new())
.layer(TraceLayer::new_for_http());
let bind_addr = format!("{bind_addr}:{port}");
println!("Listening on {bind_addr}"); println!("Listening on {bind_addr}");
let bind_addr = bind_addr.parse().unwrap();
let client = web::Data::new( axum::Server::bind(&bind_addr)
Arc::new(RwLock::new(APIClient::new())) .serve(app.into_make_service())
); .await
.unwrap();
env_logger::init();
HttpServer::new(move || {
let cors = Cors::default()
.allow_any_origin()
.allowed_methods(vec!["GET", "POST", "OPTIONS", "HEAD"])
.allowed_header(header::CONTENT_TYPE)
.max_age(3600);
App::new()
.wrap(middleware::Logger::default())
.wrap(cors)
.wrap(middleware::Compress::default())
.app_data(client.clone())
.service(index)
.service(get_url)
})
.bind(bind_addr)?
.run()
.await
} }