Ensure numbers do not overflow numeric bounds

This commit is contained in:
Tobie Morgan Hitchcock 2022-03-16 15:40:26 +00:00
parent cdef111786
commit 232b35a304
2 changed files with 26 additions and 33 deletions

View file

@ -5,7 +5,6 @@ use nom::bytes::complete::tag;
use nom::bytes::complete::take_while;
use nom::bytes::complete::take_while_m_n;
use nom::character::is_alphanumeric;
use nom::combinator::map;
use nom::multi::many1;
use nom::Err::Error;
use std::ops::RangeBounds;
@ -30,21 +29,6 @@ pub fn is_digit(chr: char) -> bool {
(0x30..=0x39).contains(&(chr as u8))
}
#[inline]
pub fn to_u32(s: &str) -> u32 {
str::FromStr::from_str(s).unwrap()
}
#[inline]
pub fn to_u64(s: &str) -> u64 {
str::FromStr::from_str(s).unwrap()
}
#[inline]
pub fn to_usize(s: &str) -> usize {
str::FromStr::from_str(s).unwrap()
}
#[inline]
pub fn val_char(chr: char) -> bool {
is_alphanumeric(chr as u8) || chr == '_'
@ -61,30 +45,41 @@ pub fn escape(s: &str, f: &dyn Fn(char) -> bool, c: &str) -> String {
}
pub fn take_u32(i: &str) -> IResult<&str, u32> {
let (i, v) = map(take_while(is_digit), to_u32)(i)?;
Ok((i, v))
let (i, v) = take_while(is_digit)(i)?;
match str::FromStr::from_str(v) {
Ok(v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
}
}
pub fn take_u64(i: &str) -> IResult<&str, u64> {
let (i, v) = map(take_while(is_digit), to_u64)(i)?;
Ok((i, v))
let (i, v) = take_while(is_digit)(i)?;
match str::FromStr::from_str(v) {
Ok(v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
}
}
pub fn take_usize(i: &str) -> IResult<&str, usize> {
let (i, v) = map(take_while(is_digit), to_usize)(i)?;
Ok((i, v))
let (i, v) = take_while(is_digit)(i)?;
match str::FromStr::from_str(v) {
Ok(v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
}
}
pub fn take_digits(i: &str, n: usize) -> IResult<&str, u32> {
let (i, v) = map(take_while_m_n(n, n, is_digit), to_u32)(i)?;
Ok((i, v))
let (i, v) = take_while_m_n(n, n, is_digit)(i)?;
match str::FromStr::from_str(v) {
Ok(v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
}
}
pub fn take_digits_range(i: &str, n: usize, range: impl RangeBounds<u32>) -> IResult<&str, u32> {
let (i, v) = map(take_while_m_n(n, n, is_digit), to_u32)(i)?;
if range.contains(&v) {
Ok((i, v))
} else {
Err(Error(ParserError(i)))
let (i, v) = take_while_m_n(n, n, is_digit)(i)?;
match str::FromStr::from_str(v) {
Ok(v) if range.contains(&v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
}
}

View file

@ -1,14 +1,13 @@
use crate::sql::common::take_u64;
use crate::sql::datetime::Datetime;
use crate::sql::error::IResult;
use chrono::DurationRound;
use nom::branch::alt;
use nom::bytes::complete::is_a;
use nom::bytes::complete::tag;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops;
use std::str::FromStr;
use std::time;
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize)]
@ -131,8 +130,7 @@ pub fn duration_raw(i: &str) -> IResult<&str, Duration> {
}
fn part(i: &str) -> IResult<&str, u64> {
let (i, v) = is_a("1234567890")(i)?;
let v = u64::from_str(v).unwrap();
let (i, v) = take_u64(i)?;
Ok((i, v))
}