Skip to content

Commit

Permalink
Reorder imports with isort (#326)
Browse files Browse the repository at this point in the history
* Apply isort everywhere

* Config isort in CI

* Update pyproject.toml

Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
yhn112 and mryab authored Jul 29, 2021
1 parent 0774937 commit bedfa6e
Show file tree
Hide file tree
Showing 68 changed files with 179 additions and 165 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/check-style.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Check style

on: [ push ]

jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: psf/black@stable
with:
options: "--check --diff"
version: "21.6b0"
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: isort/isort-action@master
13 changes: 0 additions & 13 deletions .github/workflows/check_style.yml

This file was deleted.

5 changes: 3 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ with the following rules:

## Code style

* We use [black](https://github.com/psf/black) for code formatting. Before submitting a PR, make sure to install and
run `black .` in the root of the repository.
* The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
cannot be longer than 119 characters.
* We use [black](https://github.com/psf/black) for code formatting and [isort](https://github.com/PyCQA/isort) for
import sorting. Before submitting a PR, make sure to install and run `black .` and `isort .` in the root of the
repository.
* We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
* Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
output/error streams.
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/benchmark_tensor_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import torch

from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.utils.logging import get_logger


logger = get_logger(__name__)


Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
# sys.path.insert(0, os.path.abspath('.'))
import sys

from recommonmark.transform import AutoStructify
from recommonmark.parser import CommonMarkParser

from recommonmark.transform import AutoStructify

# -- Project information -----------------------------------------------------
src_path = "../hivemind"
Expand Down
2 changes: 1 addition & 1 deletion examples/albert/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, List
from typing import List, Optional

from transformers import TrainingArguments

Expand Down
6 changes: 3 additions & 3 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch_optimizer import Lamb
from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.trainer import Trainer
from transformers.trainer_utils import is_main_process
Expand All @@ -21,7 +21,7 @@
from hivemind.utils.compression import CompressionType

import utils
from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments

logger = logging.getLogger(__name__)
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
Expand Down
4 changes: 2 additions & 2 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch
import wandb
from torch_optimizer import Lamb
from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser

import hivemind
from hivemind.utils.compression import CompressionType

import utils
from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments

logger = logging.getLogger(__name__)

Expand Down
1 change: 0 additions & 1 deletion examples/albert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from hivemind.dht.validation import RecordValidatorBase
from hivemind.utils.logging import get_logger


logger = get_logger(__name__)


Expand Down
8 changes: 4 additions & 4 deletions hivemind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
from hivemind.dht import DHT
from hivemind.moe import (
ExpertBackend,
Server,
register_expert_class,
RemoteExpert,
RemoteMixtureOfExperts,
RemoteSwitchMixtureOfExperts,
Server,
register_expert_class,
)
from hivemind.optim import (
CollaborativeAdaptiveOptimizer,
DecentralizedOptimizerBase,
CollaborativeOptimizer,
DecentralizedAdam,
DecentralizedOptimizer,
DecentralizedOptimizerBase,
DecentralizedSGD,
DecentralizedAdam,
)
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
from hivemind.utils import *
Expand Down
8 changes: 4 additions & 4 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import torch

from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
from hivemind.utils import get_logger
from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.proto import averaging_pb2
from hivemind.utils import get_logger
from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor

# flavour types
GroupID = bytes
Expand Down
14 changes: 7 additions & 7 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@
import weakref
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import asdict
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
from hivemind.dht import DHT, DHTID
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2, runtime_pb2
from hivemind.utils import MPFuture, get_logger, TensorDescriptor
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils.grpc import split_for_streaming, combine_from_streaming
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time

# flavour types
GatheredData = Any
Expand Down
6 changes: 3 additions & 3 deletions hivemind/averaging/key_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import re
import random
from typing import Optional, List, Tuple
import re
from typing import List, Optional, Tuple

import numpy as np

from hivemind.averaging.group_info import GroupInfo
from hivemind.dht import DHT
from hivemind.p2p import PeerID
from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger

GroupKey = str
GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$") # e.g. bert_exp4_averaging.0b01001101
Expand Down
3 changes: 2 additions & 1 deletion hivemind/averaging/load_balancing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, Optional, Tuple
from typing import Optional, Sequence, Tuple

import numpy as np
import scipy.optimize

Expand Down
6 changes: 3 additions & 3 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type

from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
from hivemind.dht import DHT, DHTID, DHTExpiration
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
from hivemind.utils.asyncio import anext
from hivemind.proto import averaging_pb2
from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
from hivemind.utils.asyncio import anext

logger = get_logger(__name__)

Expand Down
7 changes: 3 additions & 4 deletions hivemind/averaging/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
Auxiliary data structures for AllReduceRunner
"""
import asyncio
from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
from collections import deque
from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union

import torch
import numpy as np
import torch

from hivemind.proto.runtime_pb2 import CompressionType, Tensor
from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
from hivemind.utils.asyncio import amap_in_executor

from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor

T = TypeVar("T")
DEFAULT_PART_SIZE_BYTES = 2 ** 19
Expand Down
6 changes: 3 additions & 3 deletions hivemind/averaging/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from itertools import chain
from threading import Lock, Event
from typing import Sequence, Dict, Iterator, Optional
from threading import Event, Lock
from typing import Dict, Iterator, Optional, Sequence

import torch

from hivemind.averaging import DecentralizedAverager
from hivemind.utils import nested_flatten, nested_pack, get_logger
from hivemind.utils import get_logger, nested_flatten, nested_pack

logger = get_logger(__name__)

Expand Down
1 change: 0 additions & 1 deletion hivemind/dht/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from hivemind.utils import MSGPackSerializer, get_logger
from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey


logger = get_logger(__name__)


Expand Down
6 changes: 3 additions & 3 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import dataclasses
import random
from collections import defaultdict, Counter
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from functools import partial
from typing import (
Expand All @@ -27,11 +27,11 @@

from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
from hivemind.dht.protocol import DHTProtocol
from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
from hivemind.dht.routing import DHTID, BinaryDHTValue, DHTKey, DHTValue, Subkey, get_dht_time
from hivemind.dht.storage import DictionaryDHTValue
from hivemind.dht.traverse import traverse_dht
from hivemind.p2p import P2P, PeerID
from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
from hivemind.utils import MSGPackSerializer, SerializerBase, get_logger
from hivemind.utils.auth import AuthorizerBase
from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration

Expand Down
12 changes: 6 additions & 6 deletions hivemind/dht/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
from __future__ import annotations

import asyncio
from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union

from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
from hivemind.dht.routing import DHTID, BinaryDHTValue, RoutingTable, Subkey
from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
from hivemind.proto import dht_pb2
from hivemind.utils import get_logger, MSGPackSerializer
from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
from hivemind.utils import MSGPackSerializer, get_logger
from hivemind.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapper
from hivemind.utils.timed_storage import (
DHTExpiration,
get_dht_time,
MAX_DHT_TIME_DISCREPANCY_SECONDS,
DHTExpiration,
ValueWithExpiration,
get_dht_time,
)

logger = get_logger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion hivemind/dht/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import random
from collections.abc import Iterable
from itertools import chain
from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union

from hivemind.p2p import PeerID
from hivemind.utils import MSGPackSerializer, get_dht_time

Expand Down
2 changes: 1 addition & 1 deletion hivemind/dht/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
from hivemind.utils.serializer import MSGPackSerializer
from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration
from hivemind.utils.timed_storage import DHTExpiration, KeyType, TimedStorage, ValueType


@MSGPackSerializer.ext_serializable(0x50)
Expand Down
2 changes: 1 addition & 1 deletion hivemind/dht/traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import heapq
from collections import Counter
from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple

from hivemind.dht.routing import DHTID

Expand Down
4 changes: 2 additions & 2 deletions hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import configargparse
import torch

from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.moe.server import Server
from hivemind.moe.server.layers import schedule_name_to_scheduler
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger
from hivemind.moe.server.layers import schedule_name_to_scheduler

logger = get_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion hivemind/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
from hivemind.moe.server import ExpertBackend, Server, register_expert_class, get_experts, declare_experts
from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class
Loading

0 comments on commit bedfa6e

Please sign in to comment.