-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsvd.py
More file actions
75 lines (61 loc) · 2.1 KB
/
svd.py
File metadata and controls
75 lines (61 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from sklearn.utils.extmath import randomized_svd
import numpy as np
from evd import EVD_decomposition
np.seterr(divide='ignore', invalid='ignore')
def sklearn_svd_implementation(channel, k):
u, s, vt = randomized_svd(channel,
n_components=k,
n_iter=5,
random_state=None)
return u, np.diag(s), vt
def numpy_svd_implementation(channel, k):
u, s, vt = np.linalg.svd(channel)
return u, np.diag(s), vt
def custom_svd_implementation(a, k):
def to_square(arr):
arr = arr.copy()
new_size = min(arr.shape)
return arr[:new_size, :new_size]
def pseudo_inv(arr):
arr = arr.copy()
indices = np.nonzero(arr)
arr[indices] = 1 / arr[indices]
return arr
m, n = a.shape
if m > n:
cnn = a.T @ a
v, lv, vt = EVD_decomposition(cnn, True)
# calculate T matrix
t = np.zeros((m, n))
diag = np.diag(lv)
# remove values below zero from floating-point errors
diag = diag.clip(min=0)
t[:n, :n] = np.diag(np.sqrt(diag))
# calculate U matrix
u = np.zeros((m, m))
u[:, :n] = a @ v @ pseudo_inv(to_square(t))
# fill rest of columns with linear independent vectors
u[:, n:] = np.random.rand(m, m - n)
# make them orthogonal
q, r = np.linalg.qr(u) # changed to QR decomposition from numpy
u = q @ r
else:
rmm = a @ a.T
u, lu, ut = EVD_decomposition(rmm, True)
# calculate T matrix
t = np.zeros((m, n))
diag = np.diag(lu)
# remove values below zero from floating-point errors
diag = diag.clip(min=0)
t[:m, :m] = np.diag(np.sqrt(diag))
# calculate V matrix
v = np.zeros((n, n))
v[:m, :] = pseudo_inv(to_square(t)) @ u.T @ a
# fill rest of rows with linear independent vectors
v[m:, :] = np.random.rand(n - m, n)
# make them orthogonal
q, r = np.linalg.qr(v.T)
vt = (q @ r).T
if k <= 0:
return u @ t @ vt
return u, t, vt