From 232b35a3049988d38c870bacc63c221627dabc76 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Wed, 16 Mar 2022 15:40:26 +0000 Subject: [PATCH] Ensure numbers do not overflow numeric bounds --- lib/src/sql/common.rs | 53 +++++++++++++++++++---------------------- lib/src/sql/duration.rs | 6 ++--- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/lib/src/sql/common.rs b/lib/src/sql/common.rs index 588a9c9f..c9906959 100644 --- a/lib/src/sql/common.rs +++ b/lib/src/sql/common.rs @@ -5,7 +5,6 @@ use nom::bytes::complete::tag; use nom::bytes::complete::take_while; use nom::bytes::complete::take_while_m_n; use nom::character::is_alphanumeric; -use nom::combinator::map; use nom::multi::many1; use nom::Err::Error; use std::ops::RangeBounds; @@ -30,21 +29,6 @@ pub fn is_digit(chr: char) -> bool { (0x30..=0x39).contains(&(chr as u8)) } -#[inline] -pub fn to_u32(s: &str) -> u32 { - str::FromStr::from_str(s).unwrap() -} - -#[inline] -pub fn to_u64(s: &str) -> u64 { - str::FromStr::from_str(s).unwrap() -} - -#[inline] -pub fn to_usize(s: &str) -> usize { - str::FromStr::from_str(s).unwrap() -} - #[inline] pub fn val_char(chr: char) -> bool { is_alphanumeric(chr as u8) || chr == '_' @@ -61,30 +45,41 @@ pub fn escape(s: &str, f: &dyn Fn(char) -> bool, c: &str) -> String { } pub fn take_u32(i: &str) -> IResult<&str, u32> { - let (i, v) = map(take_while(is_digit), to_u32)(i)?; - Ok((i, v)) + let (i, v) = take_while(is_digit)(i)?; + match str::FromStr::from_str(v) { + Ok(v) => Ok((i, v)), + _ => Err(Error(ParserError(i))), + } } pub fn take_u64(i: &str) -> IResult<&str, u64> { - let (i, v) = map(take_while(is_digit), to_u64)(i)?; - Ok((i, v)) + let (i, v) = take_while(is_digit)(i)?; + match str::FromStr::from_str(v) { + Ok(v) => Ok((i, v)), + _ => Err(Error(ParserError(i))), + } } pub fn take_usize(i: &str) -> IResult<&str, usize> { - let (i, v) = map(take_while(is_digit), to_usize)(i)?; - Ok((i, v)) + let (i, v) = take_while(is_digit)(i)?; + match str::FromStr::from_str(v) { + Ok(v) => Ok((i, v)), + _ => Err(Error(ParserError(i))), + } } pub fn take_digits(i: &str, n: usize) -> IResult<&str, u32> { - let (i, v) = map(take_while_m_n(n, n, is_digit), to_u32)(i)?; - Ok((i, v)) + let (i, v) = take_while_m_n(n, n, is_digit)(i)?; + match str::FromStr::from_str(v) { + Ok(v) => Ok((i, v)), + _ => Err(Error(ParserError(i))), + } } pub fn take_digits_range(i: &str, n: usize, range: impl RangeBounds) -> IResult<&str, u32> { - let (i, v) = map(take_while_m_n(n, n, is_digit), to_u32)(i)?; - if range.contains(&v) { - Ok((i, v)) - } else { - Err(Error(ParserError(i))) + let (i, v) = take_while_m_n(n, n, is_digit)(i)?; + match str::FromStr::from_str(v) { + Ok(v) if range.contains(&v) => Ok((i, v)), + _ => Err(Error(ParserError(i))), } } diff --git a/lib/src/sql/duration.rs b/lib/src/sql/duration.rs index 03cacce9..d5f61d97 100644 --- a/lib/src/sql/duration.rs +++ b/lib/src/sql/duration.rs @@ -1,14 +1,13 @@ +use crate::sql::common::take_u64; use crate::sql::datetime::Datetime; use crate::sql::error::IResult; use chrono::DurationRound; use nom::branch::alt; -use nom::bytes::complete::is_a; use nom::bytes::complete::tag; use serde::ser::SerializeStruct; use serde::{Deserialize, Serialize}; use std::fmt; use std::ops; -use std::str::FromStr; use std::time; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize)] @@ -131,8 +130,7 @@ pub fn duration_raw(i: &str) -> IResult<&str, Duration> { } fn part(i: &str) -> IResult<&str, u64> { - let (i, v) = is_a("1234567890")(i)?; - let v = u64::from_str(v).unwrap(); + let (i, v) = take_u64(i)?; Ok((i, v)) }