Skip to content

Commit

Permalink
Merge pull request #56 from dkazanc/tests
Browse files Browse the repository at this point in the history
unittests added
  • Loading branch information
dkazanc authored Jun 25, 2020
2 parents f098813 + bf15510 commit 23805fc
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 76 deletions.
19 changes: 9 additions & 10 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
language: python

matrix:
include:
- python: 2.7
env: NUMPY=1.16

- python: 2.7
env: NUMPY=1.15

- python: 3.5
env: NUMPY=1.15

Expand All @@ -16,7 +9,13 @@ matrix:

- python: 3.6
env: NUMPY=1.16


- python: 3.7
env: NUMPY=1.15

- python: 3.7
env: NUMPY=1.16

os:
- linux

Expand All @@ -41,10 +40,10 @@ install:
- export VERSION=`date +%Y.%m`
- conda build conda-recipe --numpy=$NUMPY --python=$TRAVIS_PYTHON_VERSION
- conda install --channel /home/travis/miniconda/envs/test-environment/conda-bld/ tomobar --offline --override-channels

after_success:
- chmod +x src/Python/conda-recipe/conda_upload.sh
- test $TRAVIS_BRANCH = "master" && bash conda-recipe/conda_upload.sh

script:
- python tests/test.py
- python test/test_tomobarCPU_DIR.py
22 changes: 11 additions & 11 deletions Demos/Python/DemoFISTA_artifacts2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
plt.title('Misaligned noisy FBP Reconstruction')
plt.show()

plt.figure()
plt.figure()
plt.imshow(abs(FBPrec_ideal-FBPrec_error), vmin=0, vmax=2, cmap="gray")
plt.colorbar(ticks=[0, 0.5, 2], orientation='vertical')
plt.title('FBP reconstruction differences')
Expand All @@ -133,7 +133,7 @@

FBPrec_misalign = Rectools.FBP(sino_misalign) # reconstruction with misalignment

plt.figure()
plt.figure()
plt.imshow(FBPrec_misalign, vmin=0, vmax=1, cmap="gray")
plt.title('FBP reconstruction of misaligned data using known exact shifts')
#%%
Expand All @@ -153,7 +153,7 @@
# data dictionary
_data_ = {'projection_norm_data' : sino_artifacts,
'projection_raw_data' : sino_artifacts_raw/np.max(sino_artifacts_raw),
'OS_number' : 10}
'OS_number' : 10}
lc = RectoolsIR.powermethod(_data_) # calculate Lipschitz constant (run once to initialise)

_algorithm_ = {'iterations' : 30,
Expand All @@ -169,7 +169,7 @@
print("Run FISTA reconstrucion algorithm with regularisation...")
RecFISTA_LS_reg = RectoolsIR.FISTA(_data_, _algorithm_, _regularisation_)

# adding Huber data fidelity threshold
# adding Huber data fidelity threshold
_data_.update({'huber_threshold' : 7.0})
print(" Run FISTA reconstrucion algorithm with regularisation and Huber data...")
RecFISTA_Huber_reg = RectoolsIR.FISTA(_data_, _algorithm_, _regularisation_)
Expand All @@ -196,7 +196,7 @@
plt.title('FISTA-HuberRing-TV reconstruction')
plt.show()

# calculate errors
# calculate errors
Qtools = QualityTools(phantom_2D[indicesROI], RecFISTA_LS_reg[indicesROI])
RMSE_FISTA_LS_TV = Qtools.rmse()
Qtools = QualityTools(phantom_2D[indicesROI], RecFISTA_Huber_reg[indicesROI])
Expand Down Expand Up @@ -256,7 +256,7 @@
plt.title('FISTA-SWLS-Huber-TV reconstruction')
plt.show()

# calculate errors
# calculate errors
Qtools = QualityTools(phantom_2D[indicesROI], RecFISTA_SWLS[indicesROI])
RMSE_FISTA_SWLS_TV = Qtools.rmse()
Qtools = QualityTools(phantom_2D[indicesROI], RecFISTA_SWLS_Huber[indicesROI])
Expand All @@ -279,7 +279,7 @@
'mask_diameter' : 0.9,
'lipschitz_const' : lc}

