Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctness::compute branchless optimization #10

Open
FrancescoRuta opened this issue Mar 7, 2022 · 7 comments
Open

Correctness::compute branchless optimization #10

FrancescoRuta opened this issue Mar 7, 2022 · 7 comments

Comments

@FrancescoRuta
Copy link

FrancescoRuta commented Mar 7, 2022

I was thinking about a possible branchless implementation for the function Correctness::compute and I wrote this:

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Correctness {
	Wrong = 0,
	Misplaced = 1,
	Correct = 2,
}


impl Correctness {
	
	
	// Performs equality check
	// Example:
	// Input:
	//     answer = 01101 00101 01000 00010 01111
	//     guess  = 01101 01111 01011 00010 01010
	//     result = 11111 00000 00000 11111 00000
	fn check_for_eq(answer: u32, guess: u32) -> u32 {
		let mut result = !(answer ^ guess);
		result &= (result >> 1) & (result >> 2) & (result >> 3) & (result >> 4) & 0b00001_00001_00001_00001_00001u32;
		result |= (result << 1) | (result << 2) | (result << 3) | (result << 4);
		result
	}
	
	pub fn compute(answer: &str, guess: &str) -> [Correctness; 5] {
		const TO_NUMBER: u8 = 'a' as u8 - 1;
		
		
		let answer = answer.as_bytes();
		let guess = guess.as_bytes();
		
		// Assuming answer and guess are lowercase and use only alphabetic characters
		// we can rappresent them using 5bits per letter (in this case 'a' = 1)
		let mut answer = (answer[0] - TO_NUMBER) as u32 | 
			(((answer[1] - TO_NUMBER) as u32) << 5) | 
			(((answer[2] - TO_NUMBER) as u32) << 10) | 
			(((answer[3] - TO_NUMBER) as u32) << 15) | 
			(((answer[4] - TO_NUMBER) as u32) << 20);
		
		let mut guess = (guess[0] - TO_NUMBER) as u32 | 
			(((guess[1] - TO_NUMBER) as u32) << 5) | 
			(((guess[2] - TO_NUMBER) as u32) << 10) | 
			(((guess[3] - TO_NUMBER) as u32) << 15) | 
			(((guess[4] - TO_NUMBER) as u32) << 20);
		
		let green = Self::check_for_eq(answer, guess);
		
		// Removing green letters:
		// setting used letters from guess to 11111
		guess |= green;
		// setting used letters from answer to 00000
		answer &= !green;
		
		// To detect yellow letters we can shift guess 4 times and check the equality
		let mut yellow = 0u32;
		for _ in 0..4 {
			guess = (guess >> 5) | (guess << 20);
			let n = Self::check_for_eq(answer, guess);
			// setting used letters from guess to 11111
			guess |= n;
			// setting used letters from answer to 00000
			answer &= !n;
			// yellow must be moved with guess to avoid overwrites
			yellow = n | (yellow >> 5) | (yellow << 20);
		}
		// moving yellow to its original position
		yellow = (yellow >> 5) | (yellow << 20);
		
		let result = (green & 0b00010_00010_00010_00010_00010u32) | yellow & 0b00001_00001_00001_00001_00001;
		
		// Safety: ((result >> (n * 5)) & 0b11111) as u8, where: 0 <= n < 5
		//             - Can only give 0, 1 or 2 as result
		//             - Enum is #[repr(u8)]
		[
			unsafe { std::mem::transmute((result & 0b11111) as u8) },
			unsafe { std::mem::transmute(((result >> 5) & 0b11111) as u8) },
			unsafe { std::mem::transmute(((result >> 10) & 0b11111) as u8) },
			unsafe { std::mem::transmute(((result >> 15) & 0b11111) as u8) },
			unsafe { std::mem::transmute(((result >> 20) & 0b11111) as u8) },
		]
	}
	
}

I tested it in release with this:

