diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b97b0a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +.aider* +.env diff --git a/r2client/Cargo.toml b/r2client/Cargo.toml new file mode 100644 index 0000000..6c11923 --- /dev/null +++ b/r2client/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "r2client" +version = "0.2.0" +edition = "2024" + +[lib] + +[dependencies] +reqwest = "0.12.19" +xmltree = "0.11.0" +thiserror = "2" +http = "1.3.1" +aws_sigv4 = { path = "../aws_sigv4/" } +log = "0.4.28" + +[dev-dependencies] +tokio = { version = "1", features = ["full", "macros", "rt-multi-thread"] } +dotenv = "0.15" + +[features] +async = [] +default = ["async"] +sync = ["reqwest/blocking"] diff --git a/r2client/src/_async.rs b/r2client/src/_async.rs new file mode 100644 index 0000000..57daa66 --- /dev/null +++ b/r2client/src/_async.rs @@ -0,0 +1,4 @@ +mod r2bucket; +mod r2client; +pub use r2bucket::R2Bucket; +pub use r2client::R2Client; diff --git a/r2client/src/_async/r2bucket.rs b/r2client/src/_async/r2bucket.rs new file mode 100644 index 0000000..692a4fe --- /dev/null +++ b/r2client/src/_async/r2bucket.rs @@ -0,0 +1,62 @@ +use crate::_async::R2Client; +use crate::R2Error; + +#[derive(Debug)] +pub struct R2Bucket { + bucket: String, + pub client: R2Client, +} + +impl R2Bucket { + pub fn new(bucket: String) -> Self { + Self { + bucket, + client: R2Client::new(), + } + } + + pub fn from_client(bucket: String, client: R2Client) -> Self { + Self { bucket, client } + } + + pub fn from_credentials( + bucket: String, + access_key: String, + secret_key: String, + endpoint: String, + ) -> Self { + let client = R2Client::from_credentials(access_key, secret_key, endpoint); + Self { bucket, client } + } + + pub async fn upload_file( + &self, + local_file_path: &str, + r2_file_key: &str, + ) -> Result<(), R2Error> { + self.client + // I'm pasing None to let the R2Client derive the content type from the local_file_path + .upload_file(&self.bucket, local_file_path, r2_file_key, None) + .await + } + + pub async fn download_file(&self, r2_file_key: &str, local_path: &str) -> Result<(), R2Error> { + self.client + .download_file(&self.bucket, r2_file_key, local_path, None) + .await + } + + pub async fn list_files( + &self, + ) -> Result>, R2Error> { + self.client.list_files(&self.bucket).await + } + + pub async fn list_folders(&self) -> Result, R2Error> { + self.client.list_folders(&self.bucket).await + } + + pub async fn delete_file(&self, r2_file_key: &str) -> Result<(), R2Error> { + self.client.delete(&self.bucket, r2_file_key).await + } +} diff --git a/r2client/src/_async/r2client.rs b/r2client/src/_async/r2client.rs new file mode 100644 index 0000000..4d2c109 --- /dev/null +++ b/r2client/src/_async/r2client.rs @@ -0,0 +1,304 @@ +use crate::R2Error; +use crate::mimetypes::get_mimetype_from_fp; +use aws_sigv4::SigV4Credentials; +use http::Method; +use log::trace; +use reqwest::header::HeaderMap; +use std::collections::HashMap; +use std::str::FromStr; + +#[derive(Debug)] +pub struct R2Client { + sigv4: SigV4Credentials, + endpoint: String, +} +impl R2Client { + fn get_env() -> Result<(String, String, String), R2Error> { + let keys = ["R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_ENDPOINT"]; + let values = keys + .map(|key| { std::env::var(key).map_err(|_| R2Error::Env(key.to_owned())) }.unwrap()); + Ok(values.into()) + } + + pub fn new() -> Self { + let (access_key, secret_key, endpoint) = Self::get_env().unwrap(); + + Self { + sigv4: SigV4Credentials::new("s3", "auto", access_key, secret_key), + endpoint, + } + } + + pub fn from_credentials(access_key: String, secret_key: String, endpoint: String) -> Self { + Self { + sigv4: SigV4Credentials::new("s3", "auto", access_key, secret_key), + endpoint, + } + } + + fn create_headers( + &self, + method: http::Method, + bucket: &str, + key: Option<&str>, + payload: impl AsRef<[u8]>, + content_type: Option<&str>, + extra_headers: Option>, + ) -> Result { + let uri = http::Uri::from_str(&self.build_url(bucket, key)) + .expect("invalid uri rip (make sure the build_url function works as intended)"); + let mut headers = extra_headers.unwrap_or_default(); + headers.push(( + "host".to_string(), + uri.host().expect("Should have host in URI").to_owned(), + )); + if let Some(content_type) = content_type { + headers.push(("content-type".to_string(), content_type.to_owned())) + } + + let (_, header_map) = self.sigv4.signature(method, uri, headers, payload); + Ok(header_map) + } + + pub async fn upload_file( + &self, + bucket: &str, + local_file_path: &str, + r2_file_key: &str, + content_type: Option<&str>, + ) -> Result<(), R2Error> { + // Payload (file data) + let payload = std::fs::read(local_file_path)?; + trace!( + "[upload_file] Payload hash for signing: {}", + aws_sigv4::hash(&payload) + ); + + // Set HTTP Headers + let content_type = if let Some(content_type) = content_type { + Some(content_type) + } else { + Some(get_mimetype_from_fp(local_file_path)) + }; + let headers = self.create_headers( + Method::PUT, + bucket, + Some(r2_file_key), + &payload, + content_type, + None, + )?; + trace!("[upload_file] Headers sent to request: {headers:#?}"); + let file_url = self.build_url(bucket, Some(r2_file_key)); + let client = reqwest::Client::new(); + let resp = client + .put(&file_url) + .headers(headers) + .body(payload) + .send() + .await?; + let status = resp.status(); + let text = resp.text().await?; + if status.is_success() { + Ok(()) + } else { + Err(R2Error::FailedRequest( + format!( + "upload file {local_file_path} to bucket \"{bucket}\" under file key \"{r2_file_key}\"" + ), + status, + text, + )) + } + } + pub async fn download_file( + &self, + bucket: &str, + key: &str, + local_path: &str, + extra_headers: Option>, + ) -> Result<(), R2Error> { + // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#:~:text=For%20Amazon%20S3%2C%20include%20the%20literal%20string%20UNSIGNED%2DPAYLOAD%20when%20constructing%20a%20canonical%20request%2C%20and%20set%20the%20same%20value%20as%20the%20x%2Damz%2Dcontent%2Dsha256%20header%20value%20when%20sending%20the%20request. + // I don't know if I should trust it though, I don't see public impls with this. + let payload = ""; + trace!("[download_file] Payload for signing: (empty)"); + let headers = + self.create_headers(Method::GET, bucket, Some(key), payload, None, extra_headers)?; + trace!("[download_file] Headers sent to request: {headers:#?}"); + let file_url = self.build_url(bucket, Some(key)); + let client = reqwest::Client::new(); + let resp = client.get(&file_url).headers(headers).send().await?; + let status = resp.status(); + if status.is_success() { + std::fs::write(local_path, resp.bytes().await?)?; + Ok(()) + } else { + Err(R2Error::FailedRequest( + format!("dowloading file \"{key}\" from bucket \"{bucket}\""), + status, + resp.text().await?, + )) + } + } + pub async fn delete(&self, bucket: &str, remote_key: &str) -> Result<(), R2Error> { + let payload = ""; + trace!("[delete_file] Payload for signing: (empty)"); + let headers = self.create_headers( + Method::DELETE, + bucket, + Some(remote_key), + payload, + None, + None, + )?; + trace!("[delete_file] Headers sent to request: {headers:#?}"); + let file_url = self.build_url(bucket, Some(remote_key)); + let client = reqwest::Client::new(); + let resp = client.delete(&file_url).headers(headers).send().await?; + let status = resp.status(); + if status.is_success() { + Ok(()) + } else { + Err(R2Error::FailedRequest( + format!("deleting file \"{remote_key}\" from bucket \"{bucket}\""), + status, + resp.text().await?, + )) + } + } + async fn get_bucket_listing(&self, bucket: &str) -> Result { + let payload = ""; + trace!("[get_bucket_listing] Payload for signing: (empty)"); + let headers = self.create_headers(Method::GET, bucket, None, payload, None, None)?; + trace!("[get_bucket_listing] Headers sent to request: {headers:#?}"); + let url = self.build_url(bucket, None); + let client = reqwest::Client::new(); + let resp = client + .get(&url) + .headers(headers) + .send() + .await + .map_err(R2Error::from)?; + let status = resp.status(); + if status.is_success() { + Ok(resp.text().await.map_err(R2Error::from)?) + } else { + Err(R2Error::FailedRequest( + String::from("list bucket...folders or something idfk"), + status, + resp.text().await.map_err(R2Error::from)?, + )) + } + } + + pub async fn list_files(&self, bucket: &str) -> Result>, R2Error> { + let xml = self.get_bucket_listing(bucket).await?; + let mut files_dict: HashMap> = HashMap::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem { + let (folder, file_name): (String, String) = if let Some(idx) = file_key.rfind('/') { + (file_key[..idx].to_string(), file_key[idx + 1..].to_string()) + } else { + ("".to_string(), file_key.to_string()) + }; + files_dict.entry(folder).or_default().push(file_name); + } + } + Ok(files_dict) + } + + pub async fn list_folders(&self, bucket: &str) -> Result, R2Error> { + let xml = self.get_bucket_listing(bucket).await?; + let mut folders = std::collections::HashSet::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem + && let Some(idx) = file_key.find('/') + { + folders.insert(file_key[..idx].to_string()); + } + } + Ok(folders.into_iter().collect()) + } + + fn build_url(&self, bucket: &str, key: Option<&str>) -> String { + match key { + Some(k) => { + let encoded_key = aws_sigv4::url_encode(k); + format!("{}/{}/{}", self.endpoint, bucket, encoded_key) + } + None => format!("{}/{}/", self.endpoint, bucket), + } + } +} +impl Default for R2Client { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn r2client_from_env() -> R2Client { + unsafe { + std::env::set_var("R2_ACCESS_KEY", "AKIAEXAMPLE"); + std::env::set_var("R2_SECRET_KEY", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + std::env::set_var("R2_ENDPOINT", "https://example.r2.cloudflarestorage.com"); + } + R2Client::new() + } + + #[test] + fn r2client_env() { + let r2client = r2client_from_env(); + + // Sorry but I don't know if I should have the keys on the sigv4 pub or not yet + // assert_eq!(r2client.access_key, "AKIAEXAMPLE"); + // assert_eq!( + // r2client.secret_key, + // "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + // ); + assert_eq!( + r2client.endpoint, + "https://example.r2.cloudflarestorage.com" + ); + } + + #[test] + fn test_create_headers() { + let client = R2Client::from_credentials( + "AKIAEXAMPLE".to_string(), + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + "https://example.r2.cloudflarestorage.com".to_string(), + ); + let headers = client + .create_headers( + Method::PUT, + "bucket", + Some("key"), + "deadbeef", + Some("application/octet-stream"), + None, + ) + .unwrap(); + assert!(headers.contains_key("x-amz-date")); + assert!(headers.contains_key("authorization")); + assert!(headers.contains_key("content-type")); + assert!(headers.contains_key("host")); + } +} diff --git a/r2client/src/error.rs b/r2client/src/error.rs new file mode 100644 index 0000000..9438385 --- /dev/null +++ b/r2client/src/error.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum R2Error { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + #[error("XML parse error: {0}")] + Xml(#[from] xmltree::ParseError), + #[error("Missing environment varibles: {0}")] + Env(String), + #[error("Request failed during operation {0}: {1}\n{2}")] + FailedRequest(String, http::StatusCode, String), +} diff --git a/r2client/src/lib.rs b/r2client/src/lib.rs new file mode 100644 index 0000000..e2c4068 --- /dev/null +++ b/r2client/src/lib.rs @@ -0,0 +1,10 @@ +mod error; +mod mimetypes; +pub use error::R2Error; + +mod _async; +#[cfg(feature = "async")] +pub use _async::{R2Bucket, R2Client}; + +#[cfg(feature = "sync")] +pub mod sync; diff --git a/r2client/src/mimetypes.rs b/r2client/src/mimetypes.rs new file mode 100644 index 0000000..c3ae6c4 --- /dev/null +++ b/r2client/src/mimetypes.rs @@ -0,0 +1,112 @@ +pub fn get_mimetype(key: &str) -> &'static str { + match key { + // Image formats + ".png" => "image/png", + ".jpg" | ".jpeg" => "image/jpeg", + ".gif" => "image/gif", + ".svg" => "image/svg+xml", + ".ico" => "image/x-icon", + ".webp" => "image/webp", + + // Audio formats + ".m4a" => "audio/x-m4a", + ".mp3" => "audio/mpeg", + ".wav" => "audio/wav", + ".ogg" => "audio/ogg", + + // Video formats + ".mp4" => "video/mp4", + ".avi" => "video/x-msvideo", + ".mov" => "video/quicktime", + ".flv" => "video/x-flv", + ".wmv" => "video/x-ms-wmv", + ".webm" => "video/webm", + + // Document formats + ".pdf" => "application/pdf", + ".doc" => "application/msword", + ".docx" => "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".ppt" => "application/vnd.ms-powerpoint", + ".pptx" => "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".xls" => "application/vnd.ms-excel", + ".xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".txt" => "text/plain", + + // Web formats + ".html" => "text/html", + ".css" => "text/css", + ".js" => "application/javascript", + ".json" => "application/json", + ".xml" => "application/xml", + + // Other formats + ".csv" => "text/csv", + ".zip" => "application/zip", + ".tar" => "application/x-tar", + ".gz" => "application/gzip", + ".rar" => "application/vnd.rar", + ".7z" => "application/x-7z-compressed", + ".eps" => "application/postscript", + ".sql" => "application/sql", + ".java" => "text/x-java-source", + _ => "application/octet-stream", + } +} + +pub fn get_mimetype_from_fp(file_path: &str) -> &str { + // Sorry I just really wanted to get it done in a one liner. + // This splits a filepath based off ".", in reverse order, so that the first element will + // be the file extension (e.g. "~/.config/test.jpeg" becomes "jpeg") + // This is formated back to ".jpeg" because it's how the match statement is working. + // I could very easily change it but idk it was an interesting thing. + // + // Hey, so maybe you should change the match statement to not care about the '.'? + // Then again this is just being used for this project, so I guess it doesn't really matter + get_mimetype(&format!( + ".{}", + file_path + .rsplit(".") + .next() + .unwrap_or("time_to_be_an_octet_stream_lmao") + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn match_mime_test() { + assert_eq!(get_mimetype(".tar"), "application/x-tar"); + } + + #[test] + fn default_mime_test() { + assert_eq!(get_mimetype(".bf"), "application/octet-stream"); + } + + #[test] + fn mime_from_file() { + assert_eq!(get_mimetype_from_fp("test.ico"), "image/x-icon"); + } + + #[test] + fn mime_from_file_path() { + assert_eq!( + get_mimetype_from_fp("/home/testuser/Documents/test.pdf"), + "application/pdf" + ); + assert_eq!( + get_mimetype_from_fp("./bucket_test/bucket_test_upload.txt"), + "text/plain" + ) + } + + #[test] + fn no_ext() { + assert_eq!( + get_mimetype_from_fp("edge_case_lmao"), + "application/octet-stream" + ) + } +} diff --git a/r2client/src/sync.rs b/r2client/src/sync.rs new file mode 100644 index 0000000..57daa66 --- /dev/null +++ b/r2client/src/sync.rs @@ -0,0 +1,4 @@ +mod r2bucket; +mod r2client; +pub use r2bucket::R2Bucket; +pub use r2client::R2Client; diff --git a/r2client/src/sync/r2bucket.rs b/r2client/src/sync/r2bucket.rs new file mode 100644 index 0000000..e9fec51 --- /dev/null +++ b/r2client/src/sync/r2bucket.rs @@ -0,0 +1,54 @@ +use crate::sync::R2Client; +use crate::R2Error; + +#[derive(Debug)] +pub struct R2Bucket { + bucket: String, + pub client: R2Client, +} + +impl R2Bucket { + pub fn new(bucket: String) -> Self { + Self { + bucket, + client: R2Client::new(), + } + } + + pub fn from_client(bucket: String, client: R2Client) -> Self { + Self { bucket, client } + } + + pub fn from_credentials( + bucket: String, + access_key: String, + secret_key: String, + endpoint: String, + ) -> Self { + let client = R2Client::from_credentials(access_key, secret_key, endpoint); + Self { bucket, client } + } + + pub fn upload_file(&self, local_file_path: &str, r2_file_key: &str) -> Result<(), R2Error> { + self.client + // I'm pasing None to let the R2Client derive the content type from the local_file_path + .upload_file(&self.bucket, local_file_path, r2_file_key, None) + } + + pub fn download_file(&self, r2_file_key: &str, local_path: &str) -> Result<(), R2Error> { + self.client + .download_file(&self.bucket, r2_file_key, local_path, None) + } + + pub fn list_files(&self) -> Result>, R2Error> { + self.client.list_files(&self.bucket) + } + + pub fn list_folders(&self) -> Result, R2Error> { + self.client.list_folders(&self.bucket) + } + + pub fn delete_file(&self, r2_file_key: &str) -> Result<(), R2Error> { + self.client.delete(&self.bucket, r2_file_key) + } +} diff --git a/r2client/src/sync/r2client.rs b/r2client/src/sync/r2client.rs new file mode 100644 index 0000000..0176fd7 --- /dev/null +++ b/r2client/src/sync/r2client.rs @@ -0,0 +1,302 @@ +use crate::R2Error; +use crate::mimetypes::get_mimetype_from_fp; +use aws_sigv4::SigV4Credentials; +use http::Method; +use log::trace; +use reqwest::header::HeaderMap; +use std::collections::HashMap; +use std::str::FromStr; + +#[derive(Debug)] +pub struct R2Client { + sigv4: SigV4Credentials, + endpoint: String, +} +impl R2Client { + fn get_env() -> Result<(String, String, String), R2Error> { + let keys = ["R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_ENDPOINT"]; + let values = keys + .map(|key| { std::env::var(key).map_err(|_| R2Error::Env(key.to_owned())) }.unwrap()); + Ok(values.into()) + } + + pub fn new() -> Self { + let (access_key, secret_key, endpoint) = Self::get_env().unwrap(); + + Self { + sigv4: SigV4Credentials::new("s3", "auto", access_key, secret_key), + endpoint, + } + } + + pub fn from_credentials(access_key: String, secret_key: String, endpoint: String) -> Self { + Self { + sigv4: SigV4Credentials::new("s3", "auto", access_key, secret_key), + endpoint, + } + } + + fn create_headers( + &self, + method: http::Method, + bucket: &str, + key: Option<&str>, + payload: impl AsRef<[u8]>, + content_type: Option<&str>, + extra_headers: Option>, + ) -> Result { + let uri = http::Uri::from_str(&self.build_url(bucket, key)) + .expect("invalid uri rip (make sure the build_url function works as intended)"); + let mut headers = extra_headers.unwrap_or_default(); + headers.push(( + "host".to_string(), + uri.host().expect("Should have host in URI").to_owned(), + )); + if let Some(content_type) = content_type { + headers.push(("content-type".to_string(), content_type.to_owned())) + } + + let (_, header_map) = self.sigv4.signature(method, uri, headers, payload); + Ok(header_map) + } + + pub fn upload_file( + &self, + bucket: &str, + local_file_path: &str, + r2_file_key: &str, + content_type: Option<&str>, + ) -> Result<(), R2Error> { + // Payload (file data) + let payload = std::fs::read(local_file_path)?; + trace!( + "[upload_file] Payload hash for signing: {}", + aws_sigv4::hash(&payload) + ); + + // Set HTTP Headers + let content_type = if let Some(content_type) = content_type { + Some(content_type) + } else { + Some(get_mimetype_from_fp(local_file_path)) + }; + let headers = self.create_headers( + Method::PUT, + bucket, + Some(r2_file_key), + &payload, + content_type, + None, + )?; + trace!("[upload_file] Headers sent to request: {headers:#?}"); + let file_url = self.build_url(bucket, Some(r2_file_key)); + let client = reqwest::blocking::Client::new(); + let resp = client + .put(&file_url) + .headers(headers) + .body(payload) + .send()?; + let status = resp.status(); + let text = resp.text()?; + if status.is_success() { + Ok(()) + } else { + Err(R2Error::FailedRequest( + format!( + "upload file {local_file_path} to bucket \"{bucket}\" under file key \"{r2_file_key}\"" + ), + status, + text, + )) + } + } + pub fn download_file( + &self, + bucket: &str, + key: &str, + local_path: &str, + extra_headers: Option>, + ) -> Result<(), R2Error> { + // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#:~:text=For%20Amazon%20S3%2C%20include%20the%20literal%20string%20UNSIGNED%2DPAYLOAD%20when%20constructing%20a%20canonical%20request%2C%20and%20set%20the%20same%20value%20as%20the%20x%2Damz%2Dcontent%2Dsha256%20header%20value%20when%20sending%20the%20request. + // I don't know if I should trust it though, I don't see public impls with this. + let payload = ""; + trace!("[download_file] Payload for signing: (empty)"); + let headers = + self.create_headers(Method::GET, bucket, Some(key), payload, None, extra_headers)?; + trace!("[download_file] Headers sent to request: {headers:#?}"); + let file_url = self.build_url(bucket, Some(key)); + let client = reqwest::blocking::Client::new(); + let resp = client.get(&file_url).headers(headers).send()?; + let status = resp.status(); + if status.is_success() { + std::fs::write(local_path, resp.bytes()?)?; + Ok(()) + } else { + Err(R2Error::FailedRequest( + format!("dowloading file \"{key}\" from bucket \"{bucket}\""), + status, + resp.text()?, + )) + } + } + pub fn delete(&self, bucket: &str, remote_key: &str) -> Result<(), R2Error> { + let payload = ""; + trace!("[delete_file] Payload for signing: (empty)"); + let headers = self.create_headers( + Method::DELETE, + bucket, + Some(remote_key), + payload, + None, + None, + )?; + trace!("[delete_file] Headers sent to request: {headers:#?}"); + let file_url = self.build_url(bucket, Some(remote_key)); + let client = reqwest::blocking::Client::new(); + let resp = client.delete(&file_url).headers(headers).send()?; + let status = resp.status(); + if status.is_success() { + Ok(()) + } else { + Err(R2Error::FailedRequest( + format!("deleting file \"{remote_key}\" from bucket \"{bucket}\""), + status, + resp.text()?, + )) + } + } + fn get_bucket_listing(&self, bucket: &str) -> Result { + let payload = ""; + trace!("[get_bucket_listing] Payload for signing: (empty)"); + let headers = self.create_headers(Method::GET, bucket, None, payload, None, None)?; + trace!("[get_bucket_listing] Headers sent to request: {headers:#?}"); + let url = self.build_url(bucket, None); + let client = reqwest::blocking::Client::new(); + let resp = client + .get(&url) + .headers(headers) + .send() + .map_err(R2Error::from)?; + let status = resp.status(); + if status.is_success() { + Ok(resp.text().map_err(R2Error::from)?) + } else { + Err(R2Error::FailedRequest( + String::from("list bucket...folders or something idfk"), + status, + resp.text().map_err(R2Error::from)?, + )) + } + } + + pub fn list_files(&self, bucket: &str) -> Result>, R2Error> { + let xml = self.get_bucket_listing(bucket)?; + let mut files_dict: HashMap> = HashMap::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem { + let (folder, file_name): (String, String) = if let Some(idx) = file_key.rfind('/') { + (file_key[..idx].to_string(), file_key[idx + 1..].to_string()) + } else { + ("".to_string(), file_key.to_string()) + }; + files_dict.entry(folder).or_default().push(file_name); + } + } + Ok(files_dict) + } + + pub fn list_folders(&self, bucket: &str) -> Result, R2Error> { + let xml = self.get_bucket_listing(bucket)?; + let mut folders = std::collections::HashSet::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem + && let Some(idx) = file_key.find('/') + { + folders.insert(file_key[..idx].to_string()); + } + } + Ok(folders.into_iter().collect()) + } + + fn build_url(&self, bucket: &str, key: Option<&str>) -> String { + match key { + Some(k) => { + let encoded_key = aws_sigv4::url_encode(k); + format!("{}/{}/{}", self.endpoint, bucket, encoded_key) + } + None => format!("{}/{}/", self.endpoint, bucket), + } + } +} +impl Default for R2Client { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn r2client_from_env() -> R2Client { + unsafe { + std::env::set_var("R2_ACCESS_KEY", "AKIAEXAMPLE"); + std::env::set_var("R2_SECRET_KEY", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + std::env::set_var("R2_ENDPOINT", "https://example.r2.cloudflarestorage.com"); + } + R2Client::new() + } + + #[test] + fn r2client_env() { + let r2client = r2client_from_env(); + + // Sorry but I don't know if I should have the keys on the sigv4 pub or not yet + // assert_eq!(r2client.access_key, "AKIAEXAMPLE"); + // assert_eq!( + // r2client.secret_key, + // "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + // ); + assert_eq!( + r2client.endpoint, + "https://example.r2.cloudflarestorage.com" + ); + } + + #[test] + fn test_create_headers() { + let client = R2Client::from_credentials( + "AKIAEXAMPLE".to_string(), + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + "https://example.r2.cloudflarestorage.com".to_string(), + ); + let headers = client + .create_headers( + Method::PUT, + "bucket", + Some("key"), + "deadbeef", + Some("application/octet-stream"), + None, + ) + .unwrap(); + assert!(headers.contains_key("x-amz-date")); + assert!(headers.contains_key("authorization")); + assert!(headers.contains_key("content-type")); + assert!(headers.contains_key("host")); + } +} diff --git a/r2client/tests/r2_tests.rs b/r2client/tests/r2_tests.rs new file mode 100644 index 0000000..f5dcf00 --- /dev/null +++ b/r2client/tests/r2_tests.rs @@ -0,0 +1,137 @@ +use std::fs; +use std::io::Write; + +fn create_test_file(path: &str, content: &str) { + let mut file = fs::File::create(path).unwrap(); + file.write_all(content.as_bytes()).unwrap(); +} + +#[cfg(feature = "sync")] +mod sync_tests { + use super::create_test_file; + use r2client::sync::R2Bucket; + use std::env; + use std::fs; + + fn setup_bucket() -> R2Bucket { + dotenv::dotenv().ok(); + let bucket = env::var("R2_BUCKET").expect("R2_BUCKET not set for integration tests"); + let access_key = env::var("R2_ACCESS_KEY").expect("R2_ACCESS_KEY not set"); + let secret_key = env::var("R2_SECRET_KEY").expect("R2_SECRET_KEY not set"); + let endpoint = env::var("R2_ENDPOINT").expect("R2_ENDPOINT not set"); + R2Bucket::from_credentials(bucket, access_key, secret_key, endpoint) + } + + #[test] + fn test_sync_e2e() { + let bucket = setup_bucket(); + let test_content = "Hello, R2 sync world!"; + let local_upload_path = "test_upload_sync.txt"; + let r2_file_key = "test/test_upload_sync.txt"; + let local_download_path = "test_download_sync.txt"; + + create_test_file(local_upload_path, test_content); + + // 1. Upload file + bucket + .upload_file(local_upload_path, r2_file_key) + .expect("Sync upload failed"); + + // 2. List files and check if it exists + let files = bucket.list_files().expect("Sync list_files failed"); + assert!( + files + .get("test") + .unwrap() + .contains(&"test_upload_sync.txt".to_string()) + ); + + // 3. List folders and check if it exists + let folders = bucket.list_folders().expect("Sync list_folders failed"); + assert!(folders.contains(&"test".to_string())); + + // 4. Download file + bucket + .download_file(r2_file_key, local_download_path) + .expect("Sync download failed"); + + // 5. Verify content + let downloaded_content = fs::read_to_string(local_download_path).unwrap(); + assert_eq!(test_content, downloaded_content); + + // Cleanup + fs::remove_file(local_upload_path).unwrap(); + fs::remove_file(local_download_path).unwrap(); + } +} + +#[cfg(feature = "async")] +mod async_tests { + use super::create_test_file; + use r2client::R2Bucket; + use std::env; + use std::fs; + + fn setup_bucket() -> R2Bucket { + dotenv::dotenv().ok(); + let bucket = env::var("R2_BUCKET").expect("R2_BUCKET not set for integration tests"); + let access_key = env::var("R2_ACCESS_KEY").expect("R2_ACCESS_KEY not set"); + let secret_key = env::var("R2_SECRET_KEY").expect("R2_SECRET_KEY not set"); + let endpoint = env::var("R2_ENDPOINT").expect("R2_ENDPOINT not set"); + R2Bucket::from_credentials(bucket, access_key, secret_key, endpoint) + } + + #[tokio::test] + async fn test_async_e2e() { + let bucket = setup_bucket(); + let test_content = "Hello, R2 async world!"; + let local_upload_path = "test_upload_async.txt"; + let r2_file_key = "test/test_upload_async.txt"; + let local_download_path = "test_download_async.txt"; + + create_test_file(local_upload_path, test_content); + + // 0. List files to see if a get request will go through lol + let files = bucket.list_files().await.expect("Async list_files failed"); + println!("{files:#?}"); + + // 1. Upload file + bucket + .upload_file(local_upload_path, r2_file_key) + .await + .expect("Async upload failed"); + + // 2. List files and check if it exists + let files = bucket.list_files().await.expect("Async list_files failed"); + assert!( + files + .get("test") + .unwrap() + .contains(&"test_upload_async.txt".to_string()) + ); + + // 3. List folders and check if it exists + let folders = bucket + .list_folders() + .await + .expect("Async list_folders failed"); + assert!(folders.contains(&"test".to_string())); + + // 4. Download file + bucket + .download_file(r2_file_key, local_download_path) + .await + .expect("Async download failed"); + + // 5. Verify content + let downloaded_content = fs::read_to_string(local_download_path).unwrap(); + assert_eq!(test_content, downloaded_content); + + // Cleanup + fs::remove_file(local_upload_path).unwrap(); + fs::remove_file(local_download_path).unwrap(); + + // 6. Delete file + bucket.delete_file(r2_file_key).await.unwrap(); + } +} diff --git a/r2client/todo.md b/r2client/todo.md new file mode 100644 index 0000000..bad86f6 --- /dev/null +++ b/r2client/todo.md @@ -0,0 +1,12 @@ +## For release: + - [ ] Create a crate::Result that is Result, and have Ok(status_code) + - [ ] Consider dropping more dependencies, using hyper or some lower level stuff for async, and then http for blocking + - [ ] A way to view the file contents (UTF-8 valid) would be cool + - [ ] Add functions that will list files with their metadata (perhaps a simple R2File type?) + - [ ] Clear out all all print statements and consider logging (this is a library, after all) + +## Dev (since we're so back): + - [X] Update the sync library + - [X] Make a .env with test-bucket creds + - [X] Actually test the damn thing +