|
23 | 23 | import yaml
|
24 | 24 | from torch.cuda import amp
|
25 | 25 | from torch.nn.parallel import DistributedDataParallel as DDP
|
26 |
| -from torch.optim import Adam, SGD, lr_scheduler |
| 26 | +from torch.optim import SGD, Adam, lr_scheduler |
27 | 27 | from tqdm import tqdm
|
28 | 28 |
|
29 | 29 | FILE = Path(__file__).resolve()
|
|
37 | 37 | from models.yolo import Model
|
38 | 38 | from utils.autoanchor import check_anchors
|
39 | 39 | from utils.autobatch import check_train_batch_size
|
| 40 | +from utils.callbacks import Callbacks |
40 | 41 | from utils.datasets import create_dataloader
|
41 |
| -from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ |
42 |
| - strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \ |
43 |
| - check_file, check_yaml, check_suffix, print_args, print_mutation, one_cycle, colorstr, methods, LOGGER |
44 | 42 | from utils.downloads import attempt_download
|
45 |
| -from utils.loss import ComputeLoss |
46 |
| -from utils.plots import plot_labels, plot_evolve |
47 |
| -from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \ |
48 |
| - torch_distributed_zero_first |
| 43 | +from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, |
| 44 | + check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, |
| 45 | + labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, |
| 46 | + print_mutation, strip_optimizer) |
| 47 | +from utils.loggers import Loggers |
49 | 48 | from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
| 49 | +from utils.loss import ComputeLoss |
50 | 50 | from utils.metrics import fitness
|
51 |
| -from utils.loggers import Loggers |
52 |
| -from utils.callbacks import Callbacks |
| 51 | +from utils.plots import plot_evolve, plot_labels |
| 52 | +from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, |
| 53 | + torch_distributed_zero_first) |
53 | 54 |
|
54 | 55 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
55 | 56 | RANK = int(os.getenv('RANK', -1))
|
|
0 commit comments