# Run FISTA reconstrucion algorithm with regularisation
# Run FISTA reconstrucion algorithm with regularisation
RecFISTA_LS_GH_reg = RectoolsIR.FISTA(_data_, _algorithm_, _regularisation_)

plt.figure()
Expand All @@ -288,7 +288,7 @@
plt.title('FISTA-OS-GH-TV reconstruction')
plt.show()

# calculate errors
# calculate errors
Qtools = QualityTools(phantom_2D[indicesROI], RecFISTA_LS_GH_reg[indicesROI])
RMSE_FISTA_LS_GH_TV = Qtools.rmse()
print("RMSE for FISTA-PWLS-GH-TV reconstruction is {}".format(RMSE_FISTA_LS_GH_TV))
Expand All @@ -312,7 +312,7 @@
'iterations' : 120,
'device_regulariser': 'gpu'}

# Run FISTA reconstrucion algorithm with regularisation
# Run FISTA reconstrucion algorithm with regularisation
RecFISTA_LS_stud_reg = RectoolsIR.FISTA(_data_, _algorithm_, _regularisation_)

plt.figure()
Expand All @@ -321,8 +321,8 @@
plt.title('FISTA-OS-Stidentst-TV reconstruction')
plt.show()

# calculate errors
# calculate errors
Qtools = QualityTools(phantom_2D[indicesROI], RecFISTA_LS_stud_reg[indicesROI])
RMSE_FISTA_LS_studentst_TV = Qtools.rmse()
print("RMSE for FISTA-PWLS-Studentst-TV reconstruction is {}".format(RMSE_FISTA_LS_studentst_TV))
#%%
#%%
2 changes: 1 addition & 1 deletion Demos/Python/DemoFISTA_artifacts3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Projection geometry related parameters:
Horiz_det = int(N_size) # detector column count (horizontal)
Vert_det = N_size+20 # detector row count (vertical) (no reason for it to be > N)
Vert_det = N_size # detector row count (vertical) (no reason for it to be > N)
angles_num = int(0.5*np.pi*N_size); # angles number
angles = np.linspace(0.0,179.9,angles_num,dtype='float32') # in degrees
angles_rad = angles*(np.pi/180.0)
Expand Down
3 changes: 2 additions & 1 deletion conda-recipe/build.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set -xe
set -xe
cp -rv "$RECIPE_DIR/../test" "$SRC_DIR/"

cd $SRC_DIR

Expand Down
11 changes: 9 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
package:
name: tomobar
version: {{ environ['VERSION'] }}

build:
preserve_egg_dir: False
number: 0
script_env:
- VERSION

test:
source_files:
- ./test/
commands:
- python -c "import os; print (os.getcwd())"
- python -m unittest discover -s test

requirements:
build:
- python
- numpy
- setuptools
- cython
- cmake

run:
- scipy
- python
- numpy
- libgcc-ng # [unix]
Expand Down
39 changes: 25 additions & 14 deletions src/Python/tomobar/methodsIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,28 @@

import numpy as np
from numpy import linalg as LA
import scipy.sparse.linalg

try:
from tomobar.supp.addmodules import RING_WEIGHTS
except:
print('____! RING_WEIGHTS C-module failed on import !____')
from ccpi.filters.regularisers import ROF_TV,FGP_TV,PD_TV,SB_TV,LLT_ROF,TGV,NDF,Diff4th,NLTV
except:
print('____! CCPi-regularisation package is missing, please install !____')

try:
from ccpi.filters.regularisers import ROF_TV,FGP_TV,PD_TV,SB_TV,LLT_ROF,TGV,NDF,Diff4th,NLTV
import astra
except:
print('____! Astra-toolbox package is missing, please install !____')

try:
import scipy.sparse.linalg
except:
print('____! Scipy toolbox package is missing, please install !____')

