Skip to content

Commit

Permalink
Merge pull request #110 from kornelski/bounds
Browse files Browse the repository at this point in the history
Avoid redundant bounds checks
  • Loading branch information
hsivonen authored Oct 31, 2024
2 parents eec0bde + 779da71 commit 8cec3a5
Showing 1 changed file with 73 additions and 53 deletions.
126 changes: 73 additions & 53 deletions src/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,9 +1266,8 @@ impl<'a> Utf16Source<'a> {
dest: &'b mut ByteDestination<'a>,
) -> CopyAsciiResult<(EncoderResult, usize, usize), (NonAscii, ByteTwoHandle<'b, 'a>)> {
let non_ascii_ret = {
let dst_len = dest.slice.len();
let src_remaining = &self.slice[self.pos..];
let dst_remaining = &mut dest.slice[dest.pos..];
let dst_remaining = dest.remaining();
let (pending, length) = if dst_remaining.len() < src_remaining.len() {
(EncoderResult::OutputFull, dst_remaining.len())
} else {
Expand All @@ -1279,13 +1278,13 @@ impl<'a> Utf16Source<'a> {
} {
None => {
self.pos += length;
dest.pos += length;
return CopyAsciiResult::Stop((pending, self.pos, dest.pos));
dest.advance(length);
return CopyAsciiResult::Stop((pending, self.pos, dest.written()));
}
Some((non_ascii, consumed)) => {
self.pos += consumed;
dest.pos += consumed;
if dest.pos + 1 < dst_len {
dest.advance(consumed);
if dest.remaining().len() >= 1 {
self.pos += 1; // commit to reading `non_ascii`
let unit = non_ascii;
let unit_minus_surrogate_start = unit.wrapping_sub(0xD800);
Expand Down Expand Up @@ -1322,7 +1321,7 @@ impl<'a> Utf16Source<'a> {
return CopyAsciiResult::Stop((
EncoderResult::OutputFull,
self.pos,
dest.pos,
dest.written(),
));
}
}
Expand All @@ -1336,9 +1335,8 @@ impl<'a> Utf16Source<'a> {
dest: &'b mut ByteDestination<'a>,
) -> CopyAsciiResult<(EncoderResult, usize, usize), (NonAscii, ByteFourHandle<'b, 'a>)> {
let non_ascii_ret = {
let dst_len = dest.slice.len();
let src_remaining = &self.slice[self.pos..];
let dst_remaining = &mut dest.slice[dest.pos..];
let dst_remaining = dest.remaining();
let (pending, length) = if dst_remaining.len() < src_remaining.len() {
(EncoderResult::OutputFull, dst_remaining.len())
} else {
Expand All @@ -1349,13 +1347,13 @@ impl<'a> Utf16Source<'a> {
} {
None => {
self.pos += length;
dest.pos += length;
return CopyAsciiResult::Stop((pending, self.pos, dest.pos));
dest.advance(length);
return CopyAsciiResult::Stop((pending, self.pos, dest.written()));
}
Some((non_ascii, consumed)) => {
self.pos += consumed;
dest.pos += consumed;
if dest.pos + 3 < dst_len {
dest.advance(consumed);
if dest.remaining().len() >= 3 {
self.pos += 1; // commit to reading `non_ascii`
let unit = non_ascii;
let unit_minus_surrogate_start = unit.wrapping_sub(0xD800);
Expand Down Expand Up @@ -1392,7 +1390,7 @@ impl<'a> Utf16Source<'a> {
return CopyAsciiResult::Stop((
EncoderResult::OutputFull,
self.pos,
dest.pos,
dest.written(),
));
}
}
Expand Down Expand Up @@ -1563,7 +1561,7 @@ impl<'a> Utf8Source<'a> {
) -> CopyAsciiResult<(EncoderResult, usize, usize), (NonAscii, ByteOneHandle<'b, 'a>)> {
let non_ascii_ret = {
let src_remaining = &self.slice[self.pos..];
let dst_remaining = &mut dest.slice[dest.pos..];
let dst_remaining = dest.remaining();
let (pending, length) = if dst_remaining.len() < src_remaining.len() {
(EncoderResult::OutputFull, dst_remaining.len())
} else {
Expand All @@ -1574,12 +1572,12 @@ impl<'a> Utf8Source<'a> {
} {
None => {
self.pos += length;
dest.pos += length;
return CopyAsciiResult::Stop((pending, self.pos, dest.pos));
dest.advance(length);
return CopyAsciiResult::Stop((pending, self.pos, dest.written()));
}
Some((non_ascii, consumed)) => {
self.pos += consumed;
dest.pos += consumed;
dest.advance(consumed);
// We don't need to check space in destination, because
// `ascii_to_ascii()` already did.
if non_ascii < 0xE0 {
Expand Down Expand Up @@ -1612,9 +1610,8 @@ impl<'a> Utf8Source<'a> {
dest: &'b mut ByteDestination<'a>,
) -> CopyAsciiResult<(EncoderResult, usize, usize), (NonAscii, ByteTwoHandle<'b, 'a>)> {
let non_ascii_ret = {
let dst_len = dest.slice.len();
let src_remaining = &self.slice[self.pos..];
let dst_remaining = &mut dest.slice[dest.pos..];
let dst_remaining = dest.remaining();
let (pending, length) = if dst_remaining.len() < src_remaining.len() {
(EncoderResult::OutputFull, dst_remaining.len())
} else {
Expand All @@ -1625,13 +1622,13 @@ impl<'a> Utf8Source<'a> {
} {
None => {
self.pos += length;
dest.pos += length;
return CopyAsciiResult::Stop((pending, self.pos, dest.pos));
dest.advance(length);
return CopyAsciiResult::Stop((pending, self.pos, dest.written()));
}
Some((non_ascii, consumed)) => {
self.pos += consumed;
dest.pos += consumed;
if dest.pos + 1 < dst_len {
dest.advance(consumed);
if dest.remaining().len() >= 1 {
if non_ascii < 0xE0 {
let point = ((u16::from(non_ascii) & 0x1F) << 6)
| (u16::from(self.slice[self.pos + 1]) & 0x3F);
Expand All @@ -1655,7 +1652,7 @@ impl<'a> Utf8Source<'a> {
return CopyAsciiResult::Stop((
EncoderResult::OutputFull,
self.pos,
dest.pos,
dest.written(),
));
}
}
Expand All @@ -1669,9 +1666,8 @@ impl<'a> Utf8Source<'a> {
dest: &'b mut ByteDestination<'a>,
) -> CopyAsciiResult<(EncoderResult, usize, usize), (NonAscii, ByteFourHandle<'b, 'a>)> {
let non_ascii_ret = {
let dst_len = dest.slice.len();
let src_remaining = &self.slice[self.pos..];
let dst_remaining = &mut dest.slice[dest.pos..];
let dst_remaining = dest.remaining();
let (pending, length) = if dst_remaining.len() < src_remaining.len() {
(EncoderResult::OutputFull, dst_remaining.len())
} else {
Expand All @@ -1682,13 +1678,13 @@ impl<'a> Utf8Source<'a> {
} {
None => {
self.pos += length;
dest.pos += length;
return CopyAsciiResult::Stop((pending, self.pos, dest.pos));
dest.advance(length);
return CopyAsciiResult::Stop((pending, self.pos, dest.written()));
}
Some((non_ascii, consumed)) => {
self.pos += consumed;
dest.pos += consumed;
if dest.pos + 3 < dst_len {
dest.advance(consumed);
if dest.remaining().len() >= 3 {
if non_ascii < 0xE0 {
let point = ((u16::from(non_ascii) & 0x1F) << 6)
| (u16::from(self.slice[self.pos + 1]) & 0x3F);
Expand All @@ -1712,7 +1708,7 @@ impl<'a> Utf8Source<'a> {
return CopyAsciiResult::Stop((
EncoderResult::OutputFull,
self.pos,
dest.pos,
dest.written(),
));
}
}
Expand Down Expand Up @@ -1928,74 +1924,98 @@ where

pub struct ByteDestination<'a> {
slice: &'a mut [u8],
pos: usize,
/// Pointer to the original start of the slice. It's never dereferenced.
start: *const u8,
}

impl<'a> ByteDestination<'a> {
#[inline(always)]
pub fn new(dst: &mut [u8]) -> ByteDestination {
ByteDestination { slice: dst, pos: 0 }
ByteDestination {
start: dst.as_ptr(),
slice: dst,
}
}
#[inline(always)]
pub fn remaining(&mut self) -> &mut [u8] {
&mut self.slice
}
#[inline(always)]
pub fn check_space_one<'b>(&'b mut self) -> Space<ByteOneHandle<'b, 'a>> {
if self.pos < self.slice.len() {
if self.slice.len() >= 1 {
Space::Available(ByteOneHandle::new(self))
} else {
Space::Full(self.written())
}
}
#[inline(always)]
pub fn check_space_two<'b>(&'b mut self) -> Space<ByteTwoHandle<'b, 'a>> {
if self.pos + 1 < self.slice.len() {
if self.slice.len() >= 2 {
Space::Available(ByteTwoHandle::new(self))
} else {
Space::Full(self.written())
}
}
#[inline(always)]
pub fn check_space_three<'b>(&'b mut self) -> Space<ByteThreeHandle<'b, 'a>> {
if self.pos + 2 < self.slice.len() {
if self.slice.len() >= 3 {
Space::Available(ByteThreeHandle::new(self))
} else {
Space::Full(self.written())
}
}
#[inline(always)]
pub fn check_space_four<'b>(&'b mut self) -> Space<ByteFourHandle<'b, 'a>> {
if self.pos + 3 < self.slice.len() {
if self.slice.len() >= 4 {
Space::Available(ByteFourHandle::new(self))
} else {
Space::Full(self.written())
}
}
#[inline(always)]
pub fn written(&self) -> usize {
self.pos
// ptr::byte_offset_from(), but safe
self.slice.as_ptr() as usize - self.start as usize
}
#[inline(always)]
fn write_one(&mut self, first: u8) {
self.slice[self.pos] = first;
self.pos += 1;
// take() is necessary to use the slice's full lifetime, rather than a shorter reborrow via self
let (dst, rest) = core::mem::take(&mut self.slice).split_first_mut().unwrap();
self.slice = rest;

*dst = first;
}
#[inline(always)]
fn write_two(&mut self, first: u8, second: u8) {
self.slice[self.pos] = first;
self.slice[self.pos + 1] = second;
self.pos += 2;
let (dst, rest) = core::mem::take(&mut self.slice).split_at_mut(2);
self.slice = rest;

dst[0] = first;
dst[1] = second;
}
#[inline(always)]
fn write_three(&mut self, first: u8, second: u8, third: u8) {
self.slice[self.pos] = first;
self.slice[self.pos + 1] = second;
self.slice[self.pos + 2] = third;
self.pos += 3;
let (dst, rest) = core::mem::take(&mut self.slice).split_at_mut(3);
self.slice = rest;

dst[0] = first;
dst[1] = second;
dst[2] = third;
}
#[inline(always)]
fn write_four(&mut self, first: u8, second: u8, third: u8, fourth: u8) {
self.slice[self.pos] = first;
self.slice[self.pos + 1] = second;
self.slice[self.pos + 2] = third;
self.slice[self.pos + 3] = fourth;
self.pos += 4;
// consecutive assignments to self.slice[pos+n] would have four bounds checks
let (dst, rest) = core::mem::take(&mut self.slice).split_at_mut(4);
self.slice = rest;

dst[0] = first;
dst[1] = second;
dst[2] = third;
dst[3] = fourth;
}
/// Assume this many bytes have been written
#[inline(always)]
pub fn advance(&mut self, length: usize) {
self.slice = &mut core::mem::take(&mut self.slice)[length..];
}
}

0 comments on commit 8cec3a5

Please sign in to comment.