Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move mixin to core #8396

Merged
merged 3 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin


class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/core/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin # noqa F401
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin # noqa F401
2 changes: 1 addition & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin


class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import torch
from torchmetrics import Metric

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down