Skip to content

Commit

Permalink
cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
Shunsuke-1994 committed Nov 22, 2023
1 parent 7ad8456 commit b3717ff
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 66 deletions.
4 changes: 2 additions & 2 deletions scripts/run_cmalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


parser = argparse.ArgumentParser()
parser.add_argument('--fasta', default = "", help='path to gzipped tracebackfile') # 必須の引数を追加
parser.add_argument('--cmfile', default = "/Users/sumishunsuke/Desktop/RNA/genzyme/datasets/legacy/RF00234/RF00234.cm", help='path to cm file') # 必須の引数を追加
parser.add_argument('--fasta', default = "", help='path to gzipped tracebackfile')
parser.add_argument('--cmfile', default = "/Users/sumishunsuke/Desktop/RNA/genzyme/datasets/legacy/RF00234/RF00234.cm", help='path to cm file')
parser.add_argument('--cpu', default=4, type = int)
args = parser.parse_args()

Expand Down
4 changes: 0 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def modifiedCELoss(pred, soft_targets, gamma = 0, summarize = True):
sotfmax = nn.Softmax(dim = -1)
ce = - soft_targets/scale * ((1-sotfmax(pred)).pow(gamma)) * logsoftmax(pred)
ce_colwise = torch.sum(ce, dim = -1)
# return torch.mean(ce_colwise)
if summarize:
return torch.sum(ce)
else:
Expand All @@ -31,7 +30,6 @@ def save_model(model, dir_name, pt_file):
if not os.path.isdir(dir_name):
os.mkdir(dir_name)
torch.save(model.state_dict(), os.path.join(dir_name , pt_file))
# print(f"Saved the model at {os.path.join(dir_name , pt_file)}")

def write_csv(d, dir_name, fname):
if not os.path.isdir(dir_name):
Expand All @@ -54,7 +52,6 @@ def write_csv(d, dir_name, fname):
from util import Timer, AnnealKL
from torch.utils.data import DataLoader

# assert torch.cuda.is_available(), "CUDA is not available."

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type = str, required=True, help = "directory containing X_train etc.")
Expand Down Expand Up @@ -99,7 +96,6 @@ def write_csv(d, dir_name, fname):
parser.add_argument('--log_dir', type = str, help = "directory for log output")
parser.add_argument("--print_every", default = 20, type = int, help = "iteration num to print log of learning (default: 20)")
args = parser.parse_args()
# pprint(vars(args))

# training
torch.manual_seed(args.random_seed)
Expand Down
18 changes: 8 additions & 10 deletions src/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def forward(self, x):

