Skip to content

Commit

Permalink
Always return FieldArray from FEC encode/decode methods
Browse files Browse the repository at this point in the history
Fixes #394
  • Loading branch information
mhostetter committed Jul 29, 2022
1 parent bf1275a commit d0495cf
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 90 deletions.
20 changes: 10 additions & 10 deletions galois/_codes/_bch.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __str__(self) -> str:

return string

def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) -> Union[np.ndarray, GF2]:
def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) -> GF2:
r"""
Encodes the message :math:`\mathbf{m}` into the BCH codeword :math:`\mathbf{c}`.
Expand All @@ -314,9 +314,8 @@ def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) ->
Returns
-------
:
The codeword as either a :math:`n`-length vector or :math:`(N, n)` matrix. The return type matches the
message type. If `parity_only=True`, the parity bits are returned as either a :math:`n - k`-length vector or
:math:`(N, n-k)` matrix.
The codeword as either a :math:`n`-length vector or :math:`(N, n)` matrix. If `parity_only=True`, the parity
bits are returned as either a :math:`n - k`-length vector or :math:`(N, n-k)` matrix.
Notes
-----
Expand Down Expand Up @@ -421,13 +420,14 @@ def encode(self, message: Union[np.ndarray, GF2], parity_only: bool = False) ->

if parity_only:
parity = message.view(GF2) @ self.G[-ks:, self.k:]
return parity.view(type(message))
return parity
elif self.is_systematic:
parity = message.view(GF2) @ self.G[-ks:, self.k:]
return np.hstack((message, parity)).view(type(message))
codeword = np.hstack((message, parity))
return codeword
else:
codeword = message.view(GF2) @ self.G
return codeword.view(type(message))
return codeword

def detect(self, codeword: Union[np.ndarray, GF2]) -> Union[np.bool_, np.ndarray]:
r"""
Expand Down Expand Up @@ -584,10 +584,10 @@ def detect(self, codeword: Union[np.ndarray, GF2]) -> Union[np.bool_, np.ndarray
return detected

@overload
def decode(self, codeword: Union[np.ndarray, GF2], errors: Literal[False] = False) -> Union[np.ndarray, GF2]:
def decode(self, codeword: Union[np.ndarray, GF2], errors: Literal[False] = False) -> GF2:
...
@overload
def decode(self, codeword: Union[np.ndarray, GF2], errors: Literal[True]) -> Tuple[Union[np.ndarray, GF2], Union[np.integer, np.ndarray]]:
def decode(self, codeword: Union[np.ndarray, GF2], errors: Literal[True]) -> Tuple[GF2, Union[np.integer, np.ndarray]]:
...
def decode(self, codeword, errors=False):
r"""
Expand Down Expand Up @@ -786,7 +786,7 @@ def decode(self, codeword, errors=False):
message = dec_codeword[:, 0:ks]
else:
message, _ = divmod_jit(GF2)(dec_codeword[:, 0:ns].view(GF2), self.generator_poly.coeffs)
message = message.view(type(codeword)) # TODO: Remove this
message = message.view(GF2)

if codeword_1d:
message, N_errors = message[0,:], N_errors[0]
Expand Down
20 changes: 10 additions & 10 deletions galois/_codes/_reed_solomon.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __str__(self) -> str:

return string

def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = False) -> Union[np.ndarray, FieldArray]:
def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = False) -> FieldArray:
r"""
Encodes the message :math:`\mathbf{m}` into the Reed-Solomon codeword :math:`\mathbf{c}`.
Expand All @@ -199,9 +199,8 @@ def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = Fal
Returns
-------
:
The codeword as either a :math:`n`-length vector or :math:`(N, n)` matrix. The return type matches the
message type. If `parity_only=True`, the parity symbols are returned as either a :math:`n - k`-length vector or
:math:`(N, n-k)` matrix.
The codeword as either a :math:`n`-length vector or :math:`(N, n)` matrix. If `parity_only=True`, the parity
symbols are returned as either a :math:`n - k`-length vector or :math:`(N, n-k)` matrix.
Notes
-----
Expand Down Expand Up @@ -306,13 +305,14 @@ def encode(self, message: Union[np.ndarray, FieldArray], parity_only: bool = Fal

if parity_only:
parity = message.view(self.field) @ self.G[-ks:, self.k:]
return parity.view(type(message))
return parity
elif self.is_systematic:
parity = message.view(self.field) @ self.G[-ks:, self.k:]
return np.hstack((message, parity)).view(type(message))
codeword = np.hstack((message, parity))
return codeword
else:
codeword = message.view(self.field) @ self.G
return codeword.view(type(message))
return codeword

def detect(self, codeword: Union[np.ndarray, FieldArray]) -> Union[np.bool_, np.ndarray]:
r"""
Expand Down Expand Up @@ -471,10 +471,10 @@ def detect(self, codeword: Union[np.ndarray, FieldArray]) -> Union[np.bool_, np.
return detected

@overload
def decode(self, codeword: Union[np.ndarray, FieldArray], errors: Literal[False] = False) -> Union[np.ndarray, FieldArray]:
def decode(self, codeword: Union[np.ndarray, FieldArray], errors: Literal[False] = False) -> FieldArray:
...
@overload
def decode(self, codeword: Union[np.ndarray, FieldArray], errors: Literal[True]) -> Tuple[Union[np.ndarray, FieldArray], Union[np.integer, np.ndarray]]:
def decode(self, codeword: Union[np.ndarray, FieldArray], errors: Literal[True]) -> Tuple[FieldArray, Union[np.integer, np.ndarray]]:
...
def decode(self, codeword, errors=False):
r"""
Expand Down Expand Up @@ -676,7 +676,7 @@ def decode(self, codeword, errors=False):
message = dec_codeword[:, 0:ks]
else:
message, _ = divmod_jit(self.field)(dec_codeword[:, 0:ns].view(self.field), self.generator_poly.coeffs)
message = message.view(type(codeword)) # TODO: Remove this
message = message.view(self.field)

if codeword_1d:
message, N_errors = message[0,:], N_errors[0]
Expand Down
24 changes: 12 additions & 12 deletions tests/codes/test_bch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def test_all_correctable(self, size):
assert np.array_equal(N_corr, N_errors)

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)

Expand All @@ -98,11 +98,11 @@ def test_some_uncorrectable(self, size):
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

Expand Down Expand Up @@ -130,11 +130,11 @@ def test_all_correctable(self, size):
assert np.array_equal(N_corr, N_errors)

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)

