diff --git a/src/package/mod.rs b/src/package/mod.rs index e4c1136f..4dc7cc92 100644 --- a/src/package/mod.rs +++ b/src/package/mod.rs @@ -18,3 +18,143 @@ mod store; mod r#type; pub use self::{compressed::Package, name::PackageName, r#type::PackageType, store::PackageStore}; + +trait ParseError { + fn empty() -> Self; + fn too_long(current_length: usize) -> Self; + fn invalid_start(first: char) -> Self; + fn invalid_character(found: char, pos: usize) -> Self; +} + +/// Validation function for both package name and directories. They have very similar rules with +/// just extra allowed characters being different at the moment. +/// +/// Shared allowed characters are `a-z` for the first and `a-z0-9` + extras for the rest. +fn validate(raw: &str, extra_allowed_chars: &[u8], max_len: usize) -> Result<(), E> +where + E: ParseError, +{ + let (first, rest) = match raw.as_bytes() { + [] => return Err(E::empty()), + x if x.len() > max_len => return Err(E::too_long(x.len())), + [first, rest @ ..] => (first, rest), + }; + + if !first.is_ascii_lowercase() { + // Handle UTF-8 chars correctly + return Err(E::invalid_start(raw.chars().next().unwrap())); + } + + let is_disallowed = |&(_, c): &(usize, &u8)| { + !(c.is_ascii_lowercase() || c.is_ascii_digit() || extra_allowed_chars.contains(c)) + }; + + match rest.iter().enumerate().find(is_disallowed) { + // We need the +1 since the first character has been checked separately + Some((pos, _)) => Err(E::invalid_character( + // Handle UTF-8 chars correctly + raw.chars().nth(pos + 1).unwrap(), + pos + 1, + )), + None => Ok(()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum ParserError { + Empty, + TooLong(usize), + InvalidStart(char), + InvalidCharacter(char, usize), + } + + impl ParseError for ParserError { + fn empty() -> Self { + Self::Empty + } + + fn too_long(current_length: usize) -> Self { + Self::TooLong(current_length) + } + + fn invalid_start(first: char) -> Self { + Self::InvalidStart(first) + } + + fn invalid_character(found: char, pos: usize) -> Self { + Self::InvalidCharacter(found, pos) + } + } + + #[track_caller] + fn validate(raw: &str, extra_allowed_chars: &[u8], max_len: usize) -> Result<(), ParserError> { + super::validate(raw, extra_allowed_chars, max_len) + } + + #[test] + fn empty_fails() { + let res = validate("", &[], 10); + assert_eq!(res, Err(ParserError::Empty)); + } + + #[test] + fn length_check() { + let res = validate("abcdefghijklm", &[], 5); + assert_eq!(res, Err(ParserError::TooLong(13))); + + let res = validate("abcdefghijklm", &[], 10); + assert_eq!(res, Err(ParserError::TooLong(13))); + + let res = validate("abcdefghijklm", &[], 15); + assert_eq!(res, Ok(())); + } + + #[test] + fn invalid_start() { + let res = validate("Ab", &[b'A'], 5); + assert_eq!(res, Err(ParserError::InvalidStart('A'))); + + let res = validate("5b", &[b'5'], 5); + assert_eq!(res, Err(ParserError::InvalidStart('5'))); + + let res = validate("-b", &[b'_'], 5); + assert_eq!(res, Err(ParserError::InvalidStart('-'))); + + let res = validate("_b", &[b'_'], 5); + assert_eq!(res, Err(ParserError::InvalidStart('_'))); + + let res = validate("🦀b", &('🦀' as u32).to_ne_bytes(), 10); + assert_eq!(res, Err(ParserError::InvalidStart('🦀'))); + } + + #[test] + fn invalid_character() { + let res = validate("bAc", &[], 5); + assert_eq!(res, Err(ParserError::InvalidCharacter('A', 1))); + + let res = validate("bowl-", &[], 5); + assert_eq!(res, Err(ParserError::InvalidCharacter('-', 4))); + + let res = validate("bob_", &[], 5); + assert_eq!(res, Err(ParserError::InvalidCharacter('_', 3))); + + let res = validate("bo🦀", &[], 10); + assert_eq!(res, Err(ParserError::InvalidCharacter('🦀', 2))); + } + + #[test] + fn basic_format() { + let res = validate("abcdefghijklmnopqrstuvwxyz0123456789", &[], 36); + assert_eq!(res, Ok(())); + } + + #[test] + fn extra_allowed_chars() { + let res = validate("b0A_2d-c", &[b'A', b'_', b'-'], 10); + assert_eq!(res, Ok(())); + } +} diff --git a/src/package/name.rs b/src/package/name.rs index c6528023..f546b8f5 100644 --- a/src/package/name.rs +++ b/src/package/name.rs @@ -39,8 +39,29 @@ pub enum PackageNameError { InvalidCharacter(char, usize), } +impl super::ParseError for PackageNameError { + #[inline] + fn empty() -> Self { + Self::Empty + } + + #[inline] + fn too_long(current_length: usize) -> Self { + Self::TooLong(current_length) + } + + #[inline] + fn invalid_start(first: char) -> Self { + Self::InvalidStart(first) + } + + #[inline] + fn invalid_character(found: char, pos: usize) -> Self { + Self::InvalidCharacter(found, pos) + } +} + impl PackageName { - const MIN_LENGTH: usize = 1; const MAX_LENGTH: usize = 128; /// New package name from string. @@ -55,53 +76,9 @@ impl PackageName { Self(value.into()) } - /// Determine if this character is allowed at the start of a package name. - fn is_allowed_start(c: char) -> bool { - c.is_alphabetic() - } - - /// Determine if this character is allowed anywhere in a package name. - fn is_allowed(c: char) -> bool { - let is_ascii_lowercase_alphanumeric = - |c: char| c.is_ascii_alphanumeric() && !c.is_ascii_uppercase(); - match c { - '-' => true, - c if is_ascii_lowercase_alphanumeric(c) => true, - _ => false, - } - } - /// Validate a package name. pub fn validate(name: impl AsRef) -> Result<(), PackageNameError> { - let name = name.as_ref(); - - // validate length - if name.len() < Self::MIN_LENGTH { - return Err(PackageNameError::Empty); - } - - if name.len() > Self::MAX_LENGTH { - return Err(PackageNameError::TooLong(name.len())); - } - - // validate first character - match name.chars().next() { - Some(c) if Self::is_allowed_start(c) => {} - Some(c) => return Err(PackageNameError::InvalidStart(c)), - None => unreachable!(), - } - - // validate all characters - let illegal = name - .chars() - .enumerate() - .find(|(_, c)| !Self::is_allowed(*c)); - - if let Some((index, c)) = illegal { - return Err(PackageNameError::InvalidCharacter(c, index)); - } - - Ok(()) + super::validate(name.as_ref(), &[b'-'], Self::MAX_LENGTH) } } @@ -148,7 +125,6 @@ mod test { #[test] fn ascii_lowercase() { assert_eq!(PackageName::new("abc"), Ok(PackageName("abc".into()))); - assert_eq!(PackageName::new("abc"), Ok(PackageName("abc".into()))); } #[test] @@ -160,13 +136,13 @@ mod test { #[test] fn long() { assert_eq!( - PackageName::new("a".repeat(128)), - Ok(PackageName("a".repeat(128))) + PackageName::new("a".repeat(PackageName::MAX_LENGTH)), + Ok(PackageName("a".repeat(PackageName::MAX_LENGTH))) ); assert_eq!( - PackageName::new("a".repeat(129)), - Err(PackageNameError::TooLong(129)) + PackageName::new("a".repeat(PackageName::MAX_LENGTH + 1)), + Err(PackageNameError::TooLong(PackageName::MAX_LENGTH + 1)) ); }