diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index b7513c4d3..0da9eac94 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import ctypes as ct from typing import Optional @@ -119,6 +120,10 @@ def _( ) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) n = A.numel() @@ -140,3 +145,73 @@ def _( packed = packed.squeeze().view(quant_storage).unsqueeze(1) return packed, absmax.float() + + +@register_kernel("bitsandbytes::dequantize_4bit", "cpu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + torch._check( + A.dtype == torch.uint8, + lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}", + ) + + A = A.view(-1, 1) + + # Grab upper and lower nibbles. Using int64 for indexing in the LUT. + upper = (A >> 4).to(torch.int64) + lower = (A & 0x0F).to(torch.int64) + + # Expand to blocks + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + + # Dequantize + blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] + + # Reshape to original shape + blocks = blocks.reshape(-1, *shape[1:]) + + return blocks.to(dtype) + + +@register_kernel("bitsandbytes::gemv_4bit", "cpu") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + # TODO: We need to determine whether `code` is NF4, FP4, or other. + # Right now we assume NF4, as this is the only one supported on CPU. + + B_dq = torch.ops.bitsandbytes.dequantize_4bit.default( + B, + absmax, + blocksize, + "nf4", + shape=shapeB, + dtype=A.dtype, + ) + + # User called gemv with B.t(), so we need to transpose it back. + # if B.shape[0] == 1: + # B_dq = B_dq.t() + + return torch.nn.functional.linear( + A, + B_dq, + bias=None, + ) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 5ffcdb767..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): _int8_linear_matmul_impl(A, B, out) -@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda") -def _( - A: torch.Tensor, - CA: torch.Tensor, - CB: torch.Tensor, - SCA: torch.Tensor, - SCB: torch.Tensor, - outlier_cols: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - subB = None - - if outlier_cols is not None and outlier_cols.numel(): - # Extract the inputs with outliers in original precision - subA = A[:, outlier_cols].contiguous() - - # Dequantize the corresponding weight columns - subB = ( - torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB) - .to(A.dtype) - .t() - ) - - # TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t() - - else: - # Needed for torch.compile when there are no outliers. - subA = torch.empty(0, device=A.device, dtype=A.dtype) - - # Int8 Matmul + Dequant + Bias - output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype) - - if subB is not None: - # Add the outlier columns back to the output - output = output.addmm(subA, subB) - - return output, subA - - def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): A, B = B, A diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 6e581038d..653f87659 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,3 +1,4 @@ +from math import prod from typing import Optional import torch @@ -5,6 +6,45 @@ from ..._ops import register_kernel +@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default") +def _( + A: torch.Tensor, + CA: torch.Tensor, + CB: torch.Tensor, + SCA: torch.Tensor, + SCB: torch.Tensor, + outlier_cols: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + subB = None + + if outlier_cols is not None and outlier_cols.numel(): + # Extract the inputs with outliers in original precision + subA = A[:, outlier_cols].contiguous() + + # Dequantize the corresponding weight columns + subB = ( + torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB) + .to(A.dtype) + .t() + ) + + # TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t() + + else: + # Needed for torch.compile when there are no outliers. + subA = torch.empty(0, device=A.device, dtype=A.dtype) + + # Int8 Matmul + Dequant + Bias + output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype) + + if subB is not None: + # Add the outlier columns back to the output + output = output.addmm(subA, subB) + + return output, subA + + @register_kernel("bitsandbytes::int8_scaled_mm", "default") def _( A: torch.Tensor, @@ -41,3 +81,33 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor if out is not None: result = out.copy_(result) return result + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "default") +def _(A: torch.Tensor, threshold=0.0): + rows = prod(A.shape[:-1]) + outlier_cols = None + + if threshold > 0.0: + outliers = A.abs() >= threshold + + if outliers.any(): + # Determine which columns contain outliers, and zero out the + # outliers ahead of quantization. + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + A[outliers] = 0 + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + # Get absmax for each row. + row_stats = torch.max(A.abs(), dim=1).values.float() + + # Quantize row-wise to int8. + out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8) + + # Zero out values from outlier columns across all rows. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c9341230f..d17ff2e88 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -779,7 +779,7 @@ def quantize_blockwise( state2=state2, ) else: - quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) # TODO(matthewdouglas): Deprecate out kwarg out = out.copy_(_out) if out is not None else _out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ea5451502..74277f65e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -486,7 +486,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) - return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit): @@ -585,19 +585,28 @@ def __new__( obj.has_fp16_weights = has_fp16_weights return obj - def cuda(self, device): + def _quantize(self, device): if self.has_fp16_weights: - return super().cuda(device) - else: - # We quantize the weight and store in 8bit row-major - B = self.data.contiguous().half().cuda(device) - CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) - self.data = CB - self.CB = CB - self.SCB = SCB + return super().to(device) + + # We quantize the weight and store in 8bit row-major + B = self.data.contiguous().to(device=device, dtype=torch.float16) + CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) + self.data = CB + self.CB = CB + self.SCB = SCB return self + def cpu(self): + return self.to(device="cpu") + + def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + + def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + return self.to(device="xpu" if device is None else device, non_blocking=non_blocking) + def __deepcopy__(self, memo): # adjust this if new arguments are added to the constructor new_instance = type(self).__new__( @@ -627,8 +636,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and self.data.device.type == "cpu": - return self.cuda(device) + if device is not None and device.type != "meta" and self.data.device.type == "cpu": + return self._quantize(device) else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking), diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 7c43cab80..b6ba284c9 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -32,9 +32,15 @@ def test_matmullt( device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias ): - if device != "cuda" and funcs[1] == bnb.research.switchback_bnb: - # TODO: Deprecate/remove? - pytest.skip("switchback_bnb only works on CUDA.") + if device != "cuda": + if funcs[1] == bnb.research.switchback_bnb: + # TODO: Deprecate/remove? + pytest.skip("switchback_bnb only works on CUDA.") + + if req_grad[1]: + # This will be deprecated for CUDA in the future. We don't expect + # this to work on any other device. + pytest.skip("Deprecated feature with CUDA support only.") dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) @@ -171,7 +177,7 @@ def test_matmul_4bit( quant_type, ): if device == "cpu" and quant_type == "fp4": - pytest.skip("Only nf4 is supported on CPU") + pytest.xfail("Only nf4 is supported on CPU") dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) diff --git a/tests/test_functional.py b/tests/test_functional.py index 5b9038288..ee2b52429 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -186,7 +186,7 @@ def test_few_bit_quant(self, device, bits, method): code = F.create_dynamic_map(True, bits - 0, bits).to(device) elif method == "quantile": if device != "cuda": - pytest.xfail("Quantile map only works on CUDA") + pytest.skip("Quantile map only works on CUDA") values = torch.randn(2048, 2048, device="cuda") code = F.create_quantile_map(values, bits).cuda() # for some data types we have no zero @@ -593,7 +593,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): A = A.view(-1, A.shape[-1]) - CA, _, statsA, _, _ = F.int8_double_quant(A) + CA, statsA, _ = F.int8_vectorwise_quant(A) CB, statsB, _ = F.int8_vectorwise_quant(B) output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB) @@ -1102,6 +1102,9 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) def test_4bit_quant(self, device, dtype, quant_type, blocksize): + if device == "cpu" and quant_type != "nf4": + pytest.xfail("fp4 quantization is not supported on CPU") + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) @@ -1134,6 +1137,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): + if device == "cpu" and quant_type != "nf4": + pytest.xfail("fp4 quantization is not supported on CPU") + errs1 = [] errs2 = [] for i in range(10): @@ -1206,6 +1212,12 @@ def test_bench_4bit_dequant(self, quant_type): ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): + if device == "cpu": + if storage_type != "nf4": + pytest.xfail("fp4 quantization is not supported on CPU") + if quant_storage != torch.uint8: + pytest.xfail("Only uint8 storage is supported on CPU") + errs1 = [] errs2 = [] errs3 = [] @@ -1216,7 +1228,11 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double max_errs2 = [] max_errs3 = [] - for i in range(100): + # Large number of iterations is excessive and slow on CPU. + # Keep for CUDA for now. + iters = 100 if device == "cuda" else 10 + + for i in range(iters): if kind == "fc1": A = torch.randn(1, dim, dtype=dtype, device=device) B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim) @@ -1337,6 +1353,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): + if device == "cpu" and storage_type != "nf4": + pytest.xfail("fp4 quantization is not supported on CPU") + dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) dims = get_test_dims(0, 8192, n=dims) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 669319298..67b61cb05 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -25,7 +25,10 @@ @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): if device == "cpu": - pytest.xfail("Dequantization is not yet implemented for CPU") + if quant_type == "fp4": + pytest.xfail("FP4 is not supported for CPU") + if quant_storage != "uint8": + pytest.xfail("Only uint8 storage is supported for CPU") original_dtype = torch.float16 compute_dtype = None @@ -144,8 +147,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua linear_q3 = torch_load_from_buffer(bytes_4bit) # Test moving to CPU and back to GPU - linear_q2.to("cpu") - linear_q2.to(device) + if device != "cpu": + linear_q2.to("cpu") + linear_q2.to(device) d = linear_qs(x) assert c.dtype == d.dtype assert c.device == d.device diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 53a566cb9..8c08cfa2c 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -22,9 +22,6 @@ # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py @pytest.mark.parametrize("device", get_available_devices()) def test_linear_no_igemmlt(device): - if device == "cpu": - pytest.xfail("Not yet implemented on CPU") - linear = torch.nn.Linear(1024, 3072) x = torch.randn(3, 1024, dtype=torch.half) linear_custom = Linear8bitLt( @@ -81,8 +78,8 @@ def test_linear_serialization( save_before_forward, load_before_cuda, ): - if device == "cpu": - pytest.xfail("Not yet implemented on CPU") + if device != "cuda" and has_fp16_weights: + pytest.skip("has_fp16_weights is only supported on CUDA and is deprecated") linear = torch.nn.Linear(32, 96) # TODO: Fallback for bad shapes @@ -111,7 +108,7 @@ def test_linear_serialization( if save_before_forward: bytes_8bit = torch_save_to_buffer(linear_custom) - x_first = x.clone().cuda().requires_grad_(True) + x_first = x.clone().to(device).requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) (fx_first * grad_proj).mean().backward() @@ -157,11 +154,11 @@ def test_linear_serialization( if not load_before_cuda: new_linear_custom2 = torch_load_from_buffer(bytes_8bit) - x_second = x.clone().cuda().requires_grad_(True) + x_second = x.clone().to(device).requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() - x_third = x.clone().cuda().requires_grad_(True) + x_third = x.clone().to(device).requires_grad_(True) fx_third = new_linear_custom2(x_third).float() (fx_third * grad_proj).mean().backward() diff --git a/tests/test_modules.py b/tests/test_modules.py index 8ef0890ec..dc1d60e6c 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -55,9 +55,6 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) def test_linear8bitlt_inference(device, threshold): - if device == "cpu": - pytest.xfail("Not yet implemented on CPU") - l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() assert l1.weight.device.type == device assert l1.weight.dtype == torch.int8 @@ -120,9 +117,6 @@ def test_linear8bitlt_accumulated_gradient(device): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 2.0]) def test_linear8bitlt_no_fp16_weights(device, threshold): - if device == "cpu": - pytest.xfail("Not yet supported on CPU") - l1 = ( bnb.nn.Linear8bitLt( 32, @@ -211,7 +205,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): has_fp16_weights=False, ) w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization, - mlp = mlp.cuda().half() # and this line triggers quantization + mlp = mlp.to(device).half() # and this line triggers quantization for i in range(100): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) @@ -253,9 +247,6 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): ids=["Int8Lt", "NF4"], ) def test_linear_kbit_fp32_bias(device, module): - if device == "cpu": - pytest.xfail("Not yet implemented on CPU") - # casts model to fp16 -> int8 automatically l1 = module(32, 64).to(device) assert l1.weight.dtype in [torch.int8, torch.uint8] @@ -295,7 +286,7 @@ def test_linear_kbit_fp32_bias(device, module): @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(device, module): if device == "cpu": - pytest.xfail("Not yet implemented on CPU") + pytest.xfail("Test is not yet supported on CPU") b = 16 dim1 = 36 @@ -401,7 +392,10 @@ def test_fp8linear(): ) def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage): if device == "cpu": - pytest.xfail("Not yet supported on CPU") + if embedding_class is bnb.nn.EmbeddingFP4: + pytest.xfail("FP4 is not supported for CPU") + if quant_storage is not None and quant_storage != torch.uint8: + pytest.xfail("CPU only supports uint8 storage for 4bit") num_embeddings = 128 @@ -449,7 +443,10 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, ) def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage): if device == "cpu": - pytest.xfail("Not yet supported on CPU") + if embedding_class is bnb.nn.EmbeddingFP4: + pytest.xfail("FP4 is not supported for CPU") + if quant_storage is not None and quant_storage != torch.uint8: + pytest.xfail("CPU only supports uint8 storage for 4bit") is_8bit = embedding_class is bnb.nn.Embedding8bit @@ -486,7 +483,7 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu @pytest.mark.parametrize("device", get_available_devices()) def test_4bit_linear_warnings(device): if device == "cpu": - pytest.xfail("Not yet implemented on CPU") + pytest.xfail("gemv_4bit op is not yet implemented on CPU") dim1 = 64 @@ -525,9 +522,6 @@ def test_4bit_linear_warnings(device): @pytest.mark.parametrize("device", get_available_devices()) def test_4bit_embedding_warnings(device): - if device == "cpu": - pytest.xfail("Not yet implemented on CPU") - num_embeddings = 128 default_block_size = 64 diff --git a/tests/test_ops.py b/tests/test_ops.py index 9869f51ef..ea448f99b 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -37,9 +37,6 @@ def test_int8_linear_matmul_out(self, device): @pytest.mark.parametrize("threshold", [0.0, 6.0]) @pytest.mark.parametrize("device", get_available_devices()) def test_int8_vectorwise_quant(self, threshold, device): - if device == "cpu": - pytest.skip("CPU implementation is not available") - A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -147,7 +144,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": - pytest.skip("CPU implementation is only available for nf4") + pytest.xfail("CPU implementation is only available for nf4") if storage_dtype != torch.uint8: pytest.xfail("Known issue with storage_dtype != uint8") @@ -171,7 +168,11 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": - pytest.skip("CPU implementation is not available") + if quant_type != "nf4": + pytest.xfail("CPU implementation is only available for nf4") + + if storage_dtype != torch.uint8: + pytest.xfail("CPU implementation only supports uint8 storage") shape = (128, 128) @@ -204,7 +205,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": - pytest.skip("CPU implementation is not available") + pytest.xfail("CPU implementation is not available") out_features = 1024 in_features = 256