Skip to content

Commit

Permalink
Normalize to positive indexing; .shape is now the unpacked shape; fix…
Browse files Browse the repository at this point in the history
… indexing bugs
  • Loading branch information
amirebrahimi committed Jan 10, 2025
1 parent 57bf589 commit 8d2ad26
Showing 1 changed file with 81 additions and 45 deletions.
126 changes: 81 additions & 45 deletions src/galois/_fields/_gf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import operator
from functools import reduce
from math import ceil, floor
from typing import Sequence, Final

import numpy as np
Expand Down Expand Up @@ -166,18 +167,18 @@ class multiply_ufunc_bitpacked(multiply_ufunc):

def __call__(self, ufunc, method, inputs, kwargs, meta):
is_outer_product = method == "outer"
if is_outer_product and np.all(len(i.unpacked_shape) == 1 for i in inputs):
result_shape = reduce(operator.add, (i.unpacked_shape for i in inputs))
if is_outer_product and np.all(len(i.shape) == 1 for i in inputs):
result_shape = reduce(operator.add, (i.shape for i in inputs))
else:
result_shape = np.broadcast_shapes(*(i.unpacked_shape for i in inputs))
result_shape = np.broadcast_shapes(*(i.shape for i in inputs))

if is_outer_product and len(inputs) == 2:
a = np.unpackbits(inputs[0])
output = a[:, np.newaxis].view(np.ndarray) * inputs[1].view(np.ndarray)
else:
output = super().__call__(ufunc, method, inputs, kwargs, meta)

assert len(output.shape) == len(result_shape)
assert len(output.view(np.ndarray).shape) == len(result_shape)
# output = output.view(np.ndarray)
# if output.shape != result_shape:
# for axis, shape in enumerate(zip(output.shape, result_shape)):
Expand Down Expand Up @@ -225,22 +226,25 @@ def __call__(self, ufunc, method, inputs, kwargs, meta):
b._axis_count = row_axis_count

# Make sure the inner dimensions match (e.g. (M, N) x (N, P) -> (M, P))
assert a.shape[-1] == b.shape[0]
if len(b.shape) == 1:
final_shape = (a.shape[0],)
a_packed_shape = a.view(np.ndarray).shape
b_packed_shape = b.view(np.ndarray).shape
assert a_packed_shape[-1] == b_packed_shape[0]
if len(b_packed_shape) == 1:
final_shape = (a_packed_shape[0],)
else:
final_shape = (a.shape[0], b.shape[-1])
final_shape = (a_packed_shape[0], b_packed_shape[-1])

if len(b.shape) == 1:
# matrix-vector multiplication
output = np.bitwise_xor.reduce(np.unpackbits((a & b).view(np.ndarray), axis=-1), axis=-1)
else:
# matrix-matrix multiplication
output = GF2.Zeros(final_shape)
for i in range(b.shape[-1]):
for i in range(b_packed_shape[-1]):
# TODO: Include alternate path for numpy < v2
# output[:, i] = np.bitwise_xor.reduce(np.unpackbits((a & b[:, i]).view(np.ndarray), axis=-1), axis=-1)
output[:, i] = np.bitwise_xor.reduce(np.bitwise_count((a & b[:, i]).view(np.ndarray)), axis=-1) % 2
output[:, i] = np.bitwise_xor.reduce(
np.bitwise_count((a & b.view(np.ndarray)[:, i]).view(np.ndarray)), axis=-1) % 2
output = field._view(np.packbits(output.view(np.ndarray), axis=-1))
output._axis_count = final_shape[-1]

Expand Down Expand Up @@ -275,7 +279,6 @@ def __call__(self, A: Array) -> Array:

# Concatenate A and I to get the matrix AI = [A | I]
AI = np.concatenate((A, I), axis=-1)
AI[0]

# Perform Gaussian elimination to get the reduced row echelon form AI_rre = [I | A^-1]
AI_rre, _ = row_reduce_jit(self.field)(AI, ncols=n)
Expand Down Expand Up @@ -492,40 +495,67 @@ def Identity(cls, size: int, dtype: DTypeLike | None = None) -> Self:
return np.packbits(array)

@staticmethod
def _normalize_indexing_to_tuple(index, ndim):
def _normalize_indexing_to_tuple(index, shape, axis = 0):
"""
Normalize indexing into a tuple of slices, integers, and/or new axes.
Normalize indexing into a tuple of positive-only slices, integers, and/or new axes.
NOTE: Ellipsis indexing is converted to slice indexing.
Args:
index: The indexing expression (int, slice, list, etc.).
ndim: The number of dimensions of the array being indexed into.
shape: Tuple of integers representing the shape of the object being indexed.
Returns:
A tuple of integers, slices, and/or new axes.
A tuple of positive integers, slices, and/or new axes.
"""
if not isinstance(shape, tuple):
raise TypeError("Shape must be a tuple of integers.")

ndim = len(shape)

if isinstance(index, int):
if index < 0:
index += shape[axis]
return (index,)
elif isinstance(index, slice):
return (index,)
start, stop, step = index.start, index.stop, index.step
step = step if step is not None else 1

if step > 0:
start = start if start is not None else 0
stop = stop if stop is not None else shape[axis]

# Adjust negative start/stop values
if start < 0:
start += shape[axis]
if stop < 0:
stop += shape[axis]
else:
start = start if start is not None else shape[axis] - 1
stop = stop if stop is not None else -shape[axis] - 1

return (slice(start, stop, step),)
elif isinstance(index, list):
# Lists cannot be directly converted to slices, so leave as-is.
for axis, i in enumerate(index):
if i < 0:
index[axis] += shape[axis]
return (index,)
elif index is np.newaxis:
return (index,)
elif isinstance(index, tuple):
normalized = []

