Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Fix tie_weights save and load #8137

Merged
merged 7 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@

model_state_dict = get_expected_state_dict(model)
expected_keys = set(list(model_state_dict.keys()))
if hasattr(model, "_tied_weights_keys") and model._tied_weights_keys is not None:
for key in model._tied_weights_keys:
expected_keys.remove(key)

Check warning on line 231 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L229-L231

Added lines #L229 - L231 were not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(model, "_tied_weights_keys") and model._tied_weights_keys is not None:
for key in model._tied_weights_keys:
expected_keys.remove(key)
if model._keys_to_ignore_on_save is not None:
for key in model._keys_to_ignore_on_save:
expected_keys.remove(key)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_keys_to_ignore_on_save = None
_tied_weights_keys = None

我们的模型,应该有这个属性

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用 _keys_to_ignore_on_save 可能更好一点。


missing_keys = expected_keys - set(loaded_keys)

if len(missing_keys) > 0:
Expand Down Expand Up @@ -607,6 +611,12 @@
static2struct_name_mappings = {}
state_dict = get_expected_state_dict(model)
for k, v in state_dict.items():
if (

Check warning on line 614 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L614

Added line #L614 was not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and k in model._tied_weights_keys
):
continue

Check warning on line 619 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L619

Added line #L619 was not covered by tests
static2struct_name_mappings[v.name] = k

# rename optimizer param
Expand Down Expand Up @@ -739,6 +749,12 @@
need_files = set()
state_dict = get_expected_state_dict(model)
for key in state_dict.keys():
if (

Check warning on line 752 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L752

Added line #L752 was not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and key in model._tied_weights_keys
):
continue

Check warning on line 757 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L757

Added line #L757 was not covered by tests
filename = index["weight_map"][key]
need_files.add(filename)
diff_filelist = list(need_files.difference(set(existed_files)))
Expand Down Expand Up @@ -829,6 +845,12 @@
need_files = set()
state_dict = get_expected_state_dict(model)
for key in state_dict.keys():
if (

Check warning on line 848 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L848

Added line #L848 was not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and key in model._tied_weights_keys
):
continue

Check warning on line 853 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L853

Added line #L853 was not covered by tests
if sharding_group.nranks > 1:
static_name = struct2static_name_mappings.get(key, None)
param_rank = param2rank.get(static_name, None)
Expand Down Expand Up @@ -893,6 +915,12 @@
index_weight_file = {}
total_size = 0
for key, weight in state_dict.items():
if (

Check warning on line 918 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L918

Added line #L918 was not covered by tests
hasattr(model_to_save, "_tied_weights_keys")
and model_to_save._tied_weights_keys is not None
and key in model_to_save._tied_weights_keys
):
continue

Check warning on line 923 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L923

Added line #L923 was not covered by tests
index_weight_file[key] = weight_filename
total_size += weight.numel().item() * dtype_byte_size(weight.dtype)
sharded_index_json = {}
Expand Down Expand Up @@ -926,6 +954,12 @@
static2struct_name_mappings = {}
state_dict = get_expected_state_dict(model)
for k, v in state_dict.items():
if (

Check warning on line 957 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L957

Added line #L957 was not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and k in model._tied_weights_keys
):
continue

Check warning on line 962 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L962

Added line #L962 was not covered by tests
static2struct_name_mappings[v.name] = k

# rename optimizer param
Expand Down Expand Up @@ -1023,6 +1057,12 @@
if args.dataset_rank == 0:
state_dict = get_expected_state_dict(model)
for (k, v) in state_dict.items():
if (

Check warning on line 1060 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1060

Added line #L1060 was not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and k in model._tied_weights_keys
):
continue

Check warning on line 1065 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1065

Added line #L1065 was not covered by tests
if hasattr(v, "is_distributed") and v.is_distributed:
recv_table[k] = [(dist.get_rank(), tp_rank)]
else:
Expand Down Expand Up @@ -1069,6 +1109,12 @@
if args.data_parallel_rank == 0:
state_dict = get_expected_state_dict(model)
for (k, v) in state_dict.items():
if (

Check warning on line 1112 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1112

Added line #L1112 was not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and k in model._tied_weights_keys
):
continue

Check warning on line 1117 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1117

Added line #L1117 was not covered by tests
if sharding_group.nranks > 1:
static_name = struct2static_name_mappings[k]
param_rank = param2rank.get(static_name, None)
Expand Down Expand Up @@ -1196,7 +1242,15 @@
_, typename = key.split("/")
typename_set.add(typename)
struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()}
static2struct_name_mappings = {v.name: k for k, v in get_expected_state_dict(model).items()}
static2struct_name_mappings = {}
for k, v in get_expected_state_dict(model).items():
if (

Check warning on line 1247 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1245-L1247

Added lines #L1245 - L1247 were not covered by tests
hasattr(model, "_tied_weights_keys")
and model._tied_weights_keys is not None
and k in model._tied_weights_keys
):
continue
static2struct_name_mappings[v.name] = k

Check warning on line 1253 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1252-L1253

Added lines #L1252 - L1253 were not covered by tests
# Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors.
send_table, recv_table = create_optimizer_dispatch_table(
args,
Expand Down Expand Up @@ -1349,6 +1403,9 @@
loaded_keys = sharded_metadata["all_checkpoint_keys"]
model_state_dict = get_expected_state_dict(model)
expected_keys = set(list(model_state_dict.keys()))
if hasattr(model, "_tied_weights_keys") and model._tied_weights_keys is not None:
for key in model._tied_weights_keys:
expected_keys.remove(key)

Check warning on line 1408 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1406-L1408

Added lines #L1406 - L1408 were not covered by tests
missing_keys = expected_keys - set(loaded_keys)

if len(missing_keys) > 0:
Expand Down Expand Up @@ -1656,6 +1713,12 @@
tensor_bytes_dict = {}
model_state_dict = get_expected_state_dict(model_to_save)
for (k, v) in state_dict.items():
if (

Check warning on line 1716 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1716

Added line #L1716 was not covered by tests
hasattr(model_to_save, "_tied_weights_keys")
and model_to_save._tied_weights_keys is not None
and k in model_to_save._tied_weights_keys
):
continue

Check warning on line 1721 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1721

Added line #L1721 was not covered by tests
model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v
if hasattr(model_v, "is_distributed") and model_v.is_distributed:
tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype)
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,7 @@
# Load in optimizer and scheduler states
self.optimizer.set_state_dict(opt_state_dict)
else:
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)

Check warning on line 2460 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2460

Added line #L2460 was not covered by tests
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")

if not self.args.ignore_load_lr_and_optim:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def forward(self, hidden_states):

class ChatGLMForCausalLM(ChatGLMPretrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder_weight"]
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.decoder_weight"]

def __init__(self, config: ChatGLMConfig):
super(ChatGLMForCausalLM, self).__init__(config)
Expand Down
Loading