-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_scrambled_motif_expectations.py
More file actions
96 lines (81 loc) · 4.22 KB
/
generate_scrambled_motif_expectations.py
File metadata and controls
96 lines (81 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import numpy as np
from Bio import SeqIO
from evals import pairwise_edit_distance
from evals import JasparMotifScanner
from pyjaspar import jaspardb
from multiprocessing import Pool
def scan_fasta_file(enhancer_fn, data_dir, scanner, scrambled_scanners, motifs, outdir):
path_exists = False
if os.path.exists(os.path.join(outdir, enhancer_fn.split(".")[0] + "_real_dists.tsv")):
if os.path.exists(os.path.join(outdir, enhancer_fn.split(".")[0] + "_real_counts.tsv")):
path_exists = True
if path_exists:
pass
else:
print ("Working on", enhancer_fn)
sequences = []
with open(os.path.join(data_dir, enhancer_fn)) as handle:
for record in SeqIO.parse(handle, "fasta"):
sequences.append(str(record.seq).upper())
# Get average and stdev of pairwise distances
human_sequence = [sequences[0]]
real = scanner.scan_motifs(human_sequence, return_counts=True, return_max_score=True)
expected_counts = []
expected_max_scores = []
for i in range (0, len(scrambled_scanners)):
expected = scrambled_scanners[i].scan_motifs(human_sequence, return_counts=True, return_max_score=True)
expected_counts.append(expected['count'][0])
expected_max_scores.append(expected['max_score'][0])
expected_counts = np.array(expected_counts)
expected_max_scores = np.array(expected_max_scores)
print (np.mean(expected_counts, axis=0))
print (np.std(expected_counts, axis=0))
print (np.mean(expected_max_scores, axis=0))
print (np.std(expected_max_scores, axis=0))
'''outfile_n = os.path.join(outdir, enhancer_fn.split(".")[0] + "_real_counts.tsv")
output = open(outfile_n, "w")
output.write("Motif\tCounts\tScrambled Counts (Mean)\tScrambled Counts (Stdev)\tMax Score\tScrambled Max Scores (Mean)\tScrambled Max Scores (Stdev)\n")
for i, motif in enumerate(motifs):
output.write(str(motif.matrix_id) + "\t")
output.write(str(counts['count'][0][i]) + "\t")
output.write(str(expectations['count_mean'][0][i]) + "\t")
output.write(str(expectations['count_std'][0][i]) + "\t")
output.write(str(counts['max_score'][0][i]) + "\t")
output.write(str(expectations['max_score_mean'][0][i]) + "\t")
output.write(str(expectations['max_score_std'][0][i]) + "\n")
output.close()'''
if __name__=="__main__":
nprocesses = 20
n_scrambled_motifs = 100
input_files = []
base_data_dir = '/data1/human_enhancer_msas_fixed/test/'
base_out_dir = '/data1/human_enhancer_motif_scan_baselines_scrambled_motifs/test/'
scrambed_motif_weight_dir = '/data1/human_enhancer_scrambled_motif_weights/'
if not os.path.exists(scrambed_motif_weight_dir):
os.makedirs(scrambed_motif_weight_dir)
jaspar_release = "JASPAR2024"
jdb_obj = jaspardb(release=jaspar_release)
vertebrate_motifs = jdb_obj.fetch_motifs(collection = "CORE", tax_group = ['vertebrates'])
vertebrate_scanner = JasparMotifScanner(vertebrate_motifs)
scrambled_scanners = []
for i in range (0, 100):
if not os.path.exists(scrambed_motif_weight_dir + "scrambled_motifs_" + str(i) + "_reverse.pt"):
scrambled = JasparMotifScanner(vertebrate_motifs, scrambled=True)
scrambled.save_weights(scrambed_motif_weight_dir + "scrambled_motifs_" + str(i) + "_")
scrambled_scanners.append(scrambled)
else:
scrambled = JasparMotifScanner(vertebrate_motifs, scrambled=True)
scrambled.load_weights(scrambed_motif_weight_dir + "scrambled_motifs_" + str(i) + "_")
scrambled_scanners.append(scrambled)
for fasta_fn in os.listdir(base_data_dir):
input_files.append([fasta_fn,
base_data_dir,
vertebrate_scanner,
scrambled_scanners,
vertebrate_motifs,
base_out_dir])
if not os.path.exists(base_out_dir):
os.makedirs(base_out_dir)
with Pool() as pool:
pool.starmap(scan_fasta_file, input_files, chunksize = int(np.ceil(len(input_files) / 21)))