if any(i is Ellipsis for i in index):
num_explicit_dims = sum(1 for i in index if i is not Ellipsis)
for i in index:
if i is Ellipsis:
normalized.extend([slice(None)] * (ndim - num_explicit_dims))
else:
normalized.append(i)
else:
for i in index:
normalized.extend(GF2BP._normalize_indexing_to_tuple(i, ndim))
num_explicit_dims = sum(1 for i in index if i is not Ellipsis)
for i in index:
if i is Ellipsis:
span = ndim - num_explicit_dims
expanded_dims = [slice(None)] * span
for e_axis, e in enumerate(expanded_dims):
expanded_dims[e_axis] = GF2BP._normalize_indexing_to_tuple(e, shape, axis)[0]
axis += 1
normalized.extend(expanded_dims)
else:
normalized.extend(GF2BP._normalize_indexing_to_tuple(i, shape, axis))
axis += 1

return tuple(normalized)
elif isinstance(index, (Sequence, np.ndarray)):
Expand All @@ -535,11 +565,12 @@ def _normalize_indexing_to_tuple(index, ndim):
raise TypeError(f"Unsupported indexing type: {type(index)}")

def get_index_parameters(self, index):
normalized_index = self._normalize_indexing_to_tuple(index, self.ndim)
normalized_index = self._normalize_indexing_to_tuple(index, self.shape)

assert isinstance(normalized_index, tuple)

bit_width: Final[int] = self.BIT_WIDTH
packed_shape = self.view(np.ndarray).shape
packed_index = tuple()
unpacked_index = tuple()
shape = tuple()
Expand All @@ -556,27 +587,32 @@ def get_index_parameters(self, index):
shape += (1,)
else:
packed_index += (i // bit_width,)
unpacked_index += (i,)
unpacked_index += (i % bit_width,)
if axes_in_index > 1:
shape += (1,)
else:
shape += (self.shape[axis],)
shape += (packed_shape[axis],)
elif isinstance(i, slice):
if is_unpacked_axis:
packed_index += (i,)
# the packed index will already filter, so we just select everything after
unpacked_index += (slice(None),)
else:
packed_index += (slice(i.start // bit_width if i.start is not None else i.start,
max(i.stop // bit_width, 1) if i.stop is not None else i.stop,
max(i.step // bit_width, 1) if i.step is not None else i.step),)
unpacked_index += (i,)
if i.step > 0:
packed_index += (slice(i.start // bit_width,
max(int(ceil(i.stop / bit_width)), 1),
max(i.step // bit_width, 1)),)
unpacked_index += (slice(i.start % bit_width, i.start % bit_width + i.stop - i.start, i.step),)
else:
packed_index += (slice(i.start // bit_width,
max(int(floor(i.stop / bit_width)), -packed_shape[axis] -1),
min(i.step // bit_width, -1)),)
unpacked_index += (slice(i.start % bit_width, i.start % bit_width + i.stop - i.start, i.step),)


packed_slice = packed_index[-1]
packed_slice = slice(0 if packed_slice.start is None else packed_slice.start,
self.shape[axis] if packed_slice.stop is None else packed_slice.stop,
1 if packed_slice.step is None else packed_slice.step)
slice_size = max(0, (packed_slice.stop - packed_slice.start + packed_slice.step - 1) // packed_slice.step)
abs_step = abs(packed_slice.step)
slice_size = max(0, (packed_slice.stop - packed_slice.start + abs_step - 1) // abs_step)
shape += (slice_size,)
elif isinstance(i, (Sequence, np.ndarray)):
if is_unpacked_axis:
Expand All @@ -585,7 +621,7 @@ def get_index_parameters(self, index):
unpacked_index += (slice(None),)
else:
if isinstance(index, np.ndarray) and index.dtype == np.bool:
mask_packed = [False] * self.shape[axis]
mask_packed = [False] * packed_shape[axis]
for j, value in enumerate(i):
mask_packed[j // bit_width] |= True
packed_index = mask_packed
Expand All @@ -594,7 +630,7 @@ def get_index_parameters(self, index):
else:
# adjust indexing for this packed axis
data = np.array([s // bit_width for s in i], dtype=self.dtype)
# remove duplicate entries, including for nested arrays
# remove duplicate entries, including nested arrays
if data.ndim > 1:
rows = []
for j, row_data in enumerate(data):
Expand All @@ -608,9 +644,9 @@ def get_index_parameters(self, index):
data = data[np.sort(unique_indices)]

packed_index += (data,)
shape += (self.shape[axis],)
shape += (packed_shape[axis],)
if axes_in_index == 1:
unpacked_index += (i,)
unpacked_index += ([s % bit_width for s in i],)
elif i is np.newaxis:
packed_index += (i,)
unpacked_index += (i,)
Expand All @@ -624,10 +660,10 @@ def get_index_parameters(self, index):

return packed_index, unpacked_index, shape

# TODO: Should this be the default shape returned and a cast (i.e. .view(np.ndarray) is needed to get the packed shape?
@property
def unpacked_shape(self):
return self.shape[:-1] + (self._axis_count,)
def shape(self):
# A cast to np.ndarray is needed to get the packed shape
return self.view(np.ndarray).shape[:-1] + (self._axis_count,)

def get_unpacked_value(self, index):
# Numpy indexing is handled primarily in https://github.com/numpy/numpy/blob/maintenance/1.26.x/numpy/core/src/multiarray/mapping.c#L1435
Expand Down

0 comments on commit 8d2ad26

Please sign in to comment.