Skip to content

Commit

Permalink
Pass wasm_api::Tensor into TensorList::push by-reference
Browse files Browse the repository at this point in the history
Work around a bug [1] in wasm-bindgen when passing structs by-value into
methods, by passing `wasm_api::Tensor`s by-reference. In order to do this
without copying the underlying tensor, make `wasm_api::Tensor` use an Rc to
manage its output reference.

Alternative solutions would be to fix the issue in wasm-bindgen upstream or
to patch generated JS code snippets like this:

```
var ptr0 = someArg.ptr;
somePtr.ptr = 0;
```

To be like this instead:

```
var ptr0 = someArg.__destroy_into_raw();
```

[1] rustwasm/wasm-bindgen#2677
  • Loading branch information
robertknight committed Dec 9, 2022
1 parent 503fec6 commit f992cca
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/wasm_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use wasm_bindgen::prelude::*;

use std::collections::VecDeque;
use std::iter::zip;
use std::rc::Rc;

use crate::model;
use crate::ops::{Input, Output};
Expand Down Expand Up @@ -38,15 +39,15 @@ impl Model {
) -> Result<TensorList, String> {
let inputs: Vec<(usize, Input)> = zip(
input_ids.iter().copied(),
input.tensors.iter().map(|tensor| (&tensor.data).into()),
input.tensors.iter().map(|tensor| (&*tensor.data).into()),
)
.collect();
let result = self.model.run(&inputs[..], output_ids, None);
match result {
Ok(outputs) => {
let mut list = TensorList::new();
for output in outputs.into_iter() {
list.push(Tensor::from_output(output));
list.push(&Tensor::from_output(output));
}
Ok(list)
}
Expand All @@ -57,49 +58,54 @@ impl Model {

/// A wrapper around a multi-dimensional array model input or output.
#[wasm_bindgen]
#[derive(Clone)]
pub struct Tensor {
data: Output,
data: Rc<Output>,
}

#[wasm_bindgen]
impl Tensor {
#[wasm_bindgen(js_name = floatTensor)]
pub fn float_tensor(shape: &[usize], data: &[f32]) -> Tensor {
let data: Output = tensor::Tensor::from_data(shape.into(), data.into()).into();
Tensor { data }
Tensor {
data: Rc::new(data),
}
}

#[wasm_bindgen(js_name = intTensor)]
pub fn int_tensor(shape: &[usize], data: &[i32]) -> Tensor {
let data: Output = tensor::Tensor::from_data(shape.into(), data.into()).into();
Tensor { data }
Tensor {
data: Rc::new(data),
}
}

pub fn shape(&self) -> Vec<usize> {
match self.data {
match *self.data {
Output::IntTensor(ref t) => t.shape().into(),
Output::FloatTensor(ref t) => t.shape().into(),
}
}

#[wasm_bindgen(js_name = floatData)]
pub fn float_data(&self) -> Option<Vec<f32>> {
match self.data {
match *self.data {
Output::FloatTensor(ref t) => Some(t.elements_vec()),
_ => None,
}
}

#[wasm_bindgen(js_name = intData)]
pub fn int_data(&self) -> Option<Vec<i32>> {
match self.data {
match *self.data {
Output::IntTensor(ref t) => Some(t.elements_vec()),
_ => None,
}
}

fn from_output(out: Output) -> Tensor {
Tensor { data: out }
Tensor { data: Rc::new(out) }
}
}

Expand All @@ -125,8 +131,8 @@ impl TensorList {
}

/// Add a new tensor to the end of the list.
pub fn push(&mut self, tensor: Tensor) {
self.tensors.push_back(tensor);
pub fn push(&mut self, tensor: &Tensor) {
self.tensors.push_back(tensor.clone());
}

/// Remove and return the first tensor from this list.
Expand Down

0 comments on commit f992cca

Please sign in to comment.