try:
from tomobar.supp.addmodules import RING_WEIGHTS
except:
print('____! CCPi regularisation package is missing, please install !____')
print('____! RING_WEIGHTS C-module failed on import !____')



def smooth(y, box_pts):
# a function to smooth 1D signal
Expand Down Expand Up @@ -264,7 +275,7 @@ class RecToolsIR:
--huber_threshold # threshold for Huber function to apply to data model (supress outliers)
--studentst_threshold # threshold for Students't function to apply to data model (supress outliers)
--ring_weights_threshold # threshold to produce additional weights to supress ring artifacts
--ring_huber_power # defines the strength of Huber penalty to supress artifacts 1 = Huber, > 1 more penalising
--ring_huber_power # defines the strength of Huber penalty to supress artifacts 1 = Huber, > 1 more penalising
--ring_tuple_halfsizes # a tuple for half window sizes as [detector, angles, num of projections]
--ringGH_lambda # a parameter for Group-Huber data model to supress full rings of the same intensity
--ringGH_accelerate # Group-Huber data model acceleration factor (use carefully to avoid divergence, 50 default)
Expand Down Expand Up @@ -347,7 +358,7 @@ def __init__(self,
from tomobar.supp.astraOP import AstraTools3D
self.Atools = AstraTools3D(self.DetectorsDimH, self.DetectorsDimV, self.AnglesVec, self.CenterRotOffset, self.ObjSize) # initiate 3D ASTRA class object
return None


def SIRT(self, _data_, _algorithm_):
######################################################################
Expand All @@ -372,7 +383,7 @@ def CGLS(self, _data_, _algorithm_):
if (self.geom == '3D'):
CGLS_rec = self.Atools.cgls3D(_data_['projection_norm_data'], _algorithm_['iterations'])
return CGLS_rec

def powermethod(self, _data_):
# power iteration algorithm to calculate the eigenvalue of the operator (projection matrix)
# projection_raw_data is required for PWLS fidelity (self.datafidelity = PWLS), otherwise will be ignored
Expand All @@ -389,8 +400,8 @@ def powermethod(self, _data_):
self.AtoolsOS = AstraToolsOS3D(self.DetectorsDimH, self.DetectorsDimV, self.AnglesVec, self.CenterRotOffset, self.ObjSize, _data_['OS_number']) # initiate 3D ASTRA class OS object
niter = 15 # number of power method iterations
s = 1.0
# classical approach

# classical approach
if (self.geom == '2D'):
x1 = np.float32(np.random.randn(self.Atools.ObjSize,self.Atools.ObjSize))
else:
Expand Down Expand Up @@ -489,9 +500,9 @@ def FISTA(self, _data_, _algorithm_, _regularisation_):
r[:,0] = r_x[:,0] - np.multiply(L_const_inv,vec)
else:
r = r_x - np.multiply(L_const_inv,vec)

if ((_data_['OS_number'] != 1) and (_data_['ring_weights_threshold'] is not None) and (iter > 0)):
# Ordered subset approach for a better ring model
# Ordered subset approach for a better ring model
res_full = self.Atools.forwproj(X_t) - _data_['projection_norm_data']
rings_weights = RING_WEIGHTS(res_full, _data_['ring_tuple_halfsizes'][0], _data_['ring_tuple_halfsizes'][1], _data_['ring_tuple_halfsizes'][2])
ring_function_weight = np.ones(np.shape(res_full))
Expand Down Expand Up @@ -703,7 +714,7 @@ def ADMM_Atb(b):
# update u variable
u = u + (x_hat - z)
if (_algorithm_['verbose'] == 'on'):
if (np.mod(iter,(round)(_algorithm_['iterations']/5)+1) == 0):
if (np.mod(iter,(round)(_algorithm_['iterations']/5)+1) == 0):
print('ADMM iteration (',iter+1,') using', _regularisation_['method'], 'regularisation for (',(int)(info_vec[0]),') iterations')
if (iter == _algorithm_['iterations']-1):
print('ADMM stopped at iteration (', iter+1, ')')
Expand Down
Loading

0 comments on commit 23805fc

Please sign in to comment.