Saturday, August 24, 2019

Iterative Hard Thresholding (Python Implementation)



I'm trying to implement the Iterative Hard Thresholding recovery algorithm for compressive sensing in python. It is a very simple algorithm, given $ \mathbf{y}( = \mathbf{A}\mathbf{x}), \mathbf{A}$, we start with $\mathbf{x}^{[0]}=0$ and update the estimate of $\mathbf{x}$ iteratively with,



$\begin{equation} \mathbf{x}^{[n+1]} = \mathbf{H}_{s}( \mathbf{x}^{[n]} + \mathbf{A}^{T}(\mathbf{y} - \mathbf{Ax}^{[n]})) \end{equation}$


where $\mathbf{H}_{s}(\mathbf{a})$ is the Hard thresholding operator which sets all but $s$ largest magnitude components of $\mathbf{a}$ to zero. $\mathbf{A}$ is the measurement matrix.


from pylab import *

def largestElement(x, n):
# returns the nth largest element of the vector x
N = x.shape[0]
if n > N:
n = N
elif n < 1:

n = 1
t = np.sort(x)[::-1]
return t[n-1] # python index starts at 0

# Soft thresholding function
def softThreshold(x, threshold):
j = np.abs(x) <= threshold
x[j] = 0
j = np.abs(x) > threshold
x[j] = x[j] - np.sign(x[j])*threshold

return x

# Hard thresholding function
def hardThreshold(x, threshold):
j = np.abs(x) < threshold
x[j] = 0
return x

def reconstructIHT(A, y, s, Its=500, tol=0.001, x=0, verbose=False):
# recovers a sparse vector x from y using Iterative Hard thresholding Algorithm

# xhat = reconstructIHT(A, t, T, tol, x, verbose)
# Arguments:
# A - measurement matrix
# y - measurements
# s - sparsity level require in reconstruction
# Its - max number of iterations (optional)
# tol - stopping criteria (optional)
# x - original vector used to print progress of MSE (optional)
# verbose - print progress (optional)


# Length of original signal
N = A.shape[1]

# Length of measurement vector
n = A.shape[0]

# Initial estimate
xhat = np.zeros(N)

# Initial residue

r = y

for t in xrange(Its):
# Pre-threshold value
gamma = xhat + np.dot(A.T, r)

# Find the s-th largest coefficient of gamma
threshold = largestElement(np.abs(gamma), s)

# Estimate the signal (by hard thresholding)

xhat = hardThreshold(gamma, threshold)

# Compute error, print and plot
if verbose:
err = np.mean((x-xhat)**2)
print "iter# = "+str(t) + " MSE = " + str(err)

# update the residual
r = y - np.dot(A, xhat)


# Stopping criteria
if np.linalg.norm(r)/np.linalg.norm(y) < tol:
break

return xhat


if __name__ == '__main__':
N = 2000 # signal length
n = 400 # Number of measurements

k = 50 # Number of non-zero elements

T = 200 # Number of iterations

tol = 0.0001 # Tolerance

# Generate problem instance
A = np.random.randn(n,N)
# normalize to columns to have unit norm
A = A/np.sqrt(np.sum(A**2, axis=0))


# Sparse signal x[i] in {+1, -1, 0}
x = np.sign(np.random.rand(k)-0.5)
x = np.append(x, np.zeros(N-k))
x = x[np.random.permutation(np.arange(N))]

# Generate measurements
y = np.dot(A, x)

# Reconstruct using IHT

xiht = reconstructIHT(A, y, k, T, tol, x=x, verbose=True)

print "All Close xiht and x :" + str(np.allclose(xiht, x))

erriht = np.mean((xiht-x)*(xiht-x))
print "Mean Squared Error AMP: " + str(erriht)

Even with this simple algorithm, the output MSE is diverging from the first iteration itself. I'm getting the following output while executing the above script


$ python iht.py 
iter# = 0 MSE = 0.0212064839901

