Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use matmul decomposition for batch size 1 #8

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions benchmark/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, x):


def run_benchmark(
use_q, d_model, dim_feedforward, batch_size, seq_len, minimize_error=True
use_q, d_model, dim_feedforward, batch_size, seq_len, minimize_error=True, use_int_mm=True,
):
inp = torch.randn(batch_size, seq_len, d_model)
inp = inp.half().cuda()
Expand All @@ -58,8 +58,8 @@ def run_benchmark(
ffn = ffn.half().cuda().eval()
fp16_ref = ffn(inp).detach().clone().float()
if use_q:
ffn.linear1 = protoquant.qlinear_from_linear(ffn.linear1, minimize_error)
ffn.linear2 = protoquant.qlinear_from_linear(ffn.linear2, minimize_error)
ffn.linear1 = protoquant.qlinear_from_linear(ffn.linear1, minimize_error, use_int_mm)
ffn.linear2 = protoquant.qlinear_from_linear(ffn.linear2, minimize_error, use_int_mm)
ffn = torch.compile(ffn, options={"max-autotune": True})
fp8_ref = ffn(inp).detach().clone().float()
torch.testing.assert_close(fp16_ref, fp8_ref, atol=3e-2, rtol=3e-2)
Expand Down Expand Up @@ -144,6 +144,7 @@ def get_opt_shapes():
"with_q(μs)",
"without_q(μs)",
"minimize_error",
"use_int_mm",
"speedup",
]
shape_gen = get_default_shapes
Expand All @@ -155,9 +156,9 @@ def get_opt_shapes():
bs = int(args.batchsize)
seq_len = int(args.seq_len)
for d_model, dim_feedforward, annotation in shape_gen():
for minimize_error in [True, False]:
for (minimize_error, use_int_mm) in itertools.product([True, False], [False]):
with_q = run_benchmark(
True, d_model, dim_feedforward, bs, seq_len, minimize_error
True, d_model, dim_feedforward, bs, seq_len, minimize_error, use_int_mm
)
without_q = run_benchmark(False, d_model, dim_feedforward, bs, seq_len)
print(
Expand All @@ -173,6 +174,7 @@ def get_opt_shapes():
f"{with_q:.0f}",
f"{without_q:.0f}",
minimize_error,
use_int_mm,
f"{without_q / with_q:.2f}",
],
)
Expand Down
19 changes: 12 additions & 7 deletions protoquant/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class QLinear(torch.nn.Module):
def __init__(self, qweight, wparams, bias, minimize_error):
def __init__(self, qweight, wparams, bias, minimize_error, use_int_mm):
super(QLinear, self).__init__()
assert isinstance(bias, Parameter)
# Need to store in transposed form due to cuBLAS
Expand All @@ -15,17 +15,22 @@ def __init__(self, qweight, wparams, bias, minimize_error):
self.in_features = qweight.size(1)
self.out_features = qweight.size(1)
self.minimize_error = minimize_error
self.use_int_mm = use_int_mm

def forward(self, inp: torch.Tensor) -> torch.Tensor:
assert inp.dim() == 3
inp_size0 = inp.size(0)
inp_size1 = inp.size(1)
inp_size2 = inp.size(2)
inp = inp.reshape(inp_size0 * inp_size1, inp_size2)
qinp, iparams = (qntz)(inp, is_a=True, minimize_error=self.minimize_error)
d = torch.ops.aten._int_mm(qinp, self.qweight.t())
# d = matmul_int8(qinp, self.qweight.t())
return (dqntz)(d, iparams, self.wparams, self.bias).view(
qinp, iparams = qntz(
inp, is_a=True, minimize_error=self.minimize_error
)
if self.use_int_mm:
d = torch.ops.aten._int_mm(qinp, self.qweight.t())
else:
d = (qinp.unsqueeze(-1).to(torch.int32) * self.qweight.t().to(torch.int32)).sum(1).to(torch.int32)
return dqntz(d, iparams, self.wparams, self.bias).view(
inp_size0, inp_size1, -1
)

Expand All @@ -36,7 +41,7 @@ def extra_repr(self) -> str:


def qlinear_from_linear(
linear: torch.nn.Module, minimize_error=True
linear: torch.nn.Module, minimize_error=True, use_int_mm=True,
) -> torch.nn.Module:
import protoquant

Expand All @@ -45,4 +50,4 @@ def qlinear_from_linear(
qweight, wparams = qw.wrapped_qntzd, qw.wrapped_params
assert linear.weight.dtype == torch.float16
assert linear.bias.dtype == torch.float16
return QLinear(qweight, wparams, linear.bias, minimize_error=minimize_error)
return QLinear(qweight, wparams, linear.bias, minimize_error=minimize_error, use_int_mm=use_int_mm)
4 changes: 2 additions & 2 deletions protoquant/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def qntz(
pad_0 = pad(m) - m if do_pad else 0
pad_1 = pad(n) - n if do_pad else 0

if rowwise and not transpose and pad_0 == 0 and pad_1 == 0:
if rowwise and not transpose: # and pad_0 == 0 and pad_1 == 0:
mins, maxs, scales, zeros, sums, out = quant(input, 1, minimize_error)
params = QParams(scales, zeros, sums, rowwise, transpose, dtype, pad_0, pad_1)
return (out, params)
Expand Down Expand Up @@ -90,7 +90,7 @@ def dqntz(
)
return out

if mat1_params.pad_0 == 0 and mat2_params.pad_1 == 0:
if True: # mat1_params.pad_0 == 0 and mat2_params.pad_1 == 0:
return dequant(
input,
other,
Expand Down