-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate_enhancers.py
More file actions
91 lines (74 loc) · 3.44 KB
/
generate_enhancers.py
File metadata and controls
91 lines (74 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json
import pathlib
import numpy as np
import torch
from torch.utils.data import DataLoader
from enhancer_dataset import EnhancerDataset
from sequence_models.esm import MSATransformer
from sequence_models.constants import ENHANCER_ALPHABET, GAP, MSA_PAD, MASK
from sequence_models.collaters import MSAAbsorbingCollater
home = str(pathlib.Path.home())
def main(config_fpath, valid_dir, outname, checkpointname):
_ = torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda:3')
with open(config_fpath, 'r') as f:
config = json.load(f)
d_embed = config['d_embed']
d_hidden = config['d_hidden']
n_layers = config['n_layers']
n_heads = config['n_heads']
data_dir = valid_dir
n_sequences = config['n_sequences']
max_seq_len = config['max_seq_len']
collater = MSAAbsorbingCollater(alphabet=ENHANCER_ALPHABET, bert=False)
tokenizer = collater.tokenizer
padding_idx = tokenizer.alphabet.index(MSA_PAD) # PROTEIN_ALPHABET.index(PAD)
masking_idx = tokenizer.alphabet.index(MASK)
gap_idx = tokenizer.alphabet.index(GAP)
print(tokenizer.alphabet)
print('Using {} as padding index'.format(padding_idx))
print('Using {} as masking index'.format(masking_idx))
print('Using {} as gap index'.format(gap_idx))
model = MSATransformer(d_embed, d_hidden, n_layers, n_heads, use_ckpt=True, n_tokens=len(ENHANCER_ALPHABET),
padding_idx=padding_idx, mask_idx=masking_idx, tie_weights=False).to(device)
sd = torch.load(checkpointname, map_location=torch.device('cpu'))
msd = sd['model_state_dict']
msd = {k.split('module.')[1]: v for k, v in msd.items()}
model.load_state_dict(msd)
model.eval()
ds_valid = EnhancerDataset(n_sequences, max_seq_len, data_dir=data_dir)
dl_valid = DataLoader(dataset=ds_valid, batch_size=1, collate_fn=collater, num_workers=1)
outwrite = open(outname, "w")
outwrite.write("file\tidx\tgenerated\tgroundtruth\n")
for i, batch in enumerate(dl_valid):
src, tgt, mask = batch
new_tgt = tgt.clone()
# Replace the last sequence with all mask tokens
new_tgt[:, -1, :] = torch.ones_like(new_tgt[:, -1, :]) * 7
outputs = model(tgt.to(device))
outputs = outputs.cpu().detach().numpy()
filename = ds_valid.filenames[i]
for j in range (0, len(outputs[0])):
output = str(tokenizer.untokenize(np.argmax(outputs[0][j], axis=1).astype(int)))
target = str(tokenizer.untokenize(new_tgt.cpu().detach().numpy()[0][j].astype(int)))
tmp_output = list(output)
tmp_target = list(target)
fixed_output = []
for x in range (0, len(tmp_output)):
if tmp_target[x] != "#":
fixed_output.append(tmp_target[x])
else:
fixed_output.append(tmp_output[x])
output = ''.join(fixed_output)
outwrite.write(filename + '\t')
outwrite.write(str(j) + "\t")
outwrite.write(output + "\t")
outwrite.write(target + '\n')
print (filename)
if __name__ == '__main__':
config_fpath = "./configMSA.json"
outname = "./generated_enhancers/gdiff3_generated_enhancers_new.tsv"
valid_dir = './toy_dataset/valid/'
checkpointname = '/home/lualex/genomic_diffusion/enhancer_diffusion_model/checkpoints/ediff3/scale-noclip-4e-4/checkpoint47606.tar'
main(config_fpath, outname, checkpointname)