Skip to content

Commit

Permalink
fix: precision issue in centroid computation
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Dec 7, 2023
1 parent b6b30c2 commit 303f9f3
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions cc3d.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def _statistics(
cdef cnp.ndarray[uint32_t] counts = np.zeros(N + 1, dtype=np.uint32)
cdef cnp.ndarray[uint16_t] bounding_boxes = np.zeros(6 * (N + 1), dtype=np.uint16)
cdef cnp.ndarray[float] centroids = np.zeros(3 * (N + 1), dtype=np.float32)
cdef cnp.ndarray[double] centroids = np.zeros(3 * (N + 1), dtype=np.float64)
cdef uint16_t x = 0
cdef uint16_t y = 0
Expand All @@ -683,9 +683,9 @@ def _statistics(
bounding_boxes[6 * label + 3] = <uint16_t>max(bounding_boxes[6 * label + 3], y)
bounding_boxes[6 * label + 4] = <uint16_t>min(bounding_boxes[6 * label + 4], z)
bounding_boxes[6 * label + 5] = <uint16_t>max(bounding_boxes[6 * label + 5], z)
centroids[3 * label + 0] += <float>x
centroids[3 * label + 1] += <float>y
centroids[3 * label + 2] += <float>z
centroids[3 * label + 0] += <double>x
centroids[3 * label + 1] += <double>y
centroids[3 * label + 2] += <double>z
else:
for x in range(sx):
for y in range(sy):
Expand All @@ -698,14 +698,19 @@ def _statistics(
bounding_boxes[6 * label + 3] = <uint16_t>max(bounding_boxes[6 * label + 3], y)
bounding_boxes[6 * label + 4] = <uint16_t>min(bounding_boxes[6 * label + 4], z)
bounding_boxes[6 * label + 5] = <uint16_t>max(bounding_boxes[6 * label + 5], z)
centroids[3 * label + 0] += <float>x
centroids[3 * label + 1] += <float>y
centroids[3 * label + 2] += <float>z
centroids[3 * label + 0] += <double>x
centroids[3 * label + 1] += <double>y
centroids[3 * label + 2] += <double>z
for label in range(N+1):
centroids[3 * label + 0] /= <float>counts[label]
centroids[3 * label + 1] /= <float>counts[label]
centroids[3 * label + 2] /= <float>counts[label]
if <double>counts[label] == 0:
centroids[3 * label + 0] = float('NaN')
centroids[3 * label + 1] = float('NaN')
centroids[3 * label + 2] = float('NaN')
else:
centroids[3 * label + 0] /= <double>counts[label]
centroids[3 * label + 1] /= <double>counts[label]
centroids[3 * label + 2] /= <double>counts[label]
bbxes = bounding_boxes.reshape((N+1,6))
Expand Down

0 comments on commit 303f9f3

Please sign in to comment.