Skip to content

Commit

Permalink
Some changes to Sampling Op (#14218)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
1. add an optional input to pass in seed
2. two UTs. one for top_p=0.5, another for top_p=0.01(create greedy
search result, in convert_generation.py)
3. fix a bug in cpu kernel

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
wangyems and Ubuntu authored Jan 12, 2023
1 parent 3898b22 commit c9a53c9
Show file tree
Hide file tree
Showing 14 changed files with 236 additions and 23 deletions.
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3945,7 +3945,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (2 - 8)
#### Inputs (2 - 9)

<dl>
<dt><tt>input_ids</tt> : I</dt>
Expand All @@ -3964,6 +3964,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>presence_mask</tt> (optional) : I</dt>
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
<dt><tt>seed</tt> (optional) : I</dt>
<dd>Seed for random number generator. Shape is (1)</dd>
</dl>

#### Outputs (1 - 2)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
Expand Down Expand Up @@ -814,7 +814,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, in
DumpCpuTensor<MLFloat16>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
if (!is_enabled_)
return;
DumpCpuTensor<size_t>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (!is_enabled_)
return;
Expand Down Expand Up @@ -180,6 +186,9 @@ void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const {
void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CpuTensorConsoleDumper : public IConsoleDumper {
virtual ~CpuTensorConsoleDumper() {}
void Print(const char* name, const float* tensor, int dim0, int dim1) const override;
void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override;
void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class IConsoleDumper {
bool IsEnabled() const { return is_enabled_; }
virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0;
Expand Down
38 changes: 21 additions & 17 deletions onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ template <typename T>
void filter_scores(std::vector<size_t>& sorted_indice,
gsl::span<T>& next_token_score,
const transformers::IGenerationParameters* parameters,
size_t index) {
size_t real_index = sorted_indice[index];
next_token_score[real_index] = (T)parameters->filter_value;
size_t chunk_offset,
size_t offset) {
size_t real_index = sorted_indice[chunk_offset + offset];
next_token_score[chunk_offset + real_index] = (T)parameters->filter_value;
}

template <typename T>
Expand All @@ -23,12 +24,12 @@ void cumulate_and_filter_custom(gsl::span<T>& next_token_scores,
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] > parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset);
filter_scores(sorted_indices, next_token_scores, parameters, offset, 1);
}
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - 1; j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] > parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1);
filter_scores(sorted_indices, next_token_scores, parameters, offset, j + 1);
}
}
}
Expand All @@ -42,12 +43,12 @@ void cumulate_and_filter(gsl::span<T>& next_token_scores,
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] <= 1 - parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, offset);
filter_scores(sorted_indices, next_token_scores, parameters, offset, 0);
}
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - static_cast<size_t>(parameters->min_tokens_to_keep); j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] <= 1 - parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, j + offset);
filter_scores(sorted_indices, next_token_scores, parameters, offset, j);
}
}
}
Expand Down Expand Up @@ -78,17 +79,23 @@ Status Sample(AllocatorPtr& allocator,
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size;
auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size;
gsl::span<T> next_token_score = next_token_scores.subspan(i * parameters->vocab_size, parameters->vocab_size);
std::iota(indices_begin, indices_end, 0);
std::sort(indices_begin, indices_end,
[&next_token_scores, &predicator](size_t i1, size_t i2) {
return !predicator(next_token_scores[i1], next_token_scores[i2]);
[&next_token_score, &predicator](size_t i1, size_t i2) {
return predicator(next_token_score[i1], next_token_score[i2]);
});

std::sort(sorted_scores.begin() + i * parameters->vocab_size,
sorted_scores.begin() + (i + 1) * parameters->vocab_size,
predicator);
}

#ifdef DEBUG_GENERATION
dumper->Print("sorted_scores", sorted_scores.data(), parameters->batch_size, parameters->vocab_size);
dumper->Print("sorted_indices", sorted_indices.data(), parameters->batch_size, parameters->vocab_size);
#endif

gsl::span<T>& cumulative_probs = sampling_state->cumulative_probs;

ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
Expand All @@ -104,13 +111,10 @@ Status Sample(AllocatorPtr& allocator,
cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices);
}

gsl::span<T>& next_token_probs = sampling_state->h_softmaxed_score;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
parameters->vocab_size,
next_token_scores.data(),
next_token_probs.data(),
false,
thread_pool));
#ifdef DEBUG_GENERATION
dumper->Print("cumulative_probs after filtering", cumulative_probs.data(), parameters->batch_size, parameters->vocab_size);
dumper->Print("next_token_scores after filtering", next_token_scores.data(), parameters->batch_size, parameters->vocab_size);
#endif

// torch.multinomial()
int64_t next_token_probs_dims[] = {static_cast<int64_t>(parameters->batch_size), parameters->vocab_size};
Expand All @@ -119,7 +123,7 @@ Status Sample(AllocatorPtr& allocator,
OrtValue next_token_probs_value;
Tensor::InitOrtValue(element_type,
next_token_probs_shape,
next_token_probs.data(),
next_token_scores.data(),
allocator->Info(),
next_token_probs_value);
const Tensor& input = next_token_probs_value.Get<Tensor>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) {
vocab_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("vocab_size", -1));
}

void SamplingParameters::ParseFromInputs(OpKernelContext* context) {
this->GreedySearchParameters::ParseFromInputs(context);

auto* seed_tensor = context->Input<Tensor>(8);
seed = seed_tensor ? static_cast<int>(*seed_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(seed >= 0, "Seed must be >= 0");
}

} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace transformers {

struct SamplingParameters : public GreedySearchParameters {
void ParseFromAttributes(const OpKernelInfo& info);

void ParseFromInputs(OpKernelContext* context);
};

} // namespace transformers
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, true);
}

void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<size_t>(name, tensor, dim0, dim1, true);
}

void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, true);
Expand Down Expand Up @@ -212,6 +217,9 @@ void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const {
void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}

void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
}

void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons
virtual ~CudaTensorConsoleDumper() {}
void Print(const char* name, const float* tensor, int dim0, int dim1) const override;
void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override;
void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
.Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(8, "seed", "Seed for random number generator. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
.Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
Expand Down
27 changes: 24 additions & 3 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
)
model_group.set_defaults(presence_mask=False)

model_group.add_argument(
"--seed",
required=False,
action="store_true",
help="Random seed for sampling op",
)
model_group.set_defaults(seed=False)

beam_parameters_group = parser.add_argument_group(
"Beam search parameters not stored in the output model, for testing parity and performance"
)
Expand Down Expand Up @@ -1531,6 +1539,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati

if is_sampling and args.custom and args.presence_mask:
inputs.append("presence_mask")
else:
inputs.append("")

if is_sampling and args.seed:
inputs.append("seed")

outputs = ["sequences"]
if args.output_sequences_scores:
Expand Down Expand Up @@ -1709,6 +1722,10 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
)
graph_inputs.append(presence_mask)

if is_sampling and args.seed:
seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
graph_inputs.append(seed)

# graph outputs
sequences = None
if is_beamsearch:
Expand Down Expand Up @@ -2278,9 +2295,13 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None
if args.model_type == "gpt2" and is_greedy:
if args.top_p > 0.0 and args.top_p < 1.0:
convert_generation_model(args, GenerationType.SAMPLING)
logger.info("The test for gpt2_sampling onnx model is not implemented yet")
return
convert_generation_model(args, GenerationType.GREEDYSEARCH)
logger.info(
"The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
)
if args.top_p > 0.01 or args.custom or args.seed:
return
else:
convert_generation_model(args, GenerationType.GREEDYSEARCH)
else:
convert_generation_model(args)

Expand Down
Loading

0 comments on commit c9a53c9

Please sign in to comment.