Skip to content

Add simple op implementations for CPU #1602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
import ctypes as ct
from typing import Optional

Expand Down Expand Up @@ -119,6 +120,10 @@ def _(
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

n = A.numel()

Expand All @@ -140,3 +145,39 @@ def _(
packed = packed.squeeze().view(quant_storage).unsqueeze(1)

return packed, absmax.float()


@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(
A.dtype == torch.uint8,
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
)

# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
upper = (A >> 4).to(torch.int64)
lower = (A & 0x0F).to(torch.int64)

# Expand to blocks
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)

# Dequantize
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]

# Reshape to original shape
blocks = blocks.reshape(-1, *shape[1:])

return blocks.to(dtype)
2 changes: 1 addition & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def forward(self, x: torch.Tensor):

bias = None if self.bias is None else self.bias.to(self.compute_dtype)

return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)


class LinearFP4(Linear4bit):
Expand Down
6 changes: 5 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.skip("CPU implementation is not available")
if quant_type != "nf4":
pytest.skip("CPU implementation is only available for nf4")

if storage_dtype != torch.uint8:
pytest.skip("CPU implementation only supports uint8 storage")

shape = (128, 128)

Expand Down