Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Nf4 a NF4 Tensor subclass #18

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 50 additions & 53 deletions benchmarks/qlora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import csv
import itertools

Expand All @@ -12,6 +11,8 @@

import transformer_nuggets as nugs
import transformer_nuggets.quant.qlora as qlora
from jsonargparse import CLI
from tabulate import tabulate
from tqdm import tqdm
from transformer_nuggets.quant import NF4Tensor

Expand Down Expand Up @@ -57,15 +58,17 @@ def linear_experiment(config: ExperimentConfig) -> ExperimentResult:
config.device,
)
qlora_weight = NF4Tensor.from_tensor(input_weight.clone())
bnb_linear = qlora.build_bitsandbytes_linear(input_weight, config.device)
compiled_qlora_linear = torch.compile(qlora.linear_nf4, fullgraph=True, dynamic=config.dynamic)
if bnb_available:
bnb_linear = qlora.build_bitsandbytes_linear(input_weight, config.device)

# warmup
for _ in range(3):
F.linear(sample_input, input_weight)
qlora.linear_nf4(sample_input, qlora_weight)
compiled_qlora_linear(sample_input, qlora_weight)
bnb_linear(sample_input)
if bnb_available:
bnb_linear(sample_input)

linear_time = nugs.utils.benchmark_torch_function_in_microseconds(
F.linear, sample_input, input_weight
Expand All @@ -76,7 +79,12 @@ def linear_experiment(config: ExperimentConfig) -> ExperimentResult:
compiled_qlora_linear_time = nugs.utils.benchmark_torch_function_in_microseconds(
compiled_qlora_linear, sample_input, qlora_weight
)
bnb_linear_time = nugs.utils.benchmark_torch_function_in_microseconds(bnb_linear, sample_input)
if bnb_available:
bnb_linear_time = nugs.utils.benchmark_torch_function_in_microseconds(
bnb_linear, sample_input
)
else:
bnb_linear_time = -1.0

return ExperimentResult(
linear_time, qlora_linear_time, compiled_qlora_linear_time, bnb_linear_time
Expand All @@ -94,21 +102,26 @@ def mlp_experiment(config: ExperimentConfig) -> ExperimentResult:
mlp = qlora.MLP(*weights)
nf4_mlp = qlora.NF4MLP(*weights)
compiled_qlora_mlp = torch.compile(nf4_mlp, fullgraph=True, dynamic=config.dynamic)
bnb_mlp = qlora.BnbQloraMLP(*weights, config.device)
if bnb_available:
bnb_mlp = qlora.BnbQloraMLP(*weights, config.device)

# warmup
for _ in range(3):
mlp(sample_input)
nf4_mlp(sample_input)
compiled_qlora_mlp(sample_input)
bnb_mlp(sample_input)
if bnb_available:
bnb_mlp(sample_input)

mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(mlp, sample_input)
qlora_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(nf4_mlp, sample_input)
compiled_qlora_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(
compiled_qlora_mlp, sample_input
)
bnb_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(bnb_mlp, sample_input)
if bnb_available:
bnb_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(bnb_mlp, sample_input)
else:
bnb_mlp_time = -1.0

return ExperimentResult(mlp_time, qlora_mlp_time, compiled_qlora_mlp_time, bnb_mlp_time)

Expand Down Expand Up @@ -137,22 +150,34 @@ def gen_configs() -> List[ExperimentConfig]:


def main(output_path: Optional[Path], profile_path: Optional[Path], dynamic: bool):
"""Run experiments and output results to file

Args:
output_path (Optional[Path]): Path to write out CSV file for experiment results.
profile_path (Optional[Path]): Path to write out json chrome trace file for an experiment.
dynamic (bool): Compile with Dynamic shapes
"""

results = []
for experiment_config in tqdm(gen_configs()):
# Since we are changing between dynamic and not
import torch._dynamo # noqa: F402

torch._dynamo.reset()
experiment = experiment_types[experiment_config.op]
experiment_result = experiment(experiment_config)
merged = asdict(experiment_config) | asdict(experiment_result)
results.append(merged)

if output_path is not None:
results = []
for experiment_config in tqdm(gen_configs()):
# Since we are changing between dynamic and not
import torch._dynamo # noqa: F402

torch._dynamo.reset()
experiment = experiment_types[experiment_config.op]
experiment_result = experiment(experiment_config)
merged = asdict(experiment_config) | asdict(experiment_result)
results.append(merged)

with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
else:
headers = results[0].keys()
rows = [list(r.values()) for r in results]
print(tabulate(rows, headers=headers))

if profile_path is not None:
profile_experiment = ExperimentConfig(4096, 8, 128, torch.device("cuda:0"), "mlp", dynamic)
Expand All @@ -169,7 +194,7 @@ def main(output_path: Optional[Path], profile_path: Optional[Path], dynamic: boo

qlora_mlp = qlora.NF4MLP(*weights)
compiled_qlora_mlp = torch.compile(qlora_mlp, fullgraph=True, dynamic=dynamic)
print("dynamic = ", dynamic)
logging.info("Running torch.compile with dynamic = %s", dynamic)
profile_config = nugs.utils.ProfileConfig(
str(profile_path), "qlora_mlp", iters=5, warmup_iters=3, sync=True
)
Expand All @@ -183,34 +208,6 @@ def main(output_path: Optional[Path], profile_path: Optional[Path], dynamic: boo
if __name__ == "__main__":
"""Sample usage:
# Running sweep
python benchmarks/qlora.py -o benchmarks/data/qlora_sweep.csv
python benchmarks/qlora.py -p benchmarks/data/4096_8_128_qlora.json
python benchmarks/qlora.py false --output_path benchmarks/data/qlora_sweep.csv
"""
parser = argparse.ArgumentParser(description="Run experiments and output results to file")
parser.add_argument(
"-o",
"--output_file",
type=str,
help="Path to write out CSV file for experiment results.",
default=None,
)
parser.add_argument(
"-p",
"--profile_path",
type=str,
help="Path to write out json chrome trace file for an experiment.",
default=None,
)
parser.add_argument(
"--dynamic_shapes", action="store_true", help="Compile with Dynamic shapes"
)

args = parser.parse_args()
output_path = None
profile_path = None
if args.output_file is not None:
output_path = Path(args.output_file)
if args.profile_path is not None:
profile_path = Path(args.profile_path)

main(output_path, profile_path, args.dynamic_shapes)
CLI(main)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ dev = [
"pytest",
"flake8==6.1.0",
"flake8-pyproject",
"jsonargparse",
"docstring-parser"
]

qlora = ['bitsandbytes']
Expand Down
6 changes: 4 additions & 2 deletions test/test_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch.nn.functional as F

import transformer_nuggets.quant.qlora as qlora
from transformer_nuggets.quant import linear_nf4, NF4Tensor
from transformer_nuggets.quant import linear_nf4
from transformer_nuggets.quant.nf4_tensor import NF4Tensor
from transformer_nuggets.quant.qlora_debug import NF4TensorDebug

bnb_available = False
Expand Down Expand Up @@ -91,8 +92,9 @@ def test_binning_distribution(embed_dim: int):
@pytest.mark.parametrize("embed_dim", [256, 4096, 5120, 6656, 8192])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("requires_grad", [True, False])
@pytest.mark.xfail(reason="TORCH COMPILE No longer works here")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_autograd_func_to_eager(embed_dim: int, compile: bool, requires_grad: bool):
torch._dynamo.reset()
torch.manual_seed(0)
device = "cuda"
input_weight = qlora.build_input_weight(embed_dim, device)
Expand Down
3 changes: 2 additions & 1 deletion transformer_nuggets/quant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from transformer_nuggets.quant.qlora import get_block_absmax, linear_nf4, NF4Tensor
from transformer_nuggets.quant.nf4_tensor import get_block_absmax, NF4Tensor
from transformer_nuggets.quant.qlora import linear_nf4
from transformer_nuggets.quant.qlora_debug import NF4TensorDebug
Loading
Loading