iter# = 1 MSE = 0.0343619834319
iter# = 2 MSE = 0.0699223575979
iter# = 3 MSE = 0.142079696301
iter# = 4 MSE = 0.412376685514
iter# = 5 MSE = 1.15243885539
iter# = 6 MSE = 4.01798792918
iter# = 7 MSE = 14.2079757137
iter# = 8 MSE = 53.368135365
iter# = 9 MSE = 203.594951636
iter# = 10 MSE = 802.066963904

iter# = 11 MSE = 3191.3320412
iter# = 12 MSE = 12906.9426382
iter# = 13 MSE = 53269.0355341
iter# = 14 MSE = 224943.846986
iter# = 15 MSE = 973665.201472
iter# = 16 MSE = 4257382.86503
iter# = 17 MSE = 18692796.3866
iter# = 18 MSE = 82371037.4764
iter# = 19 MSE = 364100534.649
iter# = 20 MSE = 1613521353.98

iter# = 21 MSE = 7168116403.28
iter# = 22 MSE = 31928645804.5
iter# = 23 MSE = 142615155018.0
iter# = 24 MSE = 637897124776.0
iter# = 25 MSE = 2.85545575893e+12
iter# = 26 MSE = 1.27824446331e+13
iter# = 27 MSE = 5.72207134294e+13
iter# = 28 MSE = 2.56149013357e+14
iter# = 29 MSE = 1.14665325897e+15
iter# = 30 MSE = 5.13300318342e+15

iter# = 31 MSE = 2.29779329081e+16
iter# = 32 MSE = 1.02860914392e+17
iter# = 33 MSE = 4.60457768293e+17
iter# = 34 MSE = 2.06124316181e+18
iter# = 35 MSE = 9.22717274943e+18
iter# = 36 MSE = 4.13055182075e+19
iter# = 37 MSE = 1.84904507664e+20
iter# = 38 MSE = 8.27726619554e+20
iter# = 39 MSE = 3.70532533453e+21
iter# = 40 MSE = 1.65869207421e+22

iter# = 41 MSE = 7.42514934227e+22
iter# = 42 MSE = 3.32387449197e+23
iter# = 43 MSE = 1.48793527633e+24
iter# = 44 MSE = 6.66075506727e+24
iter# = 45 MSE = 2.98169273705e+25
iter# = 46 MSE = 1.33475731931e+26
iter# = 47 MSE = 5.97505262467e+26
iter# = 48 MSE = 2.67473744861e+27
iter# = 49 MSE = 1.19734852032e+28
iter# = 50 MSE = 5.3599409537e+28

iter# = 51 MSE = 2.39938217983e+29
iter# = 52 MSE = 1.07408549733e+30
iter# = 53 MSE = 4.80815297065e+30
iter# = 54 MSE = 2.15237381444e+31
iter# = 55 MSE = 9.63511990024e+31
iter# = 56 MSE = 4.31316971379e+32
iter# = 57 MSE = 1.9307941336e+33
iter# = 58 MSE = 8.64321655241e+33
iter# = 59 MSE = 3.86914332666e+34
iter# = 60 MSE = 1.73202533935e+35

iter# = 61 MSE = 7.75342633463e+35
iter# = 62 MSE = 3.47082797006e+36
iter# = 63 MSE = 1.55371912724e+37
iter# = 64 MSE = 6.95523704191e+37
iter# = 65 MSE = 3.11351784638e+38
iter# = 66 MSE = 1.39376894293e+39
iter# = 67 MSE = 6.23921866557e+39
iter# = 68 MSE = 2.79299160412e+40
iter# = 69 MSE = 1.25028509479e+41
iter# = 70 MSE = 5.59691198477e+41

