dzmedia/src/main.rs
2022-12-07 01:37:13 +01:00

108 lines
No EOL
3.1 KiB
Rust

use axum::{
routing::{get, post},
http::StatusCode,
response::IntoResponse,
extract::{Json, State},
Router,
};
use tower_http::{cors::{CorsLayer, Any}, compression::CompressionLayer, trace::TraceLayer};
use http::{Method, header::CONTENT_TYPE};
use serde_json::json;
use serde::Deserialize;
use std::sync::{Arc, RwLock};
mod api;
use api::{APIClient, APIError, Format};
#[derive(Deserialize)]
struct DeezerTrackList {
data: Vec<DeezerTrack>
}
#[derive(Deserialize)]
#[allow(non_snake_case)]
struct DeezerTrack {
TRACK_TOKEN: String
}
async fn root() -> &'static str {
"marecchione gay af"
}
#[derive(Debug, Deserialize)]
struct RequestParams {
formats: Vec<Format>,
ids: Vec<u32>,
}
async fn get_url(State(state): State<Arc<RwLock<APIClient>>>, Json(req): Json<RequestParams>) -> impl IntoResponse {
if req.formats.is_empty() {
return (StatusCode::BAD_REQUEST, "Format list cannot be empty".to_string());
}
if req.ids.is_empty() {
return (StatusCode::BAD_REQUEST, "ID list cannot be empty".to_string());
}
let mut client = state.read().unwrap().clone();
let old_license = client.license_token.clone();
let resp: Result<DeezerTrackList, APIError> = client.api_call("song.getListData", &json!({"sng_ids":req.ids,"array_default":["SNG_ID","TRACK_TOKEN"]})).await;
let track_list;
match resp {
Ok(t) => track_list = t,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
};
if track_list.data.is_empty() {
return (StatusCode::BAD_REQUEST, "No valid IDs found".to_string());
}
let track_tokens: Vec<&str> = track_list.data.iter().map(|t| t.TRACK_TOKEN.as_str()).collect();
let media_result = client.get_media(&req.formats, track_tokens).await;
let media_resp;
match media_result {
Ok(r) => media_resp = r,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Error while getting response from media.deezer.com".to_string()),
};
if client.license_token != old_license {
let mut client_write = state.write().unwrap();
*client_write = client;
}
(StatusCode::OK, media_resp.text().await.unwrap())
}
#[tokio::main]
async fn main() {
let bind_addr = std::env::var("BIND_ADDR").unwrap_or("[::]".to_string());
let port = std::env::var("PORT").unwrap_or("8000".to_string());
let port: u16 = port.parse().unwrap_or(8000);
let shared_state = Arc::new(RwLock::new(APIClient::new()));
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS, Method::HEAD])
.allow_headers([CONTENT_TYPE]);
tracing_subscriber::fmt::init();
let app = Router::new()
.route("/", get(root))
.route("/get_url", post(get_url))
.with_state(shared_state)
.layer(cors)
.layer(CompressionLayer::new())
.layer(TraceLayer::new_for_http());
let bind_addr = format!("{bind_addr}:{port}");
println!("Listening on {bind_addr}");
let bind_addr = bind_addr.parse().unwrap();
axum::Server::bind(&bind_addr)
.serve(app.into_make_service())
.await
.unwrap();
}