Skip to content

Commit

Permalink
Improved Bit constructor for uint8 NumPy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Feb 11, 2025
1 parent 1676e3e commit ac9fd53
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
11 changes: 4 additions & 7 deletions pgvector/bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ def __init__(self, value):
if isinstance(value, str):
self._value = self.from_text(value)._value
else:
# TODO change in 0.4.0
# TODO raise if dtype not bool or uint8
# if isinstance(value, np.ndarray) and value.dtype == np.uint8:
# value = np.unpackbits(value)
# else:
# value = np.asarray(value, dtype=bool)

value = np.asarray(value, dtype=bool)
if isinstance(value, np.ndarray) and value.dtype == np.uint8:
value = np.unpackbits(value)
else:
value = np.asarray(value, dtype=bool)

if value.ndim != 1:
raise ValueError('expected ndim to be 1')
Expand Down
4 changes: 1 addition & 3 deletions tests/test_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def test_str(self):

def test_ndarray_uint8(self):
arr = np.array([254, 7, 0], dtype=np.uint8)
# TODO change in 0.4.0
# assert Bit(arr).to_text() == '111111100000011100000000'
assert Bit(arr).to_text() == '110'
assert Bit(arr).to_text() == '111111100000011100000000'

def test_ndarray_same_object(self):
arr = np.array([True, False, True])
Expand Down

0 comments on commit ac9fd53

Please sign in to comment.