diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 9e02fc30..8f2ebedf 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -11,6 +11,7 @@ import torch import torch.backends import transformers +import wandb from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import ( @@ -29,8 +30,6 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from tqdm import tqdm -import wandb - from .args import _INVERSE_DTYPE_MAP, Args, validate_args from .constants import ( FINETRAINERS_LOG_LEVEL,