Skip to content

Commit

Permalink
[Bugfix] Fix transformers import order in megatron scripts (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
szhengac authored Jan 19, 2023
1 parent 353d6a7 commit f868eb8
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions examples/albert/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
import torch.nn.functional as F

from transformers import AutoConfig, AlbertModel

from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
Expand All @@ -35,8 +37,6 @@ def get_model(
impl="slapo",
delay_init=True,
):
from transformers import AutoConfig, AlbertModel

config = AutoConfig.from_pretrained(model_name)
if padded_vocab_size is not None:
config.vocab_size = padded_vocab_size
Expand Down
4 changes: 2 additions & 2 deletions examples/bert/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
import torch.nn.functional as F

from transformers import AutoConfig, BertModel

from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
Expand All @@ -35,8 +37,6 @@ def get_model(
impl="slapo",
delay_init=True,
):
from transformers import AutoConfig, BertModel

config = AutoConfig.from_pretrained(model_name)
if padded_vocab_size is not None:
config.vocab_size = padded_vocab_size
Expand Down
5 changes: 3 additions & 2 deletions examples/gpt/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import os

import torch

from transformers import AutoConfig, GPTNeoModel

from functools import partial
from megatron import get_args
from megatron import print_rank_0
Expand All @@ -30,8 +33,6 @@ def get_model(
impl="slapo",
delay_init=True,
):
from transformers import AutoConfig, GPTNeoModel

config = AutoConfig.from_pretrained(model_name)
if padded_vocab_size is not None:
config.vocab_size = padded_vocab_size
Expand Down
5 changes: 3 additions & 2 deletions examples/opt/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import os

import torch

from transformers import AutoConfig, OPTModel

from functools import partial
from megatron import get_args
from megatron import print_rank_0
Expand All @@ -30,8 +33,6 @@ def get_model(
impl="slapo",
delay_init=True,
):
from transformers import AutoConfig, OPTModel

config = AutoConfig.from_pretrained(model_name)
if padded_vocab_size is not None:
config.vocab_size = padded_vocab_size
Expand Down
4 changes: 2 additions & 2 deletions examples/roberta/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
import torch.nn.functional as F

from transformers import AutoConfig, RobertaModel

from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
Expand All @@ -32,8 +34,6 @@ def get_model(
impl="slapo",
delay_init=True,
):
from transformers import AutoConfig, RobertaModel

config = AutoConfig.from_pretrained(model_name)
if padded_vocab_size is not None:
config.vocab_size = padded_vocab_size
Expand Down
4 changes: 2 additions & 2 deletions examples/t5/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import torch

from transformers import AutoConfig, T5Model

from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
Expand Down Expand Up @@ -58,8 +60,6 @@ def get_model(
impl="slapo",
delay_init=True,
):
from transformers import AutoConfig, T5Model

config = AutoConfig.from_pretrained(model_name)
config.vocab_size = padded_vocab_size
config.use_cache = False
Expand Down

0 comments on commit f868eb8

Please sign in to comment.