iter# = 71 MSE = 2.50546246579e+42
iter# = 72 MSE = 1.12157242861e+43
iter# = 73 MSE = 5.02072862715e+43
iter# = 74 MSE = 2.24753348998e+44
iter# = 75 MSE = 1.00611030066e+45
iter# = 76 MSE = 4.50386141788e+45
iter# = 77 MSE = 2.01615743902e+46
iter# = 78 MSE = 9.0253461236e+46
iter# = 79 MSE = 4.040203958e+47
iter# = 80 MSE = 1.80860077815e+48

iter# = 81 MSE = 8.09621694534e+48
iter# = 82 MSE = 3.6242784819e+49
iter# = 83 MSE = 1.6224113809e+50
iter# = 84 MSE = 7.26273850647e+50
iter# = 85 MSE = 3.25117114157e+51
iter# = 86 MSE = 1.45538955896e+52
iter# = 87 MSE = 6.51506388346e+52
iter# = 88 MSE = 2.91647395326e+53
iter# = 89 MSE = 1.30556207463e+54
iter# = 90 MSE = 5.84435986066e+54

iter# = 91 MSE = 2.61623272035e+55
iter# = 92 MSE = 1.17115882838e+56
iter# = 93 MSE = 5.24270257235e+56
iter# = 94 MSE = 2.34690031754e+57
iter# = 95 MSE = 1.05059194652e+58
iter# = 96 MSE = 4.7029838883e+58
iter# = 97 MSE = 2.10529478424e+59
iter# = 98 MSE = 9.42437021644e+59
iter# = 99 MSE = 4.2188274365e+60
iter# = 100 MSE = 1.888561732e+61

iter# = 101 MSE = 8.45416284324e+61
iter# = 102 MSE = 3.78451327108e+62
iter# = 103 MSE = 1.69414062215e+63
iter# = 104 MSE = 7.58383507211e+63
iter# = 105 MSE = 3.39491029547e+64
iter# = 106 MSE = 1.51973451488e+65
iter# = 107 MSE = 6.80310463229e+65
iter# = 108 MSE = 3.04541564231e+66
iter# = 109 MSE = 1.36328293268e+67
iter# = 110 MSE = 6.10274777838e+67

iter# = 111 MSE = 2.73190029404e+68
iter# = 112 MSE = 1.22293751727e+69
iter# = 113 MSE = 5.47449031869e+69
iter# = 114 MSE = 2.45066030163e+70
iter# = 115 MSE = 1.09704019267e+71
iter# = 116 MSE = 4.91090986188e+71
iter# = 117 MSE = 2.19837302522e+72
iter# = 118 MSE = 9.84103576307e+72
iter# = 119 MSE = 4.40534812697e+73
iter# = 120 MSE = 1.97205787958e+74

iter# = 121 MSE = 8.8279340663e+74
iter# = 122 MSE = 3.95183228069e+75
iter# = 123 MSE = 1.76904112077e+76
iter# = 124 MSE = 7.91912779867e+76
iter# = 125 MSE = 3.54500437301e+77
iter# = 126 MSE = 1.58692425784e+78
iter# = 127 MSE = 7.10388009477e+78
iter# = 128 MSE = 3.18005803689e+79
iter# = 129 MSE = 1.42355571646e+80
iter# = 130 MSE = 6.3725594136e+80

iter# = 131 MSE = 2.85268170471e+81
iter# = 132 MSE = 1.27700541968e+82
iter# = 133 MSE = 5.71652574905e+82
iter# = 134 MSE = 2.55900767029e+83
iter# = 135 MSE = 1.14554198548e+84
iter# = 136 MSE = 5.12802855471e+84
iter# = 137 MSE = 2.29556639488e+85
iter# = 138 MSE = 1.02761227186e+86
iter# = 139 MSE = 4.600115177e+86
iter# = 140 MSE = 2.05924551711e+87

