-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlearnSparseBasis.jl
More file actions
138 lines (124 loc) · 3.58 KB
/
learnSparseBasis.jl
File metadata and controls
138 lines (124 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#=
Learn a Sparse Basis for Natural Images
Author: Shashwat Shukla
Date: 9th March 2018
=#
using MAT # Import library to read .mat files
using Colors # Import library to convert array to image
using ImageView # Import library to display images
using Gtk.ShortNames # Import library to draw a canvas
using StatsBase # To compute mean and variance
# Define hyperparameters
const l = 12 # dimension of image patches
const d = l*l # dimension of flattened image patch
const n = 100 # number of basis vectors
const iter = 400000 # 400000 # number of iterations
const λ = 0.2 # Sparsity weight
const η = 0.001 # Learning rate
const batch_size = 1 # Batch size
file = matopen("IMAGES.mat") # Open file with whitened images
x = read(file, "IMAGES") # Extract the image matrix
close(file) # Close the file-stream
# Conjugate Descent for symmetric, positive semi-definite A
# Returns x that satisfies Ax=b
function conjugateDescent(A, θ, b)
ϵ = 1e-5
r = b - A*θ
p = r
while (norm(r) > ϵ)
α = dot(r,r)/dot(p,A*p)
θ = θ + α*p
β = 1.0/dot(r,r)
r = r - α*A*p
β = β*dot(r,r)
p = r + β*p
end
return θ
end
# Soft-thresholding function for use in LASSO regression
function soft_threshold(λ, β)
if (β > λ)
return (β - λ)
elseif (β < -λ)
return (β + λ)
else
return 0
end
end
# Coordinate Descent solution to LASSO
function coordinateDescent(Φ,y,θ,λ)
ϵ = 1e-5
θ_old = θ + 1
while (norm(θ-θ_old) > ϵ)
θ_old = deepcopy(θ)
for k = 1:n
r = y - Φ*θ + Φ[:,k]*θ[k]
θ[k] = soft_threshold(λ, dot(r,Φ[:,k])) / (norm(Φ[:,k])^2)
end
end
return θ
end
# Automatic Relevance Determination
# This was the method finally used
function automaticRelevanceDetermination(Φ,y,θ)
D = ones(n,1)
W = zeros((n,n))
for k = 1:10
dθ = Φ'*(y-Φ*θ) - D.*θ
ddθ = -Φ'*Φ - Diagonal(D[:,1])
θ = θ - ddθ\dθ
W = -inv(ddθ)
D = 1./diag(W + θ*θ')
end
return θ, W
end
# Seed the random number generator
s = Dates.second(now())
srand(s)
# Initialise the basis vectors
Φ = rand(d,n) - 0.5
norm_Φ = sqrt.(sum(Φ.*Φ,1))
Φ = Φ ./ norm_Φ
# Iterate over training samples
for k = 1:iter
∇ = 0
for g = 1:batch_size
p = rand(1:501)
q = rand(1:501)
y = x[p:(p+l-1), q:(q+l-1), rand(1:10)] # Random 12x12 patch
y = reshape(y,(d,1)) # Flatten y
y = zscore(y) # Normalize y
# E step
θ = zeros(n,1)
W = zeros((n,n))
# A = Φ'*Φ + λ; b = Φ'*y # For Gaussian Prior
# θ = conjugateDescent(A,θ,b) # For Gaussian Prior
θ = coordinateDescent(Φ,y,θ,λ) # For Lasso
# θ, W = automaticRelevanceDetermination(Φ,y,θ) # For ARD
# M step
∇ = ∇ -y*θ' + Φ*(W + θ*θ')
end
# Update the basis vectors
∇ = ∇ / batch_size
Φ = Φ - η*∇
# Normalize the basis vectors
norm_Φ = sqrt.(sum(Φ.*Φ,1))
Φ = Φ ./ norm_Φ
end
# Φ = zscore(Φ)
file = matopen("basis.mat", "w")
write(file, "basis", Φ)
close(file)
# Display the learnt basis vectors
grid, frames, canvases = canvasgrid((10,10))
for k = 0:9
for g = 1:10
sample = Φ[:,10*k+g]
sample = reshape(sample, (l,l))
img = Gray.(sample)
ImageView.imshow(canvases[k+1,g], img)
end
end
win = Window(grid)
showall(win)
println("Done")