class Decoder(nn.Module):
"""
Convolutional decoder for C/G-VAE(split type).
Convolutional decoder for C/G-VAE.
"""
def __init__(self,
seq_len,
Expand Down Expand Up @@ -147,13 +147,12 @@ def __init__(self,
def get_padding_param(tr_len, s_len, p_len):

def outpads(leng):
# encode側の計算
# calc on encoder side
conv1_in = leng
conv2_in = int(((conv1_in + 2*(self.ker1//2) - self.ker1)/self.stride)+1)
conv3_in = int(((conv2_in + 2*0 - self.ker2)/1)+1)
conv3_out= int(((conv3_in + 2*0 - self.ker3)/1)+1)

# これがoutput_padding = 0の時, convtransposeで出てくる長さ
deconv3_out = int((conv3_out - 1) * 1 - 2*0 + (self.ker3 - 1) + 1)
deconv2_out = int((deconv3_out - 1) * 1 - 2*0 + (self.ker2 - 1) + 1)
deconv1_out = int((deconv2_out - 1) * self.stride - 2*(self.ker1//2) + (self.ker1 - 1) + 1)
Expand Down Expand Up @@ -228,6 +227,8 @@ def forward(self, z):
return self.tr_decode(h_tr), self.s_decode(h_s), self.p_decode(h_p)



# test
if __name__ == '__main__':
import h5py
from torch.distributions import Normal
Expand All @@ -251,28 +252,25 @@ def forward(self, z):
# Pass through some data
x = torch.from_numpy(data[:BATCH_SIZE]).transpose(-2, -1).float() # shape [batch, LEN_GRAMMAR, MAX_LEN]
print("input shape:", x.shape)
_, y = x.max(1) # 配列ごとに最大のindexをとるので-1
_, y = x.max(1)
print("x: ", x)

print("y: ", y)
# print(x.shape)
mu, logvar = encoder(x)

decoder = Decoder(z_dim = Z_DIM, hidden_size = HIDDEN_SIZE, len_grammar = LEN_GRAMMAR, max_len = MAX_LEN, decode_type="conv_mini", n_fc = 0)
print(decoder)
sigma = (0.5*logvar).exp()
normal = Normal(torch.zeros(mu.shape), torch.ones(sigma.shape))
eps = normal.sample()
z = mu + eps*torch.sqrt(sigma) # sigma
# print("z: ", z)
z = mu + eps*torch.sqrt(sigma)

criterion = torch.nn.CrossEntropyLoss() # class分類はmaxlen * batchsizeの分だけ行わないとダメ.
criterion = torch.nn.CrossEntropyLoss()

logits = decoder(z)
print("output shape:", logits.shape) # shape [batch, LEN_GRAMMAR, MAX_LEN]
logits = logits.transpose(1,2)
logits = logits.reshape(-1, logits.size(-1)) # batch x MAX_LEN, LEN_GRAMMAR
y = y.view(-1) # batch x seqlen = 75
y = y.view(-1)
print(y.shape, y)
print("logits: ", logits.shape)
loss = criterion(logits, y)
Expand Down
6 changes: 3 additions & 3 deletions src/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def forward(self, x):

class Encoder(nn.Module):
"""
version: 2021/12/17.
Convolutional encoder for C/G-VAE(split type).
Convolutional encoder for C/G-VAE.
"""
def __init__(self,
seq_len,
Expand Down Expand Up @@ -96,7 +95,7 @@ def forward(self, x):

class CovarianceModelEncoder(nn.Module):
"""
Convolutional encoder for CM-VAE(split type).
Convolutional encoder for CM-VAE.
Applies a series of one-dimensional convolutions to a batch
of tr/s/p encodings of the sequence of rules that generate
an artithmetic expression.
Expand Down Expand Up @@ -194,6 +193,7 @@ def forward(self, x):
return self.mu(h), self.logvar(h)


# test
if __name__ == '__main__':
# Load data
import grammar
Expand Down
2 changes: 1 addition & 1 deletion src/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def make_tree_from_node(node):
rule = derivation.popleft()
return Tree(node, [make_tree_from_node(next_node) if isinstance(next_node, Nonterminal) else next_node for next_node in rule.rhs()])
except:
# もし再構成できなかったらそれまでのrulesを返す
# for debug
if print_error:
print(rules)
return make_tree_from_node(start_node)
Expand Down
63 changes: 23 additions & 40 deletions src/infernal_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,18 @@ def cmemit(self, n=1, sample = False):
for _ in range(n):
seq_ss = "S_0"
current_state = "S_0"
# "E-> "となるのは
# When "E-> ",
# 1: next node == END ---> next_state = None and break
# 2: next node == BEGL ---> pop next_state from bif_stack
self._bif_stack = deque()
ins_loop_counter = 0
break_ins = False
while current_state != None:
# print("break_ins", break_ins)
rhs, next_state, probs_trans = self._fetch_rule_and_next(current_state, sample, break_ins = break_ins)
break_ins = False
# print(current_state, next_state, probs_trans)
if current_state == next_state:
ins_loop_counter += 1
# print("ins_loop_counter", ins_loop_counter)
# print("ins thresh", ins_loop_counter)
if ins_loop_counter == int(1/(1-probs_trans)): # call 4 times if 0.8, then break
if ins_loop_counter == int(1/(1-probs_trans)):
break_ins = True
ins_loop_counter = 0
seq_ss = seq_ss.replace(current_state, rhs, 1)
Expand Down Expand Up @@ -73,7 +69,7 @@ def _fetch_rule_and_next(self, current_state, sample, break_ins = False):
if break_ins: # escape from ins loop
prob_2nd = sorted(probs_trans, reverse=True)[1] #2nd max
next_state = child_states[probs_trans.index(prob_2nd)] #
# print(child_states, probs_trans.index(prob_2nd))

if current_state == next_state: #exception in case that several top probs
next_state = child_states[probs_trans.index(prob_2nd)+1]
else:
Expand Down Expand Up @@ -242,7 +238,7 @@ def _read_state_line(self, line):
prob = float(0)
prob_trans.update({num_to_state[lowest_child_idx + i]:prob})

# if BIF, lowest_child_idxの左の列はBIF_Rを意味する.
# if BIF, left col of lowest_child_idx = BIF_R
# prob o both splited state is 1
else:
prob_trans.update({num_to_state[lowest_child_idx]:float(1), num_to_state[int(state_line[6])]:float(1)})
Expand Down Expand Up @@ -461,28 +457,27 @@ def _make_tbdict_from_tbdf(self,tbdf):

def make_aligned_tbdict_from_tbdf(self,tbdf):
"""
fill zeros in missing st
tbdictの欠損値に0を補完する.
テンプレとしてcm_dictを利用して編集する.
Fill zeros in missing val in tbdict.
Use cm_dict as a template.
"""
aligned_tbdict = copy.deepcopy(self.cm_deriv_dict)
tbdict = self._make_tbdict_from_tbdf(tbdf)
# node_dictの全てのkeyが存在するかを調べる
# 存在すれば, countを当てはめる
# 存在しなければ0をpadding.
# Is there all keys in node_dict?
# if there, assign count
# else padding with zero.
for node, states_in_nodes in aligned_tbdict.items():
for parent_state, trans_emit in states_in_nodes.items():
# parent stateがない場合には, 全て0を当てる
# no parent state, assign zero
if not parent_state in tbdict:
for child_state, prob in trans_emit["trans"].items():
aligned_tbdict[node][parent_state]["trans"][child_state] = 0
for nuc, prob in trans_emit["emit"].items():
aligned_tbdict[node][parent_state]["emit"][nuc] = 0
else:
# panret stateが存在する場合は, さらに深く調べる.
# any panret state, search deeper
for child_state, prob in trans_emit["trans"].items():
# 最尤推定された確率を計算するために合計transition数でわる
# BIFは特殊な例とする.
# devide by num of transitions.
# BIF is a special case
sum_count_from_parent = 1 if "B" in parent_state else sum(tbdict[parent_state]["trans"].values())
if child_state in tbdict[parent_state]["trans"]:
val = tbdict[parent_state]["trans"][child_state]/sum_count_from_parent
Expand All @@ -502,28 +497,23 @@ def make_aligned_tbdict_from_tbdf(self,tbdf):

def make_aligned_tbdict_from_tbdf_ELinitCM(self,tbdf):
"""
tbdictの欠損値に0を補完する.
テンプレとしてcm_dictを利用して編集する.
EL stateが出現した時には, cm_deriv_dictの値を全て取ってくる.
last modified: 2021-04-05-2312
Complements 0 to missing values in tbdict.
Edit cm_dict as a template.
When EL state appears, all the values of cm_deriv_dict are taken.
"""
aligned_tbdict = copy.deepcopy(self.cm_deriv_dict)
tbdict = self._make_tbdict_from_tbdf(tbdf)
# node_dictの全てのkeyが存在するかを調べる
# 存在すれば, countを当てはめる
# 存在しなければ0をpadding.
# Is there all keys in node_dict?
# if there, assign count
# else padding with zero.
modeEL = False
for node, states_in_nodes in aligned_tbdict.items():
for parent_state, trans_emit in states_in_nodes.items():
#ELmodeのときにS stateが出てきたらELmodeは終了
# ELmode ends when S state appears during ELmode
if modeEL and ("S" in parent_state):
modeEL = False

# parent stateがあるとき
if parent_state in tbdict.keys():
# parent stateが存在する場合は, さらに深く調べる.
# EL stateになる際には, transはcmから取って来るが, emitはする.
# そのため, emitはmodeEL=Falseの条件で取り, transはmodeEL = Trueの条件でとる.
for nuc, prob in trans_emit["emit"].items():
sum_count_from_parent = sum(tbdict[parent_state]["emit"].values())
if nuc in tbdict[parent_state]["emit"]:
Expand All @@ -534,15 +524,13 @@ def make_aligned_tbdict_from_tbdf_ELinitCM(self,tbdf):
count = aligned_tbdict[node][parent_state]["emit"][nuc]
aligned_tbdict[node][parent_state]["emit"][nuc] = count

# tbdictのchildにEL stateがある場合はmode ELを発動させる.
# If tbdict's child has EL state, ELmode.
for tbchild in tbdict[parent_state]["trans"]:
if "EL_" in tbchild:
modeEL = True
break

for child_state, prob in trans_emit["trans"].items():
# 最尤推定を計算するために合計transition数でわる
# BIFは特殊な例とする.
sum_count_from_parent = 1 if "B" in parent_state else sum(tbdict[parent_state]["trans"].values())
if child_state in tbdict[parent_state]["trans"]:
val = tbdict[parent_state]["trans"][child_state]/sum_count_from_parent
Expand Down Expand Up @@ -583,12 +571,10 @@ def make_deriv_dict_from_trsp(cm_deriv_dict, trsp):
if nuc in {'A', 'C', 'G', 'U'}:
rule = Production(n, [nuc])
rule_i = all_rules.index(rule) -56
# rule_i = cmreader.cfg.productions().index(rule) -56
rule_val = s[rule_i, s_i]
else:# double emissin
rule = Production(nl_nr, [nuc])
rule_i = all_rules.index(rule) -60
# rule_i = cmreader.cfg.productions().index(rule) -60
rule_val = p[rule_i, p_i]
dirty_deriv_dict[node][parent_state]["emit"][nuc] = rule_val

Expand All @@ -614,7 +600,6 @@ def make_deriv_dict_from_trsp(cm_deriv_dict, trsp):
raise Exception(f"Unidentified parent_state_type: {parent_state_type}")

rule_i = all_rules.index(rule)
# rule_i = cmreader.cfg.productions().index(rule)
rule_val = tr[rule_i, node_i]
dirty_deriv_dict[node][parent_state]["trans"][child_state] = rule_val

Expand All @@ -635,14 +620,12 @@ def cleanup_deriv_dict(dirty_deriv_dict):
# conversion of derivdict to tr/s/p
def make_trsp_from_deriv_dict(path_to_cmfile, deriv_dict):
"""
last modified: 2021-04-05
function to make trsp(onehot) from dictionary of CM.
"""
cmreader = CMReader(path_to_cmfile)
trans_map, single_map, pair_map = [], [], []

for node, states in deriv_dict.items():
# if re.match(r"(BIF|END|BEG)", node) == None:
# EL stateとかいうのがあり, skipされるstateがあるので+1
trans_col = [np.nan]*56
for current_state_name, trans_emit in states.items():

Expand Down Expand Up @@ -672,7 +655,7 @@ def make_trsp_from_deriv_dict(path_to_cmfile, deriv_dict):
torch.from_numpy(np.vstack(single_map)),\
torch.from_numpy(np.vstack(pair_map))


# test
if __name__ == '__main__':
# reader = TracebackFileReader(
# "/Users/sumishunsuke/Desktop/RNA/genzyme/datasets/RF00163/RF00163.cm",
Expand Down
1 change: 0 additions & 1 deletion src/metric_helper/alignment_mi_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def _pairwise_MI(align_i_j):
# entropy
col_i = align[:, i]
nuc_count = Counter(col_i)
# mi_diag = [(c/aln_size)*(np.log2(c) - np.log2(aln_size) + 2) for nuc, c in nuc_count.items() if nuc != "-"]
mi_diag = [(c/aln_size)*(np.log2(c) - np.log2(aln_size) + 2) for nuc, c in nuc_count.items()]

return i, i, sum(mi_diag)
Expand Down
8 changes: 4 additions & 4 deletions src/models/CMVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __getitem__(self, index):

class CovarianceModelVAE(nn.Module):
"""
CM-VAE(split type)
CM-VAE. Encode and Decode CM or alignment on CM.
"""
def __init__(self,
hidden_encoder_size,
Expand Down Expand Up @@ -155,7 +155,7 @@ def build_from_config(path):
return model


if __name__ == "__main__":
cmvae = CovarianceModelVAE.build_CMVAE_from_config("./outputs/EXP06/EXP06-31/config.yaml")
print(cmvae)
# if __name__ == "__main__":
# cmvae = CovarianceModelVAE.build_CMVAE_from_config("./outputs/EXP06/EXP06-31/config.yaml")
# print(cmvae)

2 changes: 1 addition & 1 deletion src/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def refold(aln_file, n_turn = 2):
"""
refold sequences from aligned sequences.
1. run RNAalifold -f S --SS_cons
2. run refold.pl −−turn 2
2. run refold.pl --turn 2
3. RNAfold -C --enforceConstraint
ref: https://www.tbi.univie.ac.at/RNA/refold.1.html
"""
Expand Down

0 comments on commit b3717ff

Please sign in to comment.