-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate obcq path in favor of sparsegpt path #1148
Open
kylesayrs
wants to merge
2
commits into
main
Choose a base branch
from
kylesayrs/move-sparsegptq-2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,163 +1,12 @@ | ||
import contextlib | ||
from typing import Dict, Optional, Tuple | ||
import warnings | ||
|
||
import torch | ||
from compressed_tensors.utils import ( | ||
align_module_device, | ||
get_execution_device, | ||
update_offload_parameter, | ||
) | ||
from loguru import logger | ||
from pydantic import PrivateAttr | ||
from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier | ||
|
||
from llmcompressor.core import State | ||
from llmcompressor.modifiers import Modifier | ||
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin | ||
from llmcompressor.modifiers.obcq.sgpt_sparsify import ( | ||
accumulate_hessian, | ||
make_empty_hessian, | ||
sparsify_weight, | ||
warnings.warn( | ||
"llmcompressor.modifiers.obcq has been moved to " | ||
"llmcompressor.modifiers.pruning.sparsegpt Please update your paths", | ||
DeprecationWarning, | ||
) | ||
from llmcompressor.utils.metric_logging import CompressionLogger | ||
|
||
__all__ = ["SparseGPTModifier"] | ||
|
||
|
||
class SparseGPTModifier(SparsityModifierMixin, Modifier): | ||
""" | ||
Modifier for applying the one-shot SparseGPT algorithm to a model | ||
|
||
| Sample yaml: | ||
| test_stage: | ||
| obcq_modifiers: | ||
| SparseGPTModifier: | ||
| sparsity: 0.5 | ||
| mask_structure: "2:4" | ||
| dampening_frac: 0.001 | ||
| block_size: 128 | ||
| targets: ['Linear'] | ||
| ignore: ['re:.*lm_head'] | ||
|
||
Lifecycle: | ||
- on_initialize | ||
- register_hook(module, calibrate_module, "forward") | ||
- run_sequential / run_layer_sequential / run_basic | ||
- make_empty_hessian | ||
- accumulate_hessian | ||
- on_sequential_batch_end | ||
- sparsify_weight | ||
- on_finalize | ||
- remove_hooks() | ||
|
||
:param sparsity: Sparsity to compress model to | ||
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed | ||
Layerwise Sparsity (OWL), more information can be found | ||
in the paper https://arxiv.org/pdf/2310.05175 | ||
:param mask_structure: String to define the structure of the mask to apply. | ||
Must be of the form N:M where N, M are integers that define a custom block | ||
shape. Defaults to 0:0 which represents an unstructured mask. | ||
:param owl_m: Number of outliers to use for OWL | ||
:param owl_lmbda: Lambda value to use for OWL | ||
:param block_size: Used to determine number of columns to compress in one pass | ||
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the | ||
diagonal norm | ||
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask | ||
during when applying sparsegpt, this becomes useful when starting from a | ||
previously pruned model, defaults to False. | ||
:param offload_hessians: Set to True for decreased memory usage but increased | ||
runtime. | ||
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__' | ||
to compress every layer in the model. Alias for `targets` | ||
:param targets: list of layer names to compress during OBCQ, or '__ALL__' | ||
to compress every layer in the model. Alias for `sequential_targets` | ||
:param ignore: optional list of module class names or submodule names to not | ||
quantize even if they match a target. Defaults to empty list. | ||
""" | ||
|
||
# modifier arguments | ||
block_size: int = 128 | ||
dampening_frac: Optional[float] = 0.01 | ||
preserve_sparsity_mask: bool = False | ||
offload_hessians: bool = False | ||
|
||
# private variables | ||
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) | ||
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) | ||
|
||
def calibrate_module( | ||
self, | ||
module: torch.nn.Module, | ||
args: Tuple[torch.Tensor, ...], | ||
_output: torch.Tensor, | ||
): | ||
# Assume that the first argument is the input | ||
inp = args[0] | ||
|
||
# Initialize hessian if not present | ||
if module not in self._num_samples: | ||
device = get_execution_device(module) | ||
self._hessians[module] = make_empty_hessian(module, device=device) | ||
self._num_samples[module] = 0 | ||
|
||
# Accumulate hessian with input with optional offloading | ||
with self._maybe_onload_hessian(module): | ||
self._hessians[module], self._num_samples[module] = accumulate_hessian( | ||
inp, | ||
module, | ||
self._hessians[module], | ||
self._num_samples[module], | ||
) | ||
|
||
def on_sequential_batch_end(self): | ||
""" | ||
Sparsify modules | ||
TODO: implement with event callback | ||
""" | ||
for module in list(self._num_samples.keys()): | ||
name = self._module_names[module] | ||
sparsity = self._module_sparsities[module] | ||
num_samples = self._num_samples[module] | ||
|
||
logger.info(f"Sparsifying {name} using {num_samples} samples") | ||
with ( | ||
torch.no_grad(), | ||
align_module_device(module), | ||
CompressionLogger(module) as comp_logger, | ||
): | ||
loss, sparsified_weight = sparsify_weight( | ||
module=module, | ||
hessians_dict=self._hessians, | ||
sparsity=sparsity, | ||
prune_n=self._prune_n, | ||
prune_m=self._prune_m, | ||
block_size=self.block_size, | ||
dampening_frac=self.dampening_frac, | ||
preserve_sparsity_mask=self.preserve_sparsity_mask, | ||
) | ||
comp_logger.set_loss(loss) | ||
|
||
update_offload_parameter(module, "weight", sparsified_weight) | ||
|
||
# self._hessians[module] already deleted by sparsify_weight | ||
del self._num_samples[module] | ||
|
||
@contextlib.contextmanager | ||
def _maybe_onload_hessian(self, module: torch.nn.Module): | ||
if self.offload_hessians: | ||
device = get_execution_device(module) | ||
self._hessians[module] = self._hessians[module].to(device=device) | ||
|
||
yield | ||
|
||
if self.offload_hessians: | ||
if module in self._hessians: # may have been deleted in context | ||
self._hessians[module] = self._hessians[module].to(device="cpu") | ||
|
||
def on_finalize(self, state: State, **kwargs) -> bool: | ||
self.remove_hooks() | ||
self._hessians = dict() | ||
self._num_samples = dict() | ||
self._module_names = dict() | ||
self._module_sparsities = dict() | ||
|
||
return True | ||
__all__ = ["SparseGPTModifier"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ | |
|
||
from .constant import * | ||
from .magnitude import * | ||
from .sparsegpt import * | ||
from .wanda import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .base import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import contextlib | ||
from typing import Dict, Optional, Tuple | ||
|
||
import torch | ||
from compressed_tensors.utils import ( | ||
align_module_device, | ||
get_execution_device, | ||
update_offload_parameter, | ||
) | ||
from loguru import logger | ||
from pydantic import PrivateAttr | ||
|
||
from llmcompressor.core import State | ||
from llmcompressor.modifiers import Modifier | ||
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_mixin import SparsityModifierMixin | ||
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_sparsify import ( | ||
accumulate_hessian, | ||
make_empty_hessian, | ||
sparsify_weight, | ||
) | ||
from llmcompressor.utils.metric_logging import CompressionLogger | ||
|
||
__all__ = ["SparseGPTModifier"] | ||
|
||
|
||
class SparseGPTModifier(SparsityModifierMixin, Modifier): | ||
""" | ||
Modifier for applying the one-shot SparseGPT algorithm to a model | ||
|
||
| Sample yaml: | ||
| test_stage: | ||
| modifiers: | ||
| SparseGPTModifier: | ||
| sparsity: 0.5 | ||
| mask_structure: "2:4" | ||
| dampening_frac: 0.001 | ||
| block_size: 128 | ||
| targets: ['Linear'] | ||
| ignore: ['re:.*lm_head'] | ||
|
||
Lifecycle: | ||
- on_initialize | ||
- register_hook(module, calibrate_module, "forward") | ||
- run_sequential / run_layer_sequential / run_basic | ||
- make_empty_hessian | ||
- accumulate_hessian | ||
- on_sequential_batch_end | ||
- sparsify_weight | ||
- on_finalize | ||
- remove_hooks() | ||
|
||
:param sparsity: Sparsity to compress model to | ||
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed | ||
Layerwise Sparsity (OWL), more information can be found | ||
in the paper https://arxiv.org/pdf/2310.05175 | ||
:param mask_structure: String to define the structure of the mask to apply. | ||
Must be of the form N:M where N, M are integers that define a custom block | ||
shape. Defaults to 0:0 which represents an unstructured mask. | ||
:param owl_m: Number of outliers to use for OWL | ||
:param owl_lmbda: Lambda value to use for OWL | ||
:param block_size: Used to determine number of columns to compress in one pass | ||
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the | ||
diagonal norm | ||
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask | ||
during when applying sparsegpt, this becomes useful when starting from a | ||
previously pruned model, defaults to False. | ||
:param offload_hessians: Set to True for decreased memory usage but increased | ||
runtime. | ||
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__' | ||
to compress every layer in the model. Alias for `targets` | ||
:param targets: list of layer names to compress during OBCQ, or '__ALL__' | ||
to compress every layer in the model. Alias for `sequential_targets` | ||
:param ignore: optional list of module class names or submodule names to not | ||
quantize even if they match a target. Defaults to empty list. | ||
""" | ||
|
||
# modifier arguments | ||
block_size: int = 128 | ||
dampening_frac: Optional[float] = 0.01 | ||
preserve_sparsity_mask: bool = False | ||
offload_hessians: bool = False | ||
|
||
# private variables | ||
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) | ||
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) | ||
|
||
def calibrate_module( | ||
self, | ||
module: torch.nn.Module, | ||
args: Tuple[torch.Tensor, ...], | ||
_output: torch.Tensor, | ||
): | ||
# Assume that the first argument is the input | ||
inp = args[0] | ||
|
||
# Initialize hessian if not present | ||
if module not in self._num_samples: | ||
device = get_execution_device(module) | ||
self._hessians[module] = make_empty_hessian(module, device=device) | ||
self._num_samples[module] = 0 | ||
|
||
# Accumulate hessian with input with optional offloading | ||
with self._maybe_onload_hessian(module): | ||
self._hessians[module], self._num_samples[module] = accumulate_hessian( | ||
inp, | ||
module, | ||
self._hessians[module], | ||
self._num_samples[module], | ||
) | ||
|
||
def on_sequential_batch_end(self): | ||
""" | ||
Sparsify modules | ||
TODO: implement with event callback | ||
""" | ||
for module in list(self._num_samples.keys()): | ||
name = self._module_names[module] | ||
sparsity = self._module_sparsities[module] | ||
num_samples = self._num_samples[module] | ||
|
||
logger.info(f"Sparsifying {name} using {num_samples} samples") | ||
with ( | ||
torch.no_grad(), | ||
align_module_device(module), | ||
CompressionLogger(module) as comp_logger, | ||
): | ||
loss, sparsified_weight = sparsify_weight( | ||
module=module, | ||
hessians_dict=self._hessians, | ||
sparsity=sparsity, | ||
prune_n=self._prune_n, | ||
prune_m=self._prune_m, | ||
block_size=self.block_size, | ||
dampening_frac=self.dampening_frac, | ||
preserve_sparsity_mask=self.preserve_sparsity_mask, | ||
) | ||
comp_logger.set_loss(loss) | ||
|
||
update_offload_parameter(module, "weight", sparsified_weight) | ||
|
||
# self._hessians[module] already deleted by sparsify_weight | ||
del self._num_samples[module] | ||
|
||
@contextlib.contextmanager | ||
def _maybe_onload_hessian(self, module: torch.nn.Module): | ||
if self.offload_hessians: | ||
device = get_execution_device(module) | ||
self._hessians[module] = self._hessians[module].to(device=device) | ||
|
||
yield | ||
|
||
if self.offload_hessians: | ||
if module in self._hessians: # may have been deleted in context | ||
self._hessians[module] = self._hessians[module].to(device="cpu") | ||
|
||
def on_finalize(self, state: State, **kwargs) -> bool: | ||
self.remove_hooks() | ||
self._hessians = dict() | ||
self._num_samples = dict() | ||
self._module_names = dict() | ||
self._module_sparsities = dict() | ||
|
||
return True |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!