diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c52785df..d643f2c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,17 +14,14 @@ jobs: steps: - name: Checkout uses: actions/checkout@v3 - - name: Setup rust + - name: Install Rust wasm toolchain run: rustup target add wasm32-unknown-unknown if: ${{ matrix.os == 'ubuntu-latest' }} - name: Cache uses: actions/cache@v3 with: path: | - ~/.cargo/bin/ - ~/.cargo/registry/index/ - ~/.cargo/registry/cache/ - ~/.cargo/git/db/ + ~/.cargo/ target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - name: Install wasm-bindgen diff --git a/src/gemm.rs b/src/gemm.rs index 3a6d1a85..43de552c 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -1460,6 +1460,12 @@ mod tests { n: 128, k: 512, }, + // Vector-matrix. This is common in transformer decoders for example. + Case { + m: 1, + n: 4096, + k: 512, + }, ]; println!("Testing kernel {}", GemmExecutor::new().kernel_name()); @@ -1473,6 +1479,10 @@ mod tests { let target_ops: u64 = 512 * 512 * 512 * 1000; let iters = target_ops / (m * n * k) as u64; + // Cap the number of iterations, for cases where the equal-efficiency + // assumption is untrue. + let iters = iters.min(1000); + let mut rng = XorShiftRng::new(1234); let mut result = Tensor::zeros(&[m, n]); let a = Tensor::rand(&[m, k], &mut rng);