fn main() {
	let b0 = benchmark(|| Correctness::compute("abcde", "fghij"));
	let b1 = benchmark(|| jonhoo::Correctness::compute("abcde", "fghij"));
	println!("{} ns/op", b0);
	println!("{} ns/op", b1);
	println!("b1/b0 = {}", b1 / b0);
}

// ns/op
fn benchmark<T, F: Fn() -> T>(f: F) -> f64 {
	const ITERATIONS: usize = 1000000000;
	let now = Instant::now();
	for _ in 0..ITERATIONS {
		execute_it_pls(f());
	}
	let el = now.elapsed().as_millis();
	el as f64 / (ITERATIONS as f64 / 1000000.0)
}

fn execute_it_pls<T>(dummy: T) -> T {
	let ptr = (&dummy) as *const _;
	unsafe { asm!("/* {0} */", in(reg) ptr) }
	dummy
}

And it seems to be about 33 times faster on my PC:

0.207 ns/op
6.902 ns/op
b1/b0 = 33.34299516908213
@jonhoo
Copy link
Owner

jonhoo commented Mar 8, 2022

Nice! You should try cloning the code and modifying Compute::correctness directly so you can run the test suite. I suspect this doesn't handle the case where a letter appears once in the answer, and twice in the guess but both in the wrong position. Specifically, I think it'd end up marking both occurrences in the guess as yellow, even though only the first should be. But I'd be happy to be proven wrong!

@FrancescoRuta
Copy link
Author

I wrote this test and it works:

#[test]
fn my_test() {
	assert_eq!(Correctness::compute("abbbb", "caacc"), mask![W M W W W]);
}

I cloned the repo and replaced the Correctness::compute function, i ran cargo test and all tests passed.
I think the case you described is managed by the following lines:

// setting used letters from guess to 11111
guess |= n;
// setting used letters from answer to 00000
answer &= !n;

@jonhoo
Copy link
Owner

jonhoo commented Mar 9, 2022

Neat! Care to submit a PR? That should make it a little easier to review this change and potentially merge it!

@FrancescoRuta
Copy link
Author

I was going to do so, but i found out that my implementation diverges from yours for some cases. For example:
assert_eq!(Correctness::compute('duchy', 'which'), mask![W W W M M]);
I wasn't worried about it because it should still be correct, but it turns out that this doesn't play well with Guess::matches:
Guess::matches(&Self { word: 'which', mask: mask![W W W M M] }, 'duchy') returns false
Guess::matches(&Self { word: 'which', mask: mask![W M W M W] }, 'duchy') returns true

@jonhoo
Copy link
Owner

jonhoo commented Mar 11, 2022

Hmm, that suggests there's a bug in your implementation, since the computation of a correctness mask should be deterministic, so any optimization should still yield exactly the same result.

@FrancescoRuta
Copy link
Author

Yeah, I tried to solve this problem, but any changes lead to a different case diverging from your implementation.
I take this opportunity to thank you for your wonderful streams!

@niewerth
Copy link

The underlying problem is as follows: For some combinations of word and guess, there is more than one correct mask. Namely, if a misplaced character occurs at two locations in the guess, but only one location in the word. Either of the two locations can be marked as misplaced while the other one is marked as wrong. You get the exact same information.

Now the implemented algorithm requires masks to be normalized. This is a bit more than computation of masks just being deterministic. Deterministic just says for every combination of word and guess, you always have to produce the same mask. This holds true for both algorithms the one in the repo and the proposed one. But the proposed algorithm has a problem. If you have two words w1 and w2 that both have the character that occurs in two locations in the guess, then it is possible that for w1 the mask computation marks the first location as misplaced and for w2 the second. And this is a big problem for the algorithm that relies on both masks being identical in order to determine that w1 and w2 are compatible with the guess and the mask. This comes from the optimization that reuses the mask computation to check whether a word is compatible with guess and mask.

The mask computation in the repo satisfies this needs, as it always marks the leftmost location in the guess as misplaced, whenever there is a choice. But the fancy rotation and bit-operation stuff in the branchless version has the problem, that the location that is being marked as misplaced depends on the position of the misplaced character in the word.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants