Skip to content

Commit

Permalink
Added fast_init option
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 30, 2024
1 parent ebb50b5 commit a29ed5c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.1.1 (unreleased)

- Added `fast_init` option
- Improved performance of loading models
- Fixed error with `aggregation_strategy` option

Expand Down
5 changes: 5 additions & 0 deletions lib/transformers.rb
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,9 @@ def message
"not implemented yet"
end
end

class << self
attr_accessor :fast_init
end
self.fast_init = false
end
43 changes: 42 additions & 1 deletion lib/transformers/modeling_utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,47 @@
# limitations under the License.

module Transformers
module ModelingUtils
TORCH_INIT_FUNCTIONS = {
"uniform!" => Torch::NN::Init.method(:uniform!),
"normal!" => Torch::NN::Init.method(:normal!),
# "trunc_normal!" => Torch::NN::Init.method(:trunc_normal!),
"constant!" => Torch::NN::Init.method(:constant!),
"xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
"xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
"kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
"kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!),
# "uniform" => Torch::NN::Init.method(:uniform),
# "normal" => Torch::NN::Init.method(:normal),
# "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
# "xavier_normal" => Torch::NN::Init.method(:xavier_normal),
# "kaiming_uniform" => Torch::NN::Init.method(:kaiming_uniform),
# "kaiming_normal" => Torch::NN::Init.method(:kaiming_normal)
}

# private
# note: this improves loading time significantly, but is not thread-safe!
def self.no_init_weights
return yield unless Transformers.fast_init

_skip_init = lambda do |*args, **kwargs|
# pass
end
# Save the original initialization functions
TORCH_INIT_FUNCTIONS.each do |name, init_func|
Torch::NN::Init.singleton_class.undef_method(name)
Torch::NN::Init.define_singleton_method(name, &_skip_init)
end
yield
ensure
# Restore the original initialization functions
TORCH_INIT_FUNCTIONS.each do |name, init_func|
Torch::NN::Init.singleton_class.undef_method(name)
Torch::NN::Init.define_singleton_method(name, init_func)
end
end
end

module ModuleUtilsMixin
def get_extended_attention_mask(
attention_mask,
Expand Down Expand Up @@ -519,7 +560,7 @@ def from_pretrained(
config.name_or_path = pretrained_model_name_or_path

# Instantiate model.
model = new(config, *model_args, **model_kwargs)
model = ModelingUtils.no_init_weights { new(config, *model_args, **model_kwargs) }

# make sure we use the model's config since the __init__ call might have copied it
config = model.config
Expand Down
2 changes: 2 additions & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Transformers.logger.level = Logger::ERROR
end

Transformers.fast_init = true

class Minitest::Test
def assert_elements_in_delta(expected, actual)
assert_equal expected.size, actual.size
Expand Down

0 comments on commit a29ed5c

Please sign in to comment.