Expand Down Expand Up @@ -162,11 +162,11 @@ def test_some_uncorrectable(self, size):
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

Expand All @@ -192,11 +192,11 @@ def test_all_correctable(self, size):
assert np.array_equal(N_corr, N_errors)

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M, M)
assert np.array_equal(N_corr, N_errors)

Expand All @@ -222,10 +222,10 @@ def test_some_uncorrectable(self, size):
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])

DEC_M = bch.decode(R.view(np.ndarray))
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])

DEC_M, N_corr = bch.decode(R.view(np.ndarray), errors=True)
assert type(DEC_M) is np.ndarray
assert type(DEC_M) is galois.GF2
assert np.array_equal(DEC_M[corr_idxs,:], M[corr_idxs,:])
assert np.array_equal(N_corr[corr_idxs], N_errors[corr_idxs])
38 changes: 19 additions & 19 deletions tests/codes/test_bch_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def test_systematic(size):
assert np.array_equal(c, c_truth[k:])

c = bch.encode(m.view(np.ndarray))
assert type(c) is np.ndarray
assert type(c) is galois.GF2
assert np.array_equal(c, c_truth)

c = bch.encode(m.view(np.ndarray), parity_only=True)
assert type(c) is np.ndarray
assert type(c) is galois.GF2
assert np.array_equal(c, c_truth[k:])


Expand All @@ -95,7 +95,7 @@ def test_non_systematic(size):
c = bch.encode(m, parity_only=True)

c = bch.encode(m.view(np.ndarray))
assert type(c) is np.ndarray
assert type(c) is galois.GF2
assert np.array_equal(c, c_truth)

with pytest.raises(ValueError):
Expand Down Expand Up @@ -145,11 +145,11 @@ def test_default(self):
assert np.array_equal(C, C_truth[:, self.k:])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, self.k:])

def test_diff_primitive_poly(self):
Expand Down Expand Up @@ -181,11 +181,11 @@ def test_diff_primitive_poly(self):
assert np.array_equal(C, C_truth[:, self.k:])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, self.k:])


Expand Down Expand Up @@ -233,11 +233,11 @@ def test_default(self):
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])

def test_diff_primitive_poly(self):
Expand Down Expand Up @@ -269,11 +269,11 @@ def test_diff_primitive_poly(self):
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])


Expand Down Expand Up @@ -320,11 +320,11 @@ def test_default(self):
assert np.array_equal(C, C_truth[:, self.k:])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, self.k:])

def test_diff_primitive_poly(self):
Expand Down Expand Up @@ -356,11 +356,11 @@ def test_diff_primitive_poly(self):
assert np.array_equal(C, C_truth[:, self.k:])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, self.k:])


Expand Down Expand Up @@ -408,11 +408,11 @@ def test_default(self):
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])

def test_diff_primitive_poly(self):
Expand Down Expand Up @@ -444,9 +444,9 @@ def test_diff_primitive_poly(self):
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])

C = bch.encode(self.M.view(np.ndarray))
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth)

C = bch.encode(self.M.view(np.ndarray), parity_only=True)
assert type(C) is np.ndarray
assert type(C) is galois.GF2
assert np.array_equal(C, C_truth[:, -(self.n - self.k):])
Loading

0 comments on commit d0495cf

Please sign in to comment.