Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate obcq path in favor of sparsegpt path #1148

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/finetuning/example_alternating_recipe.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
initial_sparsity_stage:
run_type: oneshot
obcq_modifiers:
modifiers:
SparseGPTModifier:
sparsity: 0.5
block_size: 128
Expand All @@ -16,7 +16,7 @@ initial_training_stage:
start: 0
next_sparsity_stage:
run_type: oneshot
obcq_modifiers:
modifiers:
SparseGPTModifier:
sparsity: 0.7
block_size: 128
Expand Down
3 changes: 1 addition & 2 deletions examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.obcq import SparseGPTModifier
from llmcompressor.modifiers.pruning import ConstantPruningModifier
from llmcompressor.modifiers.pruning import ConstantPruningModifier, SparseGPTModifier
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ are relevant only during training. Below is a summary of the key modifiers avail

Modifiers that introduce sparsity into a model

### [SparseGPT](./obcq/base.py)
### [SparseGPT](./pruning/sparsegpt/base.py)
One-shot algorithm that uses calibration data to introduce unstructured or structured
sparsity into weights. Implementation based on [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774). A small amount of calibration data is used
to calculate a Hessian for each layers input activations, this Hessian is then used to
Expand Down
165 changes: 7 additions & 158 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
@@ -1,163 +1,12 @@
import contextlib
from typing import Dict, Optional, Tuple
import warnings

import torch
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr
from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
accumulate_hessian,
make_empty_hessian,
sparsify_weight,
warnings.warn(
"llmcompressor.modifiers.obcq has been moved to "
"llmcompressor.modifiers.pruning.sparsegpt Please update your paths",
DeprecationWarning,
)
from llmcompressor.utils.metric_logging import CompressionLogger

__all__ = ["SparseGPTModifier"]


class SparseGPTModifier(SparsityModifierMixin, Modifier):
"""
Modifier for applying the one-shot SparseGPT algorithm to a model

| Sample yaml:
| test_stage:
| obcq_modifiers:
| SparseGPTModifier:
| sparsity: 0.5
| mask_structure: "2:4"
| dampening_frac: 0.001
| block_size: 128
| targets: ['Linear']
| ignore: ['re:.*lm_head']

Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- sparsify_weight
- on_finalize
- remove_hooks()

:param sparsity: Sparsity to compress model to
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
Layerwise Sparsity (OWL), more information can be found
in the paper https://arxiv.org/pdf/2310.05175
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param owl_m: Number of outliers to use for OWL
:param owl_lmbda: Lambda value to use for OWL
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask
during when applying sparsegpt, this becomes useful when starting from a
previously pruned model, defaults to False.
:param offload_hessians: Set to True for decreased memory usage but increased
runtime.
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `targets`
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `sequential_targets`
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
"""

# modifier arguments
block_size: int = 128
dampening_frac: Optional[float] = 0.01
preserve_sparsity_mask: bool = False
offload_hessians: bool = False

# private variables
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)