iter# = 141 MSE = 9.21823027591e+87
iter# = 142 MSE = 4.12654871475e+88
iter# = 143 MSE = 1.84725308281e+89
iter# = 144 MSE = 8.2692443197e+89
iter# = 145 MSE = 3.70173433491e+90
iter# = 146 MSE = 1.65708456014e+91
iter# = 147 MSE = 7.4179532917e+91
iter# = 148 MSE = 3.32065317374e+92
iter# = 149 MSE = 1.48649325045e+93
iter# = 150 MSE = 6.65429982602e+93

iter# = 151 MSE = 2.97880304275e+94
iter# = 152 MSE = 1.33346374517e+95
iter# = 153 MSE = 5.96926192891e+95
iter# = 154 MSE = 2.67214523867e+96
iter# = 155 MSE = 1.19618811531e+97
iter# = 156 MSE = 5.35474639068e+97
iter# = 157 MSE = 2.39705682923e+98
iter# = 158 MSE = 1.07304455213e+99
iter# = 159 MSE = 4.80349317052e+99
iter# = 160 MSE = 2.1502878509e+100

iter# = 161 MSE = 9.62578206651e+100
iter# = 162 MSE = 4.30898962448e+101
iter# = 163 MSE = 1.92892291303e+102
iter# = 164 MSE = 8.63484001741e+102
iter# = 165 MSE = 3.86539356356e+103
iter# = 166 MSE = 1.73034675467e+104
iter# = 167 MSE = 7.74591213587e+104
iter# = 168 MSE = 3.46746423252e+105
iter# = 169 MSE = 1.55221334723e+106
iter# = 170 MSE = 6.94849640477e+106

iter# = 171 MSE = 3.11050039436e+107
iter# = 172 MSE = 1.39241817794e+108
iter# = 173 MSE = 6.23317195446e+108
iter# = 174 MSE = 2.79028478869e+109
iter# = 175 MSE = 1.24907338653e+110
iter# = 176 MSE = 5.59148776233e+110
iter# = 177 MSE = 2.50303430795e+111
iter# = 178 MSE = 1.12048546167e+112
iter# = 179 MSE = 5.01586281026e+112
iter# = 180 MSE = 2.24535530286e+113

iter# = 181 MSE = 1.00513523332e+114
iter# = 182 MSE = 4.49949652055e+114
iter# = 183 MSE = 2.01420348897e+115
iter# = 184 MSE = 9.01659924941e+115
iter# = 185 MSE = 4.03628841225e+116
iter# = 186 MSE = 1.80684798073e+117
iter# = 187 MSE = 8.08837053258e+117
iter# = 188 MSE = 3.62076602847e+118
iter# = 189 MSE = 1.6208390281e+119
iter# = 190 MSE = 7.25569985564e+119

iter# = 191 MSE = 3.24802028347e+120
iter# = 192 MSE = 1.45397907463e+121
iter# = 193 MSE = 6.50874983818e+121
iter# = 194 MSE = 2.91364746545e+122
iter# = 195 MSE = 1.30429679493e+123
iter# = 196 MSE = 5.83869582521e+123
iter# = 197 MSE = 2.61369721e+124
iter# = 198 MSE = 1.17002380499e+125
iter# = 199 MSE = 5.23762163039e+125
All Close xiht and x :False

Mean Squared Error AMP: 5.23762163039e+125

Any advice will be highly appreciated.



Answer



I'm by no means an expert in this, but I find the subject of compressed sensing very interesting, so I thought it'd be fun to play around with this.


I believe your error is in the generation of your sampling matrix, $\Phi$. According to the paper you reference "The convergence of this algorithm was proven in [1] under the condition that $\|\Phi\|_2 < 1$ ."


If you take the norm of your matrix A, you'll find that it's above 1, and that if you scale it such that the norm goes below 1 then your program will converge (actually, I found that as long as the norm of A was below about 2.8 it would converge, which I've yet to understand).


Hope that's helpful!


No comments:

Post a Comment

digital communications - Understanding the Matched Filter

I have a question about matched filtering. Does the matched filter maximise the SNR at the moment of decision only? As far as I understand, ...