commit 4e3075baad90951ed7745b6afcb6267158441d95 Author: foreverpyrite Date: Sat Sep 6 18:14:28 2025 -0400 Pre-testing phase diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b97b0a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +.aider* +.env diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..091590b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "r2client" +version = "0.1.0" +edition = "2024" + +[lib] + +[dependencies] +sha2 = "0.10.9" +bytes = "1.10.1" +reqwest = { version = "0.12.19", features = ["blocking"] } +chrono = "0.4.41" +hex = "0.4.3" +hmac = "0.12.1" +xmltree = "0.11.0" +thiserror = "2" +async-trait = "0.1.89" +async-std = { version = "1.0", optional = true } +tokio = { version = "1.0", features = ["rt-multi-thread"], optional = true } +futures-executor = { version = "0.3", optional = true } +urlencoding = "2.1.3" +http = "1.3.1" + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +dotenv = "0.15" + +[features] +async = ["tokio"] +default = ["async"] +sync = ["tokio/rt-multi-thread", "futures-executor"] +async-std-runtime = ["async-std"] diff --git a/src/_async.rs b/src/_async.rs new file mode 100644 index 0000000..57daa66 --- /dev/null +++ b/src/_async.rs @@ -0,0 +1,4 @@ +mod r2bucket; +mod r2client; +pub use r2bucket::R2Bucket; +pub use r2client::R2Client; diff --git a/src/_async/r2bucket.rs b/src/_async/r2bucket.rs new file mode 100644 index 0000000..fcd3cae --- /dev/null +++ b/src/_async/r2bucket.rs @@ -0,0 +1,84 @@ +use crate::R2Client; +use crate::R2Error; + +#[derive(Clone, Debug)] +pub struct R2Bucket { + bucket: String, + pub client: R2Client, +} + +impl R2Bucket { + pub fn new(bucket: String, client: R2Client) -> Self { + Self { bucket, client } + } + + pub fn from_credentials( + bucket: String, + access_key: String, + secret_key: String, + endpoint: String, + ) -> Self { + let client = R2Client::from_credentials(access_key, secret_key, endpoint); + Self { bucket, client } + } + + pub async fn upload_file( + &self, + local_file_path: &str, + r2_file_key: &str, + ) -> Result<(), R2Error> { + self.client + // I'm pasing None to let the R2Client derive the content type from the local_file_path + .upload_file(&self.bucket, local_file_path, r2_file_key, None) + .await + } + + pub async fn download_file(&self, r2_file_key: &str, local_path: &str) -> Result<(), R2Error> { + self.client + .download_file(&self.bucket, r2_file_key, local_path) + .await + } + + pub async fn list_files( + &self, + ) -> Result>, R2Error> { + self.client.list_files(&self.bucket).await + } + + pub async fn list_folders(&self) -> Result, R2Error> { + self.client.list_folders(&self.bucket).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::R2Client; + use std::env; + + fn get_test_bucket() -> R2Bucket { + dotenv::dotenv().ok(); + let access_key = + env::var("R2_ACCESS_KEY").unwrap_or_else(|_| "test_access_key".to_string()); + let secret_key = + env::var("R2_SECRET_KEY").unwrap_or_else(|_| "test_secret_key".to_string()); + let endpoint = env::var("R2_ENDPOINT") + .unwrap_or_else(|_| "https://example.r2.cloudflarestorage.com".to_string()); + let client = R2Client::from_credentials(access_key, secret_key, endpoint); + R2Bucket::new("test-bucket".to_string(), client) + } + + #[test] + fn test_bucket_construction() { + let bucket = get_test_bucket(); + assert_eq!(bucket.bucket, "test-bucket"); + } + + // Example async test (requires a runtime, so ignored by default) + // #[tokio::test] + // async fn test_upload_file() { + // let bucket = get_test_bucket(); + // let result = bucket.upload_file("Cargo.toml", "test-upload.toml").await; + // assert!(result.is_ok()); + // } +} diff --git a/src/_async/r2client.rs b/src/_async/r2client.rs new file mode 100644 index 0000000..0b5c722 --- /dev/null +++ b/src/_async/r2client.rs @@ -0,0 +1,289 @@ +use crate::mimetypes::Mime; +use crate::{R2Error, aws_signing}; +use http::Method; +use reqwest::header::{self, HeaderName, HeaderValue}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::str::FromStr; + +#[derive(Clone, Debug)] +pub struct R2Client { + access_key: String, + secret_key: String, + endpoint: String, +} + +impl R2Client { + fn get_env() -> Result<(String, String, String), R2Error> { + let keys = ["R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_ENDPOINT"]; + let values = keys + .map(|key| { std::env::var(key).map_err(|_| R2Error::Env(key.to_owned())) }.unwrap()); + Ok(values.into()) + } + + pub fn new() -> Self { + let (access_key, secret_key, endpoint) = Self::get_env().unwrap(); + + Self { + access_key, + secret_key, + endpoint, + } + } + + pub fn from_credentials(access_key: String, secret_key: String, endpoint: String) -> Self { + Self { + access_key, + secret_key, + endpoint, + } + } + + fn create_headers( + &self, + method: http::Method, + bucket: &str, + key: Option<&str>, + payload_hash: &str, + content_type: Option<&str>, + ) -> Result { + let uri = http::Uri::from_str(&self.build_url(bucket, key)) + .expect("invalid uri rip (make sure the build_url function works as intended)"); + let mut headers = Vec::new(); + if method == Method::GET { + headers.push(( + "x-amz-content-sha256".to_string(), + "UNSIGNED-PAYLOAD".to_string(), + )) + } + if let Some(content_type) = content_type { + headers.push(("content-type".to_string(), content_type.to_owned())) + } + + let (_, headers) = aws_signing::signature( + method, + uri, + headers, + payload_hash, + "s3", + "us-east-1", + &self.secret_key, + &self.access_key, + ); + let mut header_map = header::HeaderMap::new(); + for header in headers { + header_map.insert( + HeaderName::from_lowercase(&header.0.to_lowercase().as_bytes()) + .expect("shit tragic"), + HeaderValue::from_str(&header.1).expect("shit more tragic"), + ); + } + Ok(header_map) + } + + pub async fn upload_file( + &self, + bucket: &str, + local_file_path: &str, + r2_file_key: &str, + content_type: Option<&str>, + ) -> Result<(), R2Error> { + // --- Hash Payload -- + let file_data = std::fs::read(local_file_path)?; + let payload_hash = hex::encode(Sha256::digest(&file_data)); + + // Set HTTP Headers + let content_type = if let Some(content_type) = content_type { + Some(content_type) + } else { + Some(Mime::get_mimetype_from_fp(local_file_path)) + }; + let headers = self.create_headers( + Method::PUT, + bucket, + Some(r2_file_key), + &payload_hash, + content_type, + )?; + let file_url = format!("{}/{}/{}", self.endpoint, bucket, r2_file_key); + let client = reqwest::Client::new(); + let resp = client + .put(&file_url) + .headers(headers) + .body(file_data) + .send() + .await?; + let status = resp.status(); + let text = resp.text().await?; + if status.is_success() { + Ok(()) + } else { + Err(R2Error::Other(format!( + "Upload failed with status {}: {}", + status, text + ))) + } + } + pub async fn download_file( + &self, + bucket: &str, + key: &str, + local_path: &str, + ) -> Result<(), R2Error> { + let payload_hash = hex::encode(Sha256::digest("")); + let content_type = Mime::get_mimetype_from_fp(local_path); + let headers = self.create_headers( + Method::GET, + bucket, + Some(key), + &payload_hash, + Some(content_type), + )?; + let file_url = format!("{}/{}/{}", self.endpoint, bucket, key); + let client = reqwest::Client::new(); + let resp = client.get(&file_url).headers(headers).send().await?; + let status = resp.status(); + let content = resp.bytes().await?; + if status.is_success() { + std::fs::write(local_path, &content)?; + Ok(()) + } else { + Err(R2Error::Other(format!( + "Download failed with status {}", + status + ))) + } + } + async fn get_bucket_listing(&self, bucket: &str) -> Result { + let payload_hash = "UNSIGNED-PAYLOAD"; + let content_type = "application/xml"; + let headers = + self.create_headers(Method::GET, bucket, None, payload_hash, Some(content_type))?; + let url = self.build_url(bucket, None); + let client = reqwest::Client::new(); + let resp = client + .get(&url) + .headers(headers) + .send() + .await + .map_err(R2Error::from)?; + let status = resp.status(); + if status.is_success() { + resp.text().await.map_err(R2Error::from) + } else { + Err(R2Error::Other(format!("Failed to list bucket: {}", status))) + } + } + + pub async fn list_files(&self, bucket: &str) -> Result>, R2Error> { + let xml = self.get_bucket_listing(bucket).await?; + let mut files_dict: HashMap> = HashMap::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem { + let (folder, file_name): (String, String) = if let Some(idx) = file_key.rfind('/') { + (file_key[..idx].to_string(), file_key[idx + 1..].to_string()) + } else { + ("".to_string(), file_key.to_string()) + }; + files_dict.entry(folder).or_default().push(file_name); + } + } + Ok(files_dict) + } + + pub async fn list_folders(&self, bucket: &str) -> Result, R2Error> { + let xml = self.get_bucket_listing(bucket).await?; + let mut folders = std::collections::HashSet::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem { + if let Some(idx) = file_key.find('/') { + folders.insert(file_key[..idx].to_string()); + } + } + } + Ok(folders.into_iter().collect()) + } + + fn build_url(&self, bucket: &str, key: Option<&str>) -> String { + match key { + Some(k) => format!("{}/{}/{}", self.endpoint, bucket, k), + None => format!("{}/{}/", self.endpoint, bucket), + } + } +} +impl Default for R2Client { + fn default() -> Self { + let (access_key, secret_key, endpoint) = Self::get_env().unwrap(); + + Self { + access_key, + secret_key, + endpoint, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn r2client_from_env() -> R2Client { + unsafe { + std::env::set_var("R2_ACCESS_KEY", "AKIAEXAMPLE"); + std::env::set_var("R2_SECRET_KEY", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + std::env::set_var("R2_ENDPOINT", "https://example.r2.cloudflarestorage.com"); + } + R2Client::new() + } + + #[test] + fn r2client_env() { + let r2client = r2client_from_env(); + + assert_eq!(r2client.access_key, "AKIAEXAMPLE"); + assert_eq!( + r2client.secret_key, + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + ); + assert_eq!( + r2client.endpoint, + "https://example.r2.cloudflarestorage.com" + ); + } + + #[test] + fn test_create_headers() { + let client = R2Client::from_credentials( + "AKIAEXAMPLE".to_string(), + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + "https://example.r2.cloudflarestorage.com".to_string(), + ); + let headers = client + .create_headers( + Method::PUT, + "bucket", + Some("key"), + "deadbeef", + Some("application/octet-stream"), + ) + .unwrap(); + assert!(headers.contains_key("x-amz-date")); + assert!(headers.contains_key("authorization")); + assert!(headers.contains_key("content-type")); + assert!(headers.contains_key("host")); + } +} diff --git a/src/aws_signing.rs b/src/aws_signing.rs new file mode 100644 index 0000000..b3b5333 --- /dev/null +++ b/src/aws_signing.rs @@ -0,0 +1,197 @@ +use chrono::Utc; +use hmac::{Hmac, Mac}; +use sha2::{Digest, Sha256}; + +type Hmac256 = Hmac; + +// --- Utility functions --- +fn lowercase(string: &str) -> String { + string.to_lowercase() +} + +fn hex>(data: T) -> String { + hex::encode(data) +} + +fn sha256hash>(data: T) -> [u8; 32] { + Sha256::digest(data).into() +} + +fn hmac_sha256(signing_key: &[u8], message: &str) -> Vec { + let mut mac = Hmac256::new_from_slice(signing_key).expect("bad key :pensive:"); + mac.update(message.as_bytes()); + mac.finalize().into_bytes().to_vec() +} + +fn trim(string: &str) -> String { + string.trim().to_string() +} + +fn hash>(payload: T) -> String { + hex(sha256hash(payload)) +} + +fn url_encode(url: &str) -> String { + let mut url = urlencoding::encode(url).into_owned(); + let encoded_to_replacement: [(&str, &str); 4] = + [("+", "%20"), ("*", "%2A"), ("%7E", "~"), ("%2F", "/")]; + for (encoded_chars_pattern, replacement) in encoded_to_replacement { + url = url.replace(encoded_chars_pattern, replacement) + } + url +} + +// --- Canonical request --- +fn create_canonical_request( + method: http::Method, + uri: http::Uri, + mut headers: Vec<(String, String)>, + hashed_payload: &str, +) -> (String, String, String) { + // HTTPMethod + let http_method = method.to_string(); + + // CanonicalURI = *path only* (spec forbids scheme+host here) + let canonical_uri = if uri.path().is_empty() { + "/".to_string() + } else { + uri.path().to_string() + }; + + // CanonicalQueryString (URL-encoded, sorted by key) + let canonical_query_string = if let Some(query_string) = uri.query() { + let mut pairs = query_string + .split('&') + .map(|query| { + let (k, v) = query.split_once('=').unwrap_or((query, "")); + (url_encode(k), url_encode(v)) + }) + .collect::>(); + pairs.sort_by(|a, b| a.0.cmp(&b.0)); + pairs + .into_iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join("&") + } else { + String::new() + }; + + // Ensure required headers (host and x-amz-date) are present + let host = uri + .host() + .expect("uri passed without a proper host") + .to_string(); + if !headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("host")) { + headers.push(("host".to_string(), host)); + } + + // CanonicalHeaders + SignedHeaders + let mut http_headers = headers + .iter() + .map(|(name, value)| (lowercase(name), trim(value))) + .collect::>(); + http_headers.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); + + let canonical_headers: String = http_headers + .iter() + .map(|(k, v)| format!("{}:{}\n", k, v)) + .collect(); + + let signed_headers: String = http_headers + .iter() + .map(|(k, _)| k.clone()) + .collect::>() + .join(";"); + + // Final canonical request + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + http_method, + canonical_uri, + canonical_query_string, + canonical_headers, + signed_headers, + hashed_payload + ); + + (canonical_request, signed_headers, canonical_headers) +} + +fn credential_scope(date: &str, region: &str, service: &str) -> String { + format!( + "{}/{}/{}/aws4_request", + date, + lowercase(region), + lowercase(service) + ) +} + +fn string_to_sign(scope: &str, amz_date: &str, canonical_request: &str) -> String { + format!( + "{}\n{}\n{}\n{}", + "AWS4-HMAC-SHA256", + amz_date, + scope, + hex(sha256hash(canonical_request)) + ) +} + +fn derive_signing_key(key: &str, date: &str, region: &str, service: &str) -> Vec { + let secret_key = format!("AWS4{}", key); + let date_key = hmac_sha256(secret_key.as_bytes(), date); + let date_region_key = hmac_sha256(&date_key, region); + let date_region_service_key = hmac_sha256(&date_region_key, service); + hmac_sha256(&date_region_service_key, "aws4_request") +} + +fn calculate_signature(signing_key: &[u8], string_to_sign: &str) -> Vec { + hmac_sha256(signing_key, string_to_sign) +} + +// --- API --- +pub fn signature( + method: http::Method, + uri: http::Uri, + mut headers: Vec<(String, String)>, + hashed_payload: &str, + service: &str, + region: &str, + secret_key: &str, + access_key: &str, +) -> (String, Vec<(String, String)>) { + let now = Utc::now(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = now.format("%Y%m%d").to_string(); + + // Add x-amz-date header if not already present + if !headers + .iter() + .any(|(k, _)| k.eq_ignore_ascii_case("x-amz-date")) + { + headers.push(("x-amz-date".to_string(), amz_date.clone())); + } + + // Canonical request + let (canonical_request, signed_headers, _canonical_headers) = + create_canonical_request(method, uri, headers.clone(), hashed_payload); + + // String to sign + let scope = credential_scope(&date_stamp, region, service); + let string_to_sign = string_to_sign(&scope, &amz_date, &canonical_request); + + // Signing key + signature + let signing_key = derive_signing_key(secret_key, &date_stamp, region, service); + let signature = hex(calculate_signature(&signing_key, &string_to_sign)); + + // Authorization header + let credential = format!("{}/{}", access_key, scope); + let auth_header = format!( + "{} Credential={}, SignedHeaders={}, Signature={}", + "AWS4-HMAC-SHA256", credential, signed_headers, signature + ); + + headers.push(("authorization".to_string(), auth_header)); + + (signature, headers) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..38cd4b9 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum R2Error { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + #[error("XML parse error: {0}")] + Xml(#[from] xmltree::ParseError), + #[error("Missing environment varibles: {0}")] + Env(String), + #[error("Other: {0}")] + Other(String), +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..02b7efa --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,19 @@ +mod aws_signing; +mod error; +mod mimetypes; +pub use error::R2Error; + +mod _async; +#[cfg(feature = "async")] +pub use _async::{R2Bucket, R2Client}; + +#[cfg(feature = "sync")] +pub mod sync; + +#[cfg(test)] +mod test { + // use crate::{R2Bucket, R2Client, sync}; + + #[test] + fn test() {} +} diff --git a/src/mimetypes.rs b/src/mimetypes.rs new file mode 100644 index 0000000..28ffcc1 --- /dev/null +++ b/src/mimetypes.rs @@ -0,0 +1,103 @@ + + +pub enum Mime{} + +impl Mime { + pub fn get_mimetype(key: &str) -> &'static str { + match key { + // Image formats + ".png" => "image/png", + ".jpg" | ".jpeg" => "image/jpeg", + ".gif" => "image/gif", + ".svg" => "image/svg+xml", + ".ico" => "image/x-icon", + ".webp" => "image/webp", + + // Audio formats + ".m4a" => "audio/x-m4a", + ".mp3" => "audio/mpeg", + ".wav" => "audio/wav", + ".ogg" => "audio/ogg", + + // Video formats + ".mp4" => "video/mp4", + ".avi" => "video/x-msvideo", + ".mov" => "video/quicktime", + ".flv" => "video/x-flv", + ".wmv" => "video/x-ms-wmv", + ".webm" => "video/webm", + + // Document formats + ".pdf" => "application/pdf", + ".doc" => "application/msword", + ".docx" => "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".ppt" => "application/vnd.ms-powerpoint", + ".pptx" => "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".xls" => "application/vnd.ms-excel", + ".xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".txt" => "text/plain", + + // Web formats + ".html" => "text/html", + ".css" => "text/css", + ".js" => "application/javascript", + ".json" => "application/json", + ".xml" => "application/xml", + + // Other formats + ".csv" => "text/csv", + ".zip" => "application/zip", + ".tar" => "application/x-tar", + ".gz" => "application/gzip", + ".rar" => "application/vnd.rar", + ".7z" => "application/x-7z-compressed", + ".eps" => "application/postscript", + ".sql" => "application/sql", + ".java" => "text/x-java-source", + _ => "application/octet-stream", + } + } + + pub fn get_mimetype_from_fp(file_path: &str) -> &str { + // Sorry I just really wanted to get it done in a one liner. + // This splits a filepath based off ".", in reverse order, so that the first element will + // be the file extension (e.g. "~/.config/test.jpeg" becomes "jpeg") + // This is formated back to ".jpeg" because it's how the match statement is working. + // I could very easily change it but idk it was an interesting thing. + Self::get_mimetype( + &format!( + ".{}", file_path.rsplit(".") + .next() + .unwrap_or("time_to_be_an_octet_stream_lmao") + ) + ) + } + +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn match_mime_test() { + assert_eq!(Mime::get_mimetype(".tar"), "application/x-tar"); + } + + #[test] + fn default_mime_test() { + assert_eq!(Mime::get_mimetype(".bf"), "application/octet-stream"); + } + + #[test] + fn mime_from_file() { + assert_eq!(Mime::get_mimetype_from_fp("test.ico"), "image/x-icon"); + } + + #[test] + fn mime_from_file_path() { + assert_eq!(Mime::get_mimetype_from_fp("/home/testuser/Documents/test.pdf"), "application/pdf"); + assert_eq!(Mime::get_mimetype_from_fp("./bucket_test/bucket_test_upload.txt"), "text/plain") + } + +} \ No newline at end of file diff --git a/src/sync.rs b/src/sync.rs new file mode 100644 index 0000000..57daa66 --- /dev/null +++ b/src/sync.rs @@ -0,0 +1,4 @@ +mod r2bucket; +mod r2client; +pub use r2bucket::R2Bucket; +pub use r2client::R2Client; diff --git a/src/sync/r2bucket.rs b/src/sync/r2bucket.rs new file mode 100644 index 0000000..b8d6e36 --- /dev/null +++ b/src/sync/r2bucket.rs @@ -0,0 +1,67 @@ +use crate::R2Error; +use crate::sync::R2Client; + +#[derive(Clone, Debug)] +pub struct R2Bucket { + bucket: String, + pub client: R2Client, +} + +impl R2Bucket { + pub fn new(bucket: String, client: R2Client) -> Self { + Self { bucket, client } + } + + pub fn from_credentials( + bucket: String, + access_key: String, + secret_key: String, + endpoint: String, + ) -> Self { + let client = R2Client::from_credentials(access_key, secret_key, endpoint); + Self { bucket, client } + } + + pub fn upload_file(&self, local_file_path: &str, r2_file_key: &str) -> Result<(), R2Error> { + self.client + .upload_file(&self.bucket, local_file_path, r2_file_key) + } + + pub fn download_file(&self, r2_file_key: &str, local_path: &str) -> Result<(), R2Error> { + self.client + .download_file(&self.bucket, r2_file_key, local_path) + } + + pub fn list_files(&self) -> Result>, R2Error> { + self.client.list_files(&self.bucket) + } + + pub fn list_folders(&self) -> Result, R2Error> { + self.client.list_folders(&self.bucket) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sync::R2Bucket; + use std::env; + + fn get_test_bucket() -> R2Bucket { + dotenv::dotenv().ok(); + let access_key = + env::var("R2_ACCESS_KEY").unwrap_or_else(|_| "test_access_key".to_string()); + let secret_key = + env::var("R2_SECRET_KEY").unwrap_or_else(|_| "test_secret_key".to_string()); + let endpoint = env::var("R2_ENDPOINT") + .unwrap_or_else(|_| "https://example.r2.cloudflarestorage.com".to_string()); + let client = R2Client::from_credentials(access_key, secret_key, endpoint); + R2Bucket::new("test-bucket".to_string(), client) + } + + #[test] + fn test_bucket_construction() { + let bucket = get_test_bucket(); + assert_eq!(bucket.bucket, "test-bucket"); + } +} diff --git a/src/sync/r2client.rs b/src/sync/r2client.rs new file mode 100644 index 0000000..36470a3 --- /dev/null +++ b/src/sync/r2client.rs @@ -0,0 +1,387 @@ +use crate::R2Error; +use crate::mimetypes::Mime; +use chrono::Utc; +use hmac::{Hmac, Mac}; +use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; + +type HmacSHA256 = Hmac; + +#[derive(Clone, Debug)] +pub struct R2Client { + access_key: String, + secret_key: String, + endpoint: String, +} + +impl R2Client { + fn get_env() -> Result<(String, String, String), R2Error> { + let keys = ["R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_ENDPOINT"]; + let values = keys + .map(|key| { std::env::var(key).map_err(|_| R2Error::Env(key.to_owned())) }.unwrap()); + Ok(values.into()) + } + + pub fn new() -> Self { + let (access_key, secret_key, endpoint) = Self::get_env().unwrap(); + + Self { + access_key, + secret_key, + endpoint, + } + } + + pub fn from_credentials(access_key: String, secret_key: String, endpoint: String) -> Self { + Self { + access_key, + secret_key, + endpoint, + } + } + + fn sign(&self, key: &[u8], msg: &str) -> Vec { + let mut mac = HmacSHA256::new_from_slice(key).expect("Invalid key length"); + mac.update(msg.as_bytes()); + mac.finalize().into_bytes().to_vec() + } + + fn get_signature_key(&self, date_stamp: &str, region: &str, service: &str) -> Vec { + let aws4_secret: String = format!("AWS4{}", self.secret_key); + let k_date = self.sign(aws4_secret.as_bytes(), date_stamp); + let k_region = self.sign(&k_date, region); + let k_service = self.sign(&k_region, service); + self.sign(&k_service, "aws4_request") + } + + fn create_headers( + &self, + method: &str, + bucket: &str, + key: &str, + payload_hash: &str, + content_type: &str, + ) -> Result { + // Robustly extract host from endpoint + let endpoint = self.endpoint.trim_end_matches('/'); + // Not proud of this, it is really dumb and hard to read, but it'll work I suppose...I think... + let host = endpoint + .split("//") + .nth(1) + .unwrap_or(endpoint) + .split('/') + .next() + .unwrap_or(endpoint) + .split(':') + .next() + .unwrap_or(endpoint) + .trim(); + if host.is_empty() { + return Err(R2Error::Other( + "Host header could not be determined from endpoint".to_string(), + )); + } + let t = Utc::now(); + let amz_date = t.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = t.format("%Y%m%d").to_string(); + + let mut headers_vec = [ + ("host", host), + ("x-amz-date", &amz_date), + ("x-amz-content-sha256", payload_hash), + ("content-type", content_type), + ]; + headers_vec.sort_by(|a, b| a.0.cmp(b.0)); + + let signed_headers = headers_vec + .iter() + .map(|(k, _)| *k) + .collect::>() + .join(";"); + let canonical_headers = headers_vec + .iter() + .map(|(k, v)| format!("{}:{}\n", k.to_lowercase(), v)) + .collect::(); + + let canonical_uri = format!("/{}/{}", bucket, key); + let canonical_request = format!( + "{method}\n{uri}\n\n{headers}\n{signed_headers}\n{payload_hash}", + method = method, + uri = canonical_uri, + headers = canonical_headers, + signed_headers = signed_headers, + payload_hash = payload_hash + ); + let credential_scope = format!("{}/{}/s3/aws4_request", date_stamp, "auto"); + let hashed_request = hex::encode(Sha256::digest(canonical_request.as_bytes())); + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashed_request}", + amz_date = amz_date, + credential_scope = credential_scope, + hashed_request = hashed_request + ); + let signing_key = self.get_signature_key(&date_stamp, "auto", "s3"); + let signature = hex::encode(self.sign(&signing_key, &string_to_sign)); + let authorization = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + self.access_key, credential_scope, signed_headers, signature + ); + + // Print all headers for debugging + println!("[r2client] DEBUG: Built headers:"); + println!(" host: {}", host); + println!(" x-amz-date: {}", amz_date); + println!(" x-amz-content-sha256: {}", payload_hash); + println!(" content-type: {}", content_type); + println!(" authorization: {}", authorization); + println!(" signed_headers: {}", signed_headers); + println!( + " canonical_headers: {}", + canonical_headers.replace("\n", "\\n") + ); + println!( + " canonical_request: {}", + canonical_request.replace("\n", "\\n") + ); + println!(" string_to_sign: {}", string_to_sign.replace("\n", "\\n")); + println!(" signature: {}", signature); + + let mut header_map = HeaderMap::new(); + header_map.insert( + HeaderName::from_static("x-amz-date"), + HeaderValue::from_str(&amz_date) + .map_err(|e| R2Error::Other(format!("Invalid x-amz-date: {e}")))?, + ); + header_map.insert( + HeaderName::from_static("x-amz-content-sha256"), + HeaderValue::from_str(payload_hash).map_err(|e| { + R2Error::Other(format!( + "Invalid x-amz-content-sha256: {payload_hash:?}: {e}" + )) + })?, + ); + header_map.insert( + HeaderName::from_static("authorization"), + HeaderValue::from_str(&authorization).map_err(|e| { + R2Error::Other(format!( + "Invalid authorization: {e}\nValue: {authorization}" + )) + })?, + ); + header_map.insert( + HeaderName::from_static("content-type"), + HeaderValue::from_str(content_type) + .map_err(|e| R2Error::Other(format!("Invalid content-type: {e}")))?, + ); + header_map.insert( + HeaderName::from_static("host"), + HeaderValue::from_str(host) + .map_err(|e| R2Error::Other(format!("Invalid host: {e}")))?, + ); + Ok(header_map) + } + + pub fn upload_file( + &self, + bucket: &str, + local_file_path: &str, + r2_file_key: &str, + ) -> Result<(), R2Error> { + let file_data = std::fs::read(local_file_path)?; + let mut hasher = Sha256::new(); + hasher.update(&file_data); + let payload_hash = hex::encode(hasher.finalize()); + // let content_type = "application/octet-stream"; + let content_type = Mime::get_mimetype_from_fp(local_file_path); + let headers = + self.create_headers("PUT", bucket, r2_file_key, &payload_hash, content_type)?; + let file_url = format!("{}/{}/{}", self.endpoint, bucket, r2_file_key); + let client = reqwest::blocking::Client::new(); + let resp = client + .put(&file_url) + .headers(headers) + .body(file_data) + .send()?; + let status = resp.status(); + let text = resp.text()?; + if status.is_success() { + Ok(()) + } else { + Err(R2Error::Other(format!( + "Upload failed with status {}: {}", + status, text + ))) + } + } + pub fn download_file(&self, bucket: &str, key: &str, local_path: &str) -> Result<(), R2Error> { + let payload_hash = "UNSIGNED-PAYLOAD"; + let content_type = "application/octet-stream"; + let headers = self.create_headers("GET", bucket, key, payload_hash, content_type)?; + let file_url = format!("{}/{}/{}", self.endpoint, bucket, key); + let client = reqwest::blocking::Client::new(); + let resp = client.get(&file_url).headers(headers).send()?; + let status = resp.status(); + let content = resp.bytes()?; + if status.is_success() { + std::fs::write(local_path, &content)?; + Ok(()) + } else { + Err(R2Error::Other(format!( + "Download failed with status {}", + status + ))) + } + } + fn get_bucket_listing(&self, bucket: &str) -> Result { + let payload_hash = "UNSIGNED-PAYLOAD"; + let content_type = "application/xml"; + let headers = self.create_headers("GET", bucket, "", payload_hash, content_type)?; + let url = self.build_url(bucket, None); + let client = reqwest::blocking::Client::new(); + let resp = client + .get(&url) + .headers(headers) + .send() + .map_err(R2Error::from)?; + let status = resp.status(); + if status.is_success() { + resp.text().map_err(R2Error::from) + } else { + Err(R2Error::Other(format!("Failed to list bucket: {}", status))) + } + } + + /// List all files in the specified bucket. Returns a HashMap of folder -> `Vec`. + pub fn list_files(&self, bucket: &str) -> Result>, R2Error> { + let xml = self.get_bucket_listing(bucket)?; + let mut files_dict: HashMap> = HashMap::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem { + let (folder, file_name): (String, String) = if let Some(idx) = file_key.rfind('/') { + (file_key[..idx].to_string(), file_key[idx + 1..].to_string()) + } else { + ("".to_string(), file_key.to_string()) + }; + files_dict.entry(folder).or_default().push(file_name); + } + } + Ok(files_dict) + } + + /// List all folders in the specified bucket. Returns a Vec of folder names. + pub fn list_folders(&self, bucket: &str) -> Result, R2Error> { + let xml = self.get_bucket_listing(bucket)?; + let mut folders = std::collections::HashSet::new(); + let root = xmltree::Element::parse(xml.as_bytes()).map_err(R2Error::from)?; + for content in root + .children + .iter() + .filter_map(|c| c.as_element()) + .filter(|e| e.name == "Contents") + { + let key_elem = content.get_child("Key").and_then(|k| k.get_text()); + if let Some(file_key) = key_elem { + if let Some(idx) = file_key.find('/') { + folders.insert(file_key[..idx].to_string()); + } + } + } + Ok(folders.into_iter().collect()) + } + + fn build_url(&self, bucket: &str, key: Option<&str>) -> String { + match key { + Some(k) => format!("{}/{}/{}", self.endpoint, bucket, k), + None => format!("{}/{}/", self.endpoint, bucket), + } + } +} +impl Default for R2Client { + fn default() -> Self { + let (access_key, secret_key, endpoint) = Self::get_env().unwrap(); + + Self { + access_key, + secret_key, + endpoint, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn r2client_from_env() -> R2Client { + unsafe { + std::env::set_var("R2_ACCESS_KEY", "AKIAEXAMPLE"); + std::env::set_var("R2_SECRET_KEY", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + std::env::set_var("R2_ENDPOINT", "https://example.r2.cloudflarestorage.com"); + } + R2Client::new() + } + + #[test] + fn r2client_env() { + let r2client = r2client_from_env(); + + assert_eq!(r2client.access_key, "AKIAEXAMPLE"); + assert_eq!( + r2client.secret_key, + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + ); + assert_eq!( + r2client.endpoint, + "https://example.r2.cloudflarestorage.com" + ); + } + + #[test] + fn test_sign_and_signature_key() { + let client = R2Client::from_credentials( + "AKIAEXAMPLE".to_string(), + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + "https://example.r2.cloudflarestorage.com".to_string(), + ); + let key = b"testkey"; + let msg = "testmsg"; + let sig = client.sign(key, msg); + assert_eq!(sig.len(), 32); // HMAC-SHA256 output is 32 bytes + + let date = "20250101"; + let region = "auto"; + let service = "s3"; + let signing_key = client.get_signature_key(date, region, service); + assert_eq!(signing_key.len(), 32); + } + + #[test] + fn test_create_headers() { + let client = R2Client::from_credentials( + "AKIAEXAMPLE".to_string(), + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + "https://example.r2.cloudflarestorage.com".to_string(), + ); + let headers = client + .create_headers( + "PUT", + "bucket", + "key", + "deadbeef", + "application/octet-stream", + ) + .unwrap(); + assert!(headers.contains_key("x-amz-date")); + assert!(headers.contains_key("authorization")); + assert!(headers.contains_key("content-type")); + assert!(headers.contains_key("host")); + } +} diff --git a/todo.md b/todo.md new file mode 100644 index 0000000..b94c17b --- /dev/null +++ b/todo.md @@ -0,0 +1,6 @@ +Okay I think I did everything, so to clean up: + +- [ ] Update the sync library +- [X] Make a .env with test-bucket creds +- [ ] Actually test the damn thing +- [ ] Cry (ALL OF THAT WORK, FOR WHAT!? A SINGLE `main.rs` ON GITHUB!?)