-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOneColInv.py
67 lines (55 loc) · 1.68 KB
/
OneColInv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import List
import torch
def OneColInv(B, X0, col : List, posX0 : int, add = True):
col.sort()
X = X0[:, col]
if add == True:
col.append(posX0)
col.sort()
pos = col.index(posX0)
v = X0[:, posX0]
u1 = torch.matmul(X.T, v)
u2 = torch.matmul(B, u1)
F22inv = 1 / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
u3 = F22inv * u2
F11inv = B + F22inv * torch.matmul(u2.T, u2)
Bnew = torch.hstack([F11inv, -u3.reshape(F11inv.shape[0], -1)])
Bnew = torch.vstack([Bnew, torch.hstack([-u3.T, F22inv])])
l = []
for i in range(0, pos):
l.append(i)
l.append(Bnew.shape[1]- 1)
for i in range(pos, Bnew.shape[1]- 1):
l.append(i)
Bnew = Bnew[:, l]
l.clear()
for i in range(0, pos):
l.append(i)
l.append(Bnew.shape[0]- 1)
for i in range(pos, Bnew.shape[0]- 1):
l.append(i)
Bnew = Bnew[l, :]
# print(Bnew)
return Bnew
else:
pos = col.index(posX0)
l = []
for i in range(0, pos):
l.append(i)
for i in range(pos + 1, B.shape[0]):
l.append(i)
l.append(pos)
B = B[l,:]
l.clear()
for i in range(0, pos):
l.append(i)
for i in range(pos + 1, B.shape[1]):
l.append(i)
l.append(pos)
B = B[:,l]
F11inv = B[0 : B.shape[0] - 1, 0 : B.shape[1] - 1]
d = B[B.shape[0] - 1, B.shape[1] - 1]
u3 = -B[0 : B.shape[0] -1, B.shape[1] - 1]
u2 = u3/d
Bnew = F11inv - d * torch.matmul(u2, u2.T)
return Bnew