Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 49 additions & 65 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import skip_if_rocm, skip_if_xpu
from torchao.utils import (
get_available_devices,
get_current_accelerator_device,
is_ROCM,
is_sm_at_least_89,
Expand Down Expand Up @@ -118,25 +120,6 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
return model


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
Expand All @@ -159,14 +142,11 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):


class TestQuantFlow(TestCase):
GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
["xpu"] if torch.xpu.is_available() else []
)

def test_dynamic_quant_gpu_singleline(self):
if is_ROCM():
self.skipTest("Don't test CPU for ROCM version of torch")
m = ToyLinearModel().eval()
@common_utils.parametrize("device", get_available_devices())
def test_dynamic_quant_gpu_singleline(self, device):
if is_ROCM() and device == "cuda":
self.skipTest("Skip on ROCM")
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device=device).eval()
example_inputs = m.example_inputs()
quantize_(m, Int8DynamicActivationInt8WeightConfig())
m(*example_inputs)
Expand All @@ -180,7 +160,7 @@ def test_dynamic_quant_gpu_singleline(self):
@unittest.skip("skipping for now due to torch.compile error")
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
quantizer = XNNPackDynamicQuantizer()
m = ToyLinearModel().eval()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu").eval()
example_inputs = m.example_inputs()
m = quantizer.prepare(m)
m = quantizer.convert(m)
Expand All @@ -197,7 +177,7 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self):
)
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
quantizer = TorchCompileDynamicQuantizer()
m = ToyLinearModel().eval()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu").eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
quantized = m(*example_inputs)
Expand All @@ -207,7 +187,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_int8_wo_quant_save_load(self):
m = ToyLinearModel().eval().cpu()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu").eval()

def api(model):
quantize_(model, Int8WeightOnlyConfig())
Expand All @@ -222,7 +202,7 @@ def api(model):
f.seek(0)
state_dict = torch.load(f)

m2 = ToyLinearModel().eval().cpu()
m2 = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu").eval()
api(m2)

m2.load_state_dict(state_dict)
Expand All @@ -240,7 +220,7 @@ def test_8da4w_quantizer(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
m = ToyLinearModel().eval()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu").eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
Expand All @@ -252,7 +232,9 @@ def test_8da4w_quantizer_linear_bias(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
m = ToyLinearModel(bias=True).eval()
m = ToyTwoLinearModel(
64, 32, 64, dtype=torch.float, device="cpu", has_bias=True
).eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
Expand All @@ -261,9 +243,9 @@ def test_8da4w_quantizer_linear_bias(self):

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_quantized_tensor_subclass_save_load(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cpu").eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)
example_inputs = m.example_inputs()

quantize_(m, Int8WeightOnlyConfig())
ref = m(*example_inputs)
Expand All @@ -279,8 +261,8 @@ def test_quantized_tensor_subclass_save_load(self):

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_int8wo_quantized_model_to_device(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cpu").eval()
example_inputs = m.example_inputs()

quantize_(m, Int8WeightOnlyConfig())
ref = m(*example_inputs)
Expand All @@ -294,8 +276,8 @@ def test_int8wo_quantized_model_to_device(self):
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_quantized_tensor_subclass_save_load_map_location(self):
device = get_current_accelerator_device()
m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device=device)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device=device).eval()
example_inputs = m.example_inputs()

quantize_(m, Int8WeightOnlyConfig())
ref = m(*example_inputs)
Expand All @@ -305,7 +287,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self):
state_dict = torch.load(f.name, map_location="cpu", mmap=True)

with torch.device("meta"):
m_copy = ToyLinearModel().eval()
m_copy = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="meta").eval()

m_copy.load_state_dict(state_dict, assign=True)
m_copy.to(dtype=torch.bfloat16, device=device)
Expand All @@ -324,13 +306,13 @@ def reset_memory():
device_module.reset_peak_memory_stats()

reset_memory()
m = ToyLinearModel()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu")
quantize_(m.to(device=device), Int8WeightOnlyConfig())
memory_baseline = device_module.max_memory_allocated()

del m
reset_memory()
m = ToyLinearModel()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu")
quantize_(m, Int8WeightOnlyConfig(), device=device)
memory_streaming = device_module.max_memory_allocated()

