use axum::{ routing::{get, post}, http::StatusCode, response::IntoResponse, extract::{Json, State}, Router, }; use tower_http::{cors::{CorsLayer, Any}, compression::CompressionLayer, trace::TraceLayer}; use http::{Method, header::CONTENT_TYPE}; use serde_json::json; use serde::Deserialize; use std::sync::{Arc, RwLock}; mod api; use api::{APIClient, APIError, Format}; #[derive(Deserialize)] struct DeezerTrackList { data: Vec } #[derive(Deserialize)] #[allow(non_snake_case)] struct DeezerTrack { TRACK_TOKEN: String } async fn root() -> &'static str { "marecchione gay af" } #[derive(Debug, Deserialize)] struct RequestParams { formats: Vec, ids: Vec, } async fn get_url(State(state): State>>, Json(req): Json) -> impl IntoResponse { if req.formats.is_empty() { return (StatusCode::BAD_REQUEST, "Format list cannot be empty".to_string()); } if req.ids.is_empty() { return (StatusCode::BAD_REQUEST, "ID list cannot be empty".to_string()); } let mut client = state.read().unwrap().clone(); let old_license = client.license_token.clone(); let resp: Result = client.api_call("song.getListData", &json!({"sng_ids":req.ids,"array_default":["SNG_ID","TRACK_TOKEN"]})).await; let track_list; match resp { Ok(t) => track_list = t, Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), }; if track_list.data.is_empty() { return (StatusCode::BAD_REQUEST, "No valid IDs found".to_string()); } let track_tokens: Vec<&str> = track_list.data.iter().map(|t| t.TRACK_TOKEN.as_str()).collect(); let media_result = client.get_media(&req.formats, track_tokens).await; let media_resp; match media_result { Ok(r) => media_resp = r, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Error while getting response from media.deezer.com".to_string()), }; if client.license_token != old_license { let mut client_write = state.write().unwrap(); *client_write = client; } (StatusCode::OK, media_resp.text().await.unwrap()) } #[tokio::main] async fn main() { let bind_addr = std::env::var("BIND_ADDR").unwrap_or("[::]".to_string()); let port = std::env::var("PORT").unwrap_or("8000".to_string()); let port: u16 = port.parse().unwrap_or(8000); let shared_state = Arc::new(RwLock::new(APIClient::new())); let cors = CorsLayer::new() .allow_origin(Any) .allow_methods([Method::GET, Method::POST, Method::OPTIONS, Method::HEAD]) .allow_headers([CONTENT_TYPE]); tracing_subscriber::fmt::init(); let app = Router::new() .route("/", get(root)) .route("/get_url", post(get_url)) .with_state(shared_state) .layer(cors) .layer(CompressionLayer::new()) .layer(TraceLayer::new_for_http()); let bind_addr = format!("{bind_addr}:{port}"); println!("Listening on {bind_addr}"); let bind_addr = bind_addr.parse().unwrap(); axum::Server::bind(&bind_addr) .serve(app.into_make_service()) .await .unwrap(); }