From 29c9a8ca60fd215a7b84eb418874715fb509f38a Mon Sep 17 00:00:00 2001 From: Dangyi Liu Date: Wed, 30 Apr 2025 12:42:49 -0700 Subject: [PATCH] Use native int2 types. PiperOrigin-RevId: 753273544 --- README.md | 4 ++-- qwix/core/numerics.py | 4 ++-- tests/core/numerics_test.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b705740..b934f1e 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ targets (LiteRT). converter could produce full integer models. * LoRA/QLoRA: this mode enables LoRA and QLoRA on a model. * Supported numerics: - * Native: `int4`, `int8`, `fp8`. - * Emulated: `int1` to `int7`, `nf4`. + * Native: `int2`, `int4`, `int8`, `fp8`. + * Emulated: `int3` to `int7`, `nf4`. * Supported array calibration methods: * `absmax`: symmetric quantization using maximum absolute value. * `minmax`: asymmetric quantization using minimum and maximum values. diff --git a/qwix/core/numerics.py b/qwix/core/numerics.py index 2d89504..de8cd4c 100644 --- a/qwix/core/numerics.py +++ b/qwix/core/numerics.py @@ -43,7 +43,7 @@ def get_symmetric_bound(qtype: jax.typing.DTypeLike) -> float: match qtype: case 'nf4': return 1.0 - case 'int2' | 'int3' | 'int5' | 'int6' | 'int7': + case 'int3' | 'int5' | 'int6' | 'int7': # The bound is extended to qmax + 0.5 so that we have a better utilization # of the whole range. This is more important for fewer bits of int. return 2 ** (int(qtype[3:]) - 1) - 0.5 @@ -63,7 +63,7 @@ def convert_to(x: jax.Array, qtype: jax.typing.DTypeLike) -> jax.Array: match qtype: case 'nf4': return fp_to_nf4(x) - case 'int2' | 'int3' | 'int5' | 'int6' | 'int7': + case 'int3' | 'int5' | 'int6' | 'int7': bits = int(qtype[3:]) qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) - 1 diff --git a/tests/core/numerics_test.py b/tests/core/numerics_test.py index 634c062..7f28da4 100644 --- a/tests/core/numerics_test.py +++ b/tests/core/numerics_test.py @@ -39,6 +39,10 @@ def test_convert_to(self): numerics.convert_to(jnp.array([1.2, 3.5, 8, -1300]), jnp.int4), jnp.array([1, 4, 7, -8], jnp.int4), ) + self._assert_equal( + numerics.convert_to(jnp.array([1.2, 3.5, 8, -1300]), jnp.int2), + jnp.array([1, 1, 1, -2], jnp.int2), + ) def test_inf(self): self._assert_equal( @@ -52,7 +56,15 @@ def test_arbitrary_integer_dtype(self): numerics.convert_to(jnp.array([1.2, 3.5, 129, -1300]), "int6"), jnp.array([1, 4, 31, -32], jnp.int8), ) - # jnp.int4 and "int4" should be the same. + # jnp.int* and "int*" should be the same. + self._assert_equal( + numerics.get_symmetric_bound("int2"), + numerics.get_symmetric_bound(jnp.int2), + ) + self._assert_equal( + numerics.convert_to(jnp.array([1.2, 3.5, 129, -1300]), "int2"), + numerics.convert_to(jnp.array([1.2, 3.5, 129, -1300]), jnp.int2), + ) self._assert_equal( numerics.get_symmetric_bound("int4"), numerics.get_symmetric_bound(jnp.int4),