def calibrate_module(
self,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
# Assume that the first argument is the input
inp = args[0]

# Initialize hessian if not present
if module not in self._num_samples:
device = get_execution_device(module)
self._hessians[module] = make_empty_hessian(module, device=device)
self._num_samples[module] = 0

# Accumulate hessian with input with optional offloading
with self._maybe_onload_hessian(module):
self._hessians[module], self._num_samples[module] = accumulate_hessian(
inp,
module,
self._hessians[module],
self._num_samples[module],
)

def on_sequential_batch_end(self):
"""
Sparsify modules
TODO: implement with event callback
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
num_samples = self._num_samples[module]

logger.info(f"Sparsifying {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
CompressionLogger(module) as comp_logger,
):
loss, sparsified_weight = sparsify_weight(
module=module,
hessians_dict=self._hessians,
sparsity=sparsity,
prune_n=self._prune_n,
prune_m=self._prune_m,
block_size=self.block_size,
dampening_frac=self.dampening_frac,
preserve_sparsity_mask=self.preserve_sparsity_mask,
)
comp_logger.set_loss(loss)

update_offload_parameter(module, "weight", sparsified_weight)

# self._hessians[module] already deleted by sparsify_weight
del self._num_samples[module]

@contextlib.contextmanager
def _maybe_onload_hessian(self, module: torch.nn.Module):
if self.offload_hessians:
device = get_execution_device(module)
self._hessians[module] = self._hessians[module].to(device=device)

yield

if self.offload_hessians:
if module in self._hessians: # may have been deleted in context
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
self._hessians = dict()
self._num_samples = dict()
self._module_names = dict()
self._module_sparsities = dict()

return True
__all__ = ["SparseGPTModifier"]
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/pruning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .constant import *
from .magnitude import *
from .sparsegpt import *
from .wanda import *
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
163 changes: 163 additions & 0 deletions src/llmcompressor/modifiers/pruning/sparsegpt/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import contextlib
from typing import Dict, Optional, Tuple

import torch
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_sparsify import (
accumulate_hessian,
make_empty_hessian,
sparsify_weight,
)
from llmcompressor.utils.metric_logging import CompressionLogger

__all__ = ["SparseGPTModifier"]


class SparseGPTModifier(SparsityModifierMixin, Modifier):
"""
Modifier for applying the one-shot SparseGPT algorithm to a model

| Sample yaml:
| test_stage:
| modifiers:
| SparseGPTModifier:
| sparsity: 0.5
| mask_structure: "2:4"
| dampening_frac: 0.001
| block_size: 128
| targets: ['Linear']
| ignore: ['re:.*lm_head']

Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- sparsify_weight
- on_finalize
- remove_hooks()

:param sparsity: Sparsity to compress model to
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
Layerwise Sparsity (OWL), more information can be found
in the paper https://arxiv.org/pdf/2310.05175
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param owl_m: Number of outliers to use for OWL
:param owl_lmbda: Lambda value to use for OWL
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask
during when applying sparsegpt, this becomes useful when starting from a
previously pruned model, defaults to False.
:param offload_hessians: Set to True for decreased memory usage but increased
runtime.
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `targets`
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `sequential_targets`
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
"""

# modifier arguments
block_size: int = 128
dampening_frac: Optional[float] = 0.01
preserve_sparsity_mask: bool = False
offload_hessians: bool = False

# private variables
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)

def calibrate_module(
self,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
# Assume that the first argument is the input
inp = args[0]

# Initialize hessian if not present
if module not in self._num_samples:
device = get_execution_device(module)
self._hessians[module] = make_empty_hessian(module, device=device)
self._num_samples[module] = 0

# Accumulate hessian with input with optional offloading
with self._maybe_onload_hessian(module):
self._hessians[module], self._num_samples[module] = accumulate_hessian(
inp,
module,
self._hessians[module],
self._num_samples[module],
)

def on_sequential_batch_end(self):
"""
Sparsify modules
TODO: implement with event callback
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
num_samples = self._num_samples[module]

logger.info(f"Sparsifying {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
CompressionLogger(module) as comp_logger,
):
loss, sparsified_weight = sparsify_weight(
module=module,
hessians_dict=self._hessians,
sparsity=sparsity,
prune_n=self._prune_n,
prune_m=self._prune_m,
block_size=self.block_size,
dampening_frac=self.dampening_frac,
preserve_sparsity_mask=self.preserve_sparsity_mask,
)
comp_logger.set_loss(loss)

update_offload_parameter(module, "weight", sparsified_weight)

# self._hessians[module] already deleted by sparsify_weight
del self._num_samples[module]

@contextlib.contextmanager
def _maybe_onload_hessian(self, module: torch.nn.Module):
if self.offload_hessians:
device = get_execution_device(module)
self._hessians[module] = self._hessians[module].to(device=device)

yield

if self.offload_hessians:
if module in self._hessians: # may have been deleted in context
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
self._hessians = dict()
self._num_samples = dict()
self._module_names = dict()
self._module_sparsities = dict()

return True
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.pruning.wanda.wanda_sparsify import (
accumulate_row_scalars,
make_empty_row_scalars,
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GPTQModifier(Modifier, HooksMixin):

| Sample yaml:
| test_stage:
| obcq_modifiers:
| modifiers:
| GPTQModifier:
| block_size: 128
| dampening_frac: 0.001
Expand Down
Loading
Loading