Skip to content

Commit

Permalink
Support int tensors in ArgMin and ArgMax ops
Browse files Browse the repository at this point in the history
This is used by OpenAI clip models for example
(https://huggingface.co/openai/clip-vit-base-patch32).
  • Loading branch information
robertknight committed Aug 25, 2024
1 parent 5ce30bc commit fee3f09
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,40 @@ fn select_max_index<T, Cmp: Fn(&T, &T) -> std::cmp::Ordering>(
Ok(reduced)
}

/// Dispatch a reduction over multiple axes.
macro_rules! dispatch_reduce_op {
($pool:expr, $input:expr, $reduce_op:ident, $axes:expr, $keep_dims:expr) => {
match $input {
Input::FloatTensor(input) => $reduce_op(
$pool,
input,
$axes.as_ref().map(|axis| &axis[..]),
$keep_dims,
)
.into_op_result(),
Input::IntTensor(input) => $reduce_op(
$pool,
input,
$axes.as_ref().map(|axis| &axis[..]),
$keep_dims,
)
.into_op_result(),
}
};
}

/// Dispatch a reduction over a single axis.
macro_rules! dispatch_single_axis_reduce_op {
($pool:expr, $input:expr, $reduce_op:ident, $axis:expr, $keep_dims:expr) => {
match $input {
Input::FloatTensor(input) => {
$reduce_op($pool, input, $axis, $keep_dims).into_op_result()
}
Input::IntTensor(input) => $reduce_op($pool, input, $axis, $keep_dims).into_op_result(),
}
};
}

/// Return the index of the maximum value along a given axis.
///
/// NaN values are propagated by treating NaNs as greater than other values.
Expand All @@ -80,8 +114,8 @@ impl Operator for ArgMax {
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require_as::<f32>(0)?;
arg_max(pool, input, self.axis, self.keep_dims).into_op_result()
let input = inputs.require(0)?;
dispatch_single_axis_reduce_op!(pool, input, arg_max, self.axis, self.keep_dims)
}
}

Expand Down Expand Up @@ -114,8 +148,8 @@ impl Operator for ArgMin {
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require_as::<f32>(0)?;
arg_min(pool, input, self.axis, self.keep_dims).into_op_result()
let input = inputs.require(0)?;
dispatch_single_axis_reduce_op!(pool, input, arg_max, self.axis, self.keep_dims)
}
}

Expand Down Expand Up @@ -451,27 +485,6 @@ impl Operator for ReduceL2 {
}
}

macro_rules! dispatch_reduce_op {
($pool:expr, $input:expr, $reduce_op:ident, $axes:expr, $keep_dims:expr) => {
match $input {
Input::FloatTensor(input) => $reduce_op(
$pool,
input,
$axes.as_ref().map(|axis| &axis[..]),
$keep_dims,
)
.into_op_result(),
Input::IntTensor(input) => $reduce_op(
$pool,
input,
$axes.as_ref().map(|axis| &axis[..]),
$keep_dims,
)
.into_op_result(),
}
};
}

fn is_nan<T: PartialOrd>(a: &T) -> bool {
a.partial_cmp(a).is_none()
}
Expand Down

0 comments on commit fee3f09

Please sign in to comment.