Expand Down Expand Up @@ -392,8 +374,8 @@ def test_module_fqn_to_config_default(self):
config2 = Int8WeightOnlyConfig()
config = ModuleFqnToConfig({"_default": config1, "linear2": config2})
device = get_current_accelerator_device()
model = ToyLinearModel().to(device).to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device=device, dtype=torch.bfloat16)
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device=device)
example_inputs = model.example_inputs()
quantize_(model, config, filter_fn=None)
model(*example_inputs)
assert isinstance(model.linear1.weight, Float8Tensor)
Expand All @@ -406,8 +388,8 @@ def test_module_fqn_to_config_module_name(self):
config2 = Int8WeightOnlyConfig()
config = ModuleFqnToConfig({"linear1": config1, "linear2": config2})
device = get_current_accelerator_device()
model = ToyLinearModel().to(device).to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device=device, dtype=torch.bfloat16)
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device=device)
example_inputs = model.example_inputs()
quantize_(model, config, filter_fn=None)
model(*example_inputs)
assert isinstance(model.linear1.weight, Float8Tensor)
Expand All @@ -419,8 +401,8 @@ def test_module_fqn_to_config_regex_basic(self):
group_size=32, int4_packing_format="tile_packed_to_4d"
)
config = ModuleFqnToConfig({"re:linear.": config1})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda")
example_inputs = model.example_inputs()
quantize_(model, config, filter_fn=None)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
Expand All @@ -436,8 +418,8 @@ def test_module_fqn_to_config_regex_precedence(self):
)
config2 = IntxWeightOnlyConfig()
config = ModuleFqnToConfig({"linear1": config1, "re:linear.": config2})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda")
example_inputs = model.example_inputs()
quantize_(model, config, filter_fn=None)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
Expand All @@ -455,8 +437,8 @@ def test_module_fqn_to_config_regex_precedence2(self):
)
config2 = IntxWeightOnlyConfig()
config = ModuleFqnToConfig({"re:linear.": config2, "linear1": config1})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda")
example_inputs = model.example_inputs()
quantize_(model, config, filter_fn=None)
model(*example_inputs)
assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
Expand Down Expand Up @@ -544,8 +526,8 @@ def test_module_fqn_to_config_skip(self):
config1 = Float8DynamicActivationFloat8WeightConfig()
config = ModuleFqnToConfig({"_default": config1, "linear2": None})
device = get_current_accelerator_device()
model = ToyLinearModel().to(device).to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device=device, dtype=torch.bfloat16)
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device=device)
example_inputs = model.example_inputs()
quantize_(model, config, filter_fn=None)
model(*example_inputs)
assert isinstance(model.linear1.weight, Float8Tensor)
Expand Down Expand Up @@ -588,7 +570,9 @@ def __init__(self):
assert "PerTensor()" in str(custom_module)

def test_fqn_to_config_repr_linear(self):
linear_model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
linear_model = ToyTwoLinearModel(
64, 32, 64, dtype=torch.bfloat16, device="cuda"
).eval()
linear_quant_config = FqnToConfig(
{
"linear1.weight": Float8DynamicActivationFloat8WeightConfig(
Expand Down Expand Up @@ -699,7 +683,7 @@ def test_quantize_param_fqn_regex(self):
assert isinstance(model.experts.gate_up_proj, Float8Tensor)

def test_quantize_fqn_precedence_param_over_module(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()

quant_config = FqnToConfig(
{
Expand All @@ -714,7 +698,7 @@ def test_quantize_fqn_precedence_param_over_module(self):
assert model.linear1.weight.scale.numel() == 1

def test_quantize_fqn_precedence_param_over_module_regex(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()

quant_config = FqnToConfig(
{
Expand All @@ -729,7 +713,7 @@ def test_quantize_fqn_precedence_param_over_module_regex(self):
assert model.linear1.weight.scale.numel() == 1

def test_quantize_fqn_precedence_param_regex_over_module_regex(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()

quant_config = FqnToConfig(
{
Expand All @@ -744,7 +728,7 @@ def test_quantize_fqn_precedence_param_regex_over_module_regex(self):
assert model.linear1.weight.scale.numel() == 1

def test_quantize_fqn_precedence_module_over_param_regex(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()

quant_config = FqnToConfig(
{
Expand All @@ -760,7 +744,7 @@ def test_quantize_fqn_precedence_module_over_param_regex(self):
assert not isinstance(model.linear2.weight, Float8Tensor)

def test_quantize_fqn_precedence_param_over_default(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()

quant_config = FqnToConfig(
{
Expand All @@ -776,7 +760,7 @@ def test_quantize_fqn_precedence_param_over_default(self):
assert not isinstance(model.linear2.weight, Float8Tensor)

def test_quantize_fqn_precedence_param_regex_over_default(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()

quant_config = FqnToConfig(
{
Expand All @@ -791,7 +775,7 @@ def test_quantize_fqn_precedence_param_regex_over_default(self):
assert not isinstance(model.linear1.weight, Float8Tensor)

def test_quantize_model_same_module_different_param(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()
model.linear1.register_parameter(
"weight2", torch.nn.Parameter(model.linear1.weight.clone())
)
Expand All @@ -817,7 +801,7 @@ def test_quantize_model_same_module_different_param(self):
assert model.linear1.weight2.scale.numel() == 32

def test_quantize_model_same_module_different_param_regex(self):
model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
model = ToyTwoLinearModel(64, 32, 64, dtype=torch.bfloat16, device="cuda").eval()
quant_config = FqnToConfig(
{
"re:.*weight": Float8DynamicActivationFloat8WeightConfig(
Expand Down Expand Up @@ -941,13 +925,13 @@ def reset_memory():

quant_config = FqnToConfig({"_default": Int8WeightOnlyConfig()})
reset_memory()
m = ToyLinearModel()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu")
quantize_(m.to(device=device), quant_config, filter_fn=None)
memory_baseline = device_module.max_memory_allocated()

del m
reset_memory()
m = ToyLinearModel()
m = ToyTwoLinearModel(64, 32, 64, dtype=torch.float, device="cpu")
quantize_(m, quant_config, device=device, filter_fn=None)
memory_streaming = device_module.max_memory_allocated()

Expand Down