diff --git a/misc/training/README.md b/misc/training/README.md new file mode 100644 index 0000000..10efa7c --- /dev/null +++ b/misc/training/README.md @@ -0,0 +1,5 @@ +Use generate_SNP_pileups.py to generate training features for SNP calling, and then run model_run.py to training a model. +Use generate_indel_pileups.py or generate_indel_pileups_hifi.py (for PacBio HiFi) to generate training features for indel calling, and then run model_run_indels.py to training a model. + +You can print help for each script using `--help`, e.g. `python generate_SNP_pileups.py --help` +This code uses tensorflow 1.13 for training a model. diff --git a/misc/training/generate_SNP_pileups.py b/misc/training/generate_SNP_pileups.py new file mode 100644 index 0000000..28e2b6a --- /dev/null +++ b/misc/training/generate_SNP_pileups.py @@ -0,0 +1,480 @@ +import sys, pysam, time, os, copy, argparse, random, datetime +from collections import Counter +import numpy as np +import multiprocessing as mp +from pysam import VariantFile +from intervaltree import Interval, IntervalTree + +base_to_num_map={'*':4,'A':0,'G':1,'T':2,'C':3,'N':4} + +def in_bed(tree,pos): + return tree.overlaps(pos) + +def get_cnd_pos(v_pos,cnd_pos, seq='ont'): + if seq=='ont': + ls=cnd_pos[abs(cnd_pos-v_pos)<50000] + + ls1_0= [p for p in ls if (p>=v_pos-2000) & (p=v_pos-5000) & (p=v_pos-10000) & (p=v_pos-20000) & (pv_pos) & (p<=v_pos+2000)][-2:] + ls2_1= [p for p in ls if (p>v_pos+2000) & (p<=v_pos+5000)][:3] + ls2_2= [p for p in ls if (p>v_pos+5000) & (p<=v_pos+10000)][:4] + ls2_3= [p for p in ls if (p>v_pos+10000) & (p<=v_pos+20000)][:5] + ls2_4= [p for p in ls if (p>v_pos+20000)][:6] + + ls_total_1=sorted(ls1_0+ls1_1+ls1_2+ls1_3+ls1_4) + ls_total_2=sorted(ls2_0+ls2_1+ls2_2+ls2_3+ls2_4) + + elif seq=='ul_ont': + ls=cnd_pos[abs(cnd_pos-v_pos)<100000] + + ls1_0= [p for p in ls if (p>=v_pos-2000) & (p=v_pos-5000) & (p=v_pos-10000) & (p=v_pos-20000) & (p=v_pos-40000) & (p=v_pos-50000) & (pv_pos) & (p<=v_pos+2000)][-2:] + ls2_1= [p for p in ls if (p>v_pos+2000) & (p<=v_pos+5000)][:2] + ls2_2= [p for p in ls if (p>v_pos+5000) & (p<=v_pos+10000)][:3] + ls2_3= [p for p in ls if (p>v_pos+10000) & (p<=v_pos+20000)][:3] + ls2_4= [p for p in ls if (p>v_pos+20000) & (p<=v_pos+40000)][:4] + ls2_5= [p for p in ls if (p>v_pos+40000) & (p<=v_pos+50000)][:3] + ls2_6= [p for p in ls if (p>v_pos+50000)][:3] + + ls_total_1=sorted(ls1_0+ls1_1+ls1_2+ls1_3+ls1_4+ls1_5+ls1_6) + ls_total_2=sorted(ls2_0+ls2_1+ls2_2+ls2_3+ls2_4+ls2_5+ls2_6) + + elif seq=='ul_ont_extreme': + ls=cnd_pos[abs(cnd_pos-v_pos)<300000] + + ls1_0= [p for p in ls if (p>=v_pos-10000) & (p=v_pos-20000) & (p=v_pos-50000) & (p=v_pos-75000) & (p=v_pos-100000) & (p=v_pos-200000) & (pv_pos) & (p<=v_pos+10000)][-2:] + ls2_1= [p for p in ls if (p>v_pos+10000) & (p<=v_pos+20000)][:2] + ls2_2= [p for p in ls if (p>v_pos+20000) & (p<=v_pos+50000)][:3] + ls2_3= [p for p in ls if (p>v_pos+50000) & (p<=v_pos+75000)][:3] + ls2_4= [p for p in ls if (p>v_pos+75000) & (p<=v_pos+100000)][:4] + ls2_5= [p for p in ls if (p>v_pos+100000) & (p<=v_pos+200000)][:4] + ls2_6= [p for p in ls if (p>v_pos+200000)][:2] + + ls_total_1=sorted(ls1_0+ls1_1+ls1_2+ls1_3+ls1_4+ls1_5+ls1_6) + ls_total_2=sorted(ls2_0+ls2_1+ls2_2+ls2_3+ls2_4+ls2_5+ls2_6) + + elif seq=='new_pcb': + ls=cnd_pos[abs(cnd_pos-v_pos)<20000] + + ls1_0= [p for p in ls if (p>=v_pos-2000) & (p=v_pos-5000) & (p=v_pos-10000) & (p=v_pos-20000) & (pv_pos) & (p<=v_pos+2000)][-4:] + ls2_1= [p for p in ls if (p>v_pos+2000) & (p<=v_pos+5000)][:5] + ls2_2= [p for p in ls if (p>v_pos+5000) & (p<=v_pos+10000)][:5] + ls2_3= [p for p in ls if (p>v_pos+10000) & (p<=v_pos+20000)][:6] + + ls_total_1=sorted(ls1_0+ls1_1+ls1_2+ls1_3) + ls_total_2=sorted(ls2_0+ls2_1+ls2_2+ls2_3) + + elif seq=='old_pcb': + ls=cnd_pos[abs(cnd_pos-v_pos)<20000] + + ls_total_1= [p for p in ls if (p>=v_pos-20000) & (pv_pos) & (p<=v_pos+20000)][:20] + + return ls_total_1, ls_total_2 + +def get_nbr(dct, nbr_type='freq'): + chrom=dct['chrom'] + start=max(dct['start']-50000,1) + end=dct['end']+50000 + + sam_path=dct['sam_path'] + fasta_path=dct['fasta_path'] + samfile = pysam.Samfile(sam_path, "rb") + fastafile=pysam.FastaFile(fasta_path) + + tbx = pysam.TabixFile(dct['exclude_bed']) + exclude_intervals=IntervalTree(Interval(int(row[1]), int(row[2]), "%s" % (row[1])) for row in tbx.fetch(chrom, parser=pysam.asBed())) + + + output_seq={} + + flag=0x4|0x100|0x200|0x400|0x800 + + if nbr_type=='freq': + rlist=[s for s in fastafile.fetch(chrom,start-1,end-1)] + + for pcol in samfile.pileup(chrom,start-1,end-1,min_base_quality=0, flag_filter=flag,truncate=True): + + r=rlist[pcol.pos+1-start] + if r in 'AGTC' and not in_bed(exclude_intervals,pcol.pos+1): + n=pcol.get_num_aligned() + seq=''.join([x[0] for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False,add_indels=True)]).upper() + alt_freq=max([x[1] for x in Counter(seq).items() if (x[0]!=r and x[0] in 'AGTC')]+[0])/n + + if dct['threshold'][0]<=alt_freq and alt_freq=dct['mincov']: + name=pcol.get_query_names() + output_seq[pcol.pos+1]={n:base_to_num_map[s] for (n,s) in zip(name,seq)} + output_seq[pcol.pos+1]['ref']=base_to_num_map[r] + + elif nbr_type=='gtruth': + gt_map={(0,0):0, (1,1):0, (2,2):0, (1,2):1, (2,1):1, (0,1):1, (1,0):1, (0,2):1, (2,0):1, (1,None):0,(None,1):0} + + bcf_in = VariantFile(dct['vcf_path']) + + ground_truth={} + for rec in bcf_in.fetch(chrom,start,end+1): + gt=rec.samples.items()[0][1].get('GT') + if gt_map[gt]: + if base_to_num_map[rec.alleles[gt[0]]]<4 and base_to_num_map[rec.alleles[gt[1]]]<4: + ground_truth[rec.pos]=rec.ref + + for pcol in samfile.pileup(chrom,start-1,end-1,min_base_quality=0, flag_filter=flag,truncate=True): + + if pcol.pos+1 in ground_truth: + n=pcol.get_num_aligned() + r=ground_truth[pcol.pos+1] + + + if n>=dct['mincov']: + name=pcol.get_query_names() + seq=''.join([x[0] for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False,add_indels=True)]).upper() + output_seq[pcol.pos+1]={n:base_to_num_map[s] for (n,s) in zip(name,seq)} + output_seq[pcol.pos+1]['ref']=base_to_num_map[r] + + return output_seq + +def get_snp_training_pileups(dct): + + chrom=dct['chrom'] + start=dct['start'] + end=dct['end'] + + include_intervals = None + + tbx = pysam.TabixFile(dct['include_bed']) + include_intervals=IntervalTree(Interval(int(row[1]), int(row[2]), "%s" % (row[1])) for row in tbx.fetch(chrom, parser=pysam.asBed())) + + + sam_path=dct['sam_path'] + fasta_path=dct['fasta_path'] + vcf_path=dct['vcf_path'] + + bcf_in = VariantFile(vcf_path) + + gt_map={(0,0):0, (1,1):0, (2,2):0, (1,2):1, (2,1):1, (0,1):1, (1,0):1, (0,2):1,(2,0):1} + tr_pos={} + for rec in bcf_in.fetch(chrom,start,end+1): + gt=rec.samples.items()[0][1].get('GT') + if gt in gt_map: + if base_to_num_map[rec.alleles[gt[0]]]<4 and base_to_num_map[rec.alleles[gt[1]]]<4: + tr_pos[rec.pos]=(gt_map[gt],base_to_num_map[rec.alleles[gt[0]]],base_to_num_map[rec.alleles[gt[1]]]) + + + nbr_size=20 + + cnd_seq_freq=get_nbr(dct, nbr_type='freq') + cnd_pos_freq=np.array(list(cnd_seq_freq.keys())) + + cnd_seq_gtruth=get_nbr(dct, nbr_type='gtruth') + cnd_pos_gtruth=np.array(list(cnd_seq_gtruth.keys())) + + samfile = pysam.Samfile(sam_path, "rb") + fastafile=pysam.FastaFile(fasta_path) + + ref_dict={j:s.upper() if s in 'AGTC' else '*' for j,s in zip(range(max(1,start-40),end+40+1),fastafile.fetch(chrom,max(1,start-40)-1,end+40)) } + + pileup_dict={} + + output={'pos':[],0:[],5:[],10:[],15:[],20:[],25:[]} + + flag=0x4|0x100|0x200|0x400|0x800 + + for pcol in samfile.pileup(chrom,max(0,start-1),end,min_base_quality=0,\ + flag_filter=flag,truncate=True): + + r=ref_dict[pcol.pos+1] + if in_bed(include_intervals, pcol.pos+1) and r in 'AGTC': + n=pcol.get_num_aligned() + + if n<=dct['maxcov'] and n>=dct['mincov'] and pcol.pos+1>=start and pcol.pos+1<=end: + + + if pcol.pos+1 in tr_pos: + seq=''.join([x[0] for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False,add_indels=True)]).upper() + name=pcol.get_query_names() + pileup_dict[pcol.pos+1]={n:base_to_num_map[s] for (n,s) in zip(name,seq)} + output['pos'].append(pcol.pos+1) + + else: + seq=''.join([x[0] for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False,add_indels=True)]).upper() + alt_freq=max([x[1] for x in Counter(seq).items() if (x[0]!=r and x[0] in 'AGTC')]+[0])/n + + if alt_freq>=0.10: + name=pcol.get_query_names() + pileup_dict[pcol.pos+1]={n:base_to_num_map[s] for (n,s) in zip(name,seq)} + + if 0.10<=alt_freq<0.15: + output[10].append(pcol.pos+1) + elif 0.15<=alt_freq<0.20: + output[15].append(pcol.pos+1) + elif 0.20<=alt_freq<0.25: + output[20].append(pcol.pos+1) + elif 0.25<=alt_freq: + output[25].append(pcol.pos+1) + + elif np.random.randint(2): + seq=''.join([x[0] for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False,add_indels=True)]).upper() + name=pcol.get_query_names() + + pileup_dict[pcol.pos+1]={n:base_to_num_map[s] for (n,s) in zip(name,seq)} + + if alt_freq<0.05: + output[0].append(pcol.pos+1) + + elif 0.05<=alt_freq<0.10: + output[5].append(pcol.pos+1) + + + pileup_list={'pos':[],'neg':[]} + + if output['pos']: + tr_len=len(output['pos']) + else: + tr_len=1e16 + + sizes={0:tr_len, 5:tr_len//3,10:tr_len//3,15:tr_len//3, 20:tr_len, 25:tr_len} + + + for instance in ['pos',0,5,10,15,20,25]: + pos_list=output[instance] + + if pos_list: + if instance!='pos': + if sizes[instance] maxcov: + sample=random.sample(sample,min(len(sample),maxcov)) + + sample=sorted(sample) + + fa_tmp_file=''.join(['>%s_SEQ\n%s\n'%(read_name,seq_list[read_name]) for read_name in sample]) + + + fa_tmp_file+='>ref_SEQ\n%s' %ref + + gap_penalty=1.0 + msa_process =Popen(['muscle', '-quiet','-gapopen','%.1f' %gap_penalty,'-maxiters', '1' ,'-diags1'], stdout=PIPE, stdin=PIPE, stderr=PIPE) + hap_file=msa_process.communicate(input=fa_tmp_file.encode('utf-8')) + + if len(hap_file)==0: + print('hapfile length 0') + + + tmp=hap_file[0].decode('utf-8')[1:].replace('\n','').split('>') + + zz_0=[] + for seq in tmp: + p1,p2=seq.split('_SEQ') + if p1!='ref': + zz_0.append(p2[:128]) + else: + ref_real_0=p2 + + if len(zz_0)=dct['mincov'] and len(read_names_1)>=dct['mincov']: + output['pos'].append(pcol.pos+1) + + if in_bed(include_intervals, v_pos) and not ex_bed(exclude_intervals, v_pos): + read_names=pcol.get_query_names() + read_names_0=set(read_names) & hap_reads_0 + read_names_1=set(read_names) & hap_reads_1 + len_seq_0=len(read_names_0) + len_seq_1=len(read_names_1) + + if len_seq_0>=dct['mincov'] and len_seq_1>=dct['mincov']: + seq=[x for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False, add_indels=True)] + + tmp_seq_0=''.join([s for n,s in zip(read_names,seq) if n in read_names_0]) + tmp_seq_1=''.join([s for n,s in zip(read_names,seq) if n in read_names_1]) + + del_freq_0=(tmp_seq_0.count('-'))/len_seq_0 if len_seq_0>0 else 0 + ins_freq_0=tmp_seq_0.count('+')/len_seq_0 if len_seq_0>0 else 0 + + del_freq_1=(tmp_seq_1.count('-'))/len_seq_1 if len_seq_1>0 else 0 + ins_freq_1=tmp_seq_1.count('+')/len_seq_1 if len_seq_1>0 else 0 + + if 0.3<=del_freq_0 or 0.3<=del_freq_1 or 0.3<=ins_freq_0 or 0.3<=ins_freq_1: + output['high'].append(pcol.pos+1) + + elif del_freq_0<=0.2 and del_freq_1<=0.2 and ins_freq_0<=0.1 and ins_freq_1<=0.1 and np.random.randint(100)==0: + output['low'].append(pcol.pos+1) + + + if output['pos']: + tr_len=len(output['pos']) + else: + tr_len=20 + + sizes={'high':tr_len, 'low':tr_len} + + output['pos']=set(output['pos']) + + + for i in ['high','low']: + if sizes[i] maxcov: + sample=random.sample(sample,min(len(sample),maxcov)) + + sample=sorted(sample) + + fa_tmp_file=''.join(['>%s_SEQ\n%s\n'%(read_name,seq_list[read_name]) for read_name in sample]) + + + fa_tmp_file+='>ref_SEQ\n%s' %ref + + gap_penalty=1.0 + msa_process =Popen(['muscle', '-quiet','-gapopen','%.1f' %gap_penalty,'-maxiters', '1' ,'-diags1'], stdout=PIPE, stdin=PIPE, stderr=PIPE) + hap_file=msa_process.communicate(input=fa_tmp_file.encode('utf-8')) + + if len(hap_file)==0: + print('hapfile length 0') + + + tmp=hap_file[0].decode('utf-8')[1:].replace('\n','').split('>') + + zz_0=[] + for seq in tmp: + p1,p2=seq.split('_SEQ') + if p1!='ref': + zz_0.append(p2[:128]) + else: + ref_real_0=p2 + + if len(zz_0)=dct['mincov'] and len(read_names_1)>=dct['mincov']: + output['pos'].append(pcol.pos+1) + + if in_bed(include_intervals, v_pos) and not ex_bed(exclude_intervals, v_pos): + read_names=pcol.get_query_names() + read_names_0=set(read_names) & hap_reads_0 + read_names_1=set(read_names) & hap_reads_1 + len_seq_0=len(read_names_0) + len_seq_1=len(read_names_1) + + if len_seq_0>=dct['mincov'] and len_seq_1>=dct['mincov']: + seq=[x for x in pcol.get_query_sequences( mark_matches=False, mark_ends=False, add_indels=True)] + + tmp_seq_0=''.join([s for n,s in zip(read_names,seq) if n in read_names_0]) + tmp_seq_1=''.join([s for n,s in zip(read_names,seq) if n in read_names_1]) + + del_freq_0=(tmp_seq_0.count('-'))/len_seq_0 if len_seq_0>0 else 0 + ins_freq_0=tmp_seq_0.count('+')/len_seq_0 if len_seq_0>0 else 0 + + del_freq_1=(tmp_seq_1.count('-'))/len_seq_1 if len_seq_1>0 else 0 + ins_freq_1=tmp_seq_1.count('+')/len_seq_1 if len_seq_1>0 else 0 + + if 0.3<=del_freq_0 or 0.3<=del_freq_1 or 0.3<=ins_freq_0 or 0.3<=ins_freq_1: + output['high'].append(pcol.pos+1) + + elif del_freq_0<=0.2 and del_freq_1<=0.2 and ins_freq_0<=0.1 and ins_freq_1<=0.1 and np.random.randint(100)==0: + output['low'].append(pcol.pos+1) + + + if output['pos']: + tr_len=len(output['pos']) + else: + tr_len=20 + + sizes={'high':tr_len, 'low':tr_len} + + output['pos']=set(output['pos']) + + + for i in ['high','low']: + if sizes[i] %s' %(str(datetime.datetime.now()),params['rt_path'])) + saver.restore(sess, params['rt_path']) + + stats,v_stats=[],[] + + count=0 + + save_num=1 + + for k in range(training_iters): + t=time.time() + print('\n'+100*'-'+'\n',flush=True) + print('%s: Starting epoch #: %d\n' %(str(datetime.datetime.now()),k),flush=True) + + for genome in genomes_list: + print('%s: Training on genome %s \n' %(str(datetime.datetime.now()), genome.name), flush=True) + + chunk_list=genome.chunk_list + print('Progress\n\nTotal: %s\nDone : ' %('.'*len(chunk_list)),end='', flush=True) + + for chunk in chunk_list: + training_loss, total_train_data, train_acc = 0, 0, 0 + + x_train_sparse,y_train,train_allele,train_ref=genome.training_data[chunk] + + if sparse: + x_train_flat=np.array(x_train_sparse.todense()) + x_train=x_train_flat.reshape(x_train_flat.shape[1]//1025,5,41,5) + else: + x_train=x_train_sparse + + for batch in range(len(x_train)//batch_size): + batch_x = x_train[batch*batch_size:min((batch+1)*batch_size, len(x_train))] + + batch_y = y_train[batch*batch_size:min((batch+1)*batch_size, len(y_train))] + + batch_ref = train_ref[batch*batch_size :min((batch+1)*batch_size, len(train_ref))] + + batch_allele = train_allele[batch*batch_size :min((batch+1)*batch_size, len(train_allele))] + + opt,loss,batch_acc = sess.run([optimizer,cost,accuracy], feed_dict={x: batch_x, GT_label:batch_y, A_label:np.eye(2)[batch_allele[:,0]], G_label:np.eye(2)[batch_allele[:,1]], T_label:np.eye(2)[batch_allele[:,2]], C_label:np.eye(2)[batch_allele[:,3]] , A_ref:batch_ref[:,0][:,np.newaxis], G_ref:batch_ref[:,1][:,np.newaxis], T_ref:batch_ref[:,2][:,np.newaxis], C_ref:batch_ref[:,3][:,np.newaxis], keep:0.5}) + + training_loss+=loss*len(batch_x) + total_train_data+=len(batch_x) + train_acc+=batch_acc + + + + print('.',end='',flush=True) + + training_loss=training_loss/total_train_data + train_acc=train_acc/total_train_data + + print('\n\nGenome: %s Training Loss: %.4f Training Accuracy: %.4f\n' %(genome.name, training_loss, train_acc), flush=True) + + if params['val']: + print(50*'*'+'\n', flush=True) + print('%s: Performing Validation\n' %str(datetime.datetime.now()), flush=True) + + for val_genome in genomes_list: + vx_test, vy_test, vtest_allele, vtest_ref = val_genome.testing_data + + v_loss,v_acc,A_acc,G_acc,T_acc,C_acc,GT_acc,v_loss,v_acc=0,0,0,0,0,0,0,0,0 + + for batch in range(int(np.ceil(len(vx_test)/(batch_size)))): + vbatch_x = vx_test[batch*batch_size:min((batch+1)*batch_size,len(vx_test))] + vbatch_y = vy_test[batch*batch_size:min((batch+1)*batch_size,len(vx_test))] + vbatch_allele = vtest_allele[batch*batch_size:min((batch+1)*batch_size,len(vx_test))] + vbatch_ref = vtest_ref[batch*batch_size:min((batch+1)*batch_size,len(vx_test))] + + batch_loss,batch_acc, batch_GT_acc, batch_A_acc, batch_G_acc, batch_T_acc, batch_C_acc,batch_prediction_GT, batch_prediction_A, batch_prediction_G, batch_prediction_T, batch_prediction_C,bGT_score,bA_score,bG_score,bT_score,bC_score = sess.run([cost, accuracy, accuracy_GT, accuracy_A, accuracy_G, accuracy_T, accuracy_C,prediction_GT, prediction_A, prediction_G, prediction_T, prediction_C,GT_score,A_score,G_score,T_score,C_score], feed_dict={x: vbatch_x,GT_label:vbatch_y, A_label:np.eye(2)[vbatch_allele[:,0]], G_label:np.eye(2)[vbatch_allele[:,1]], T_label:np.eye(2)[vbatch_allele[:,2]], C_label:np.eye(2)[vbatch_allele[:,3]], A_ref:vbatch_ref[:,0][:,np.newaxis], G_ref:vbatch_ref[:,1][:,np.newaxis], T_ref:vbatch_ref[:,2][:,np.newaxis], C_ref:vbatch_ref[:,3][:,np.newaxis], keep:1.0}) + + v_loss+=batch_loss*len(vbatch_x) + v_acc+=batch_acc + A_acc+=batch_A_acc + G_acc+=batch_G_acc + T_acc+=batch_T_acc + C_acc+=batch_C_acc + GT_acc+=batch_GT_acc + + print('Genome: %s Validation Loss: %.4f Validation Accuracy: %.4f' %(val_genome.name, v_loss/len(vx_test), v_acc/len(vx_test)), flush=True) + + print('Genome: %s GT_acc=%.4f A_acc=%.4f G_acc=%.4f T_acc=%.4f C_acc=%.4f\n' %(val_genome.name, GT_acc/len(vx_test), A_acc/len(vx_test), G_acc/len(vx_test), T_acc/len(vx_test), C_acc/len(vx_test)), flush=True) + + print(50*'*'+'\n', flush=True) + saver.save(sess, save_path=os.path.join(params['model'],'model'),global_step=save_num) + elapsed=time.time()-t + save_num+=1 + + print ('%s: Time Taken for Iteration %d: %.2f seconds\n' %(str(datetime.datetime.now()),k,elapsed), flush=True) + + + + +def get_data(fname,pool,a=None, b=None,dims=(5,41,5)): + t=time.time() + l=os.stat(fname).st_size + + rec_size=15+dims[0]*dims[1]*dims[2]*6 + if a!=None and b!=None: + my_array=[(fname,x,dims) for x in range(a,b,1000*rec_size)] + else: + my_array=[(fname,x,dims) for x in range(0,l,1000*rec_size)] + + results = pool.map(read_pileups_from_file, my_array) + + pos=np.vstack([res[0][:,np.newaxis] for res in results]) + mat=np.vstack([res[1] for res in results]) + + + ref=np.vstack([res[2] for res in results]) + allele, gt=None, None + + allele=np.vstack([res[3] for res in results]) + + gt=np.vstack([res[4] for res in results]) + + + return mat,gt,allele,ref + +def read_pileups_from_file(options): + + fname,n,dims=options + file= open(fname,'r') + file.seek(n) + mat=[] + pos=[] + ref=[] + allele1,allele2=[],[] + gt=[] + + dp=[] + freq=[] + + i=0 + while i<1000: + i+=1 + c=file.read(15) + if not c: + break + pos.append(int(c[:11])) + gt.append(int(c[11])) + allele1.append(int(c[12])) + allele2.append(int(c[13])) + ref.append(int(c[14])) + + + + m=file.read(dims[0]*dims[1]*dims[2]*6) + p_mat=np.array([int(m[6*i:6*i+6]) for i in range(dims[0]*dims[1]*dims[2])]).reshape((dims[0],dims[1],dims[2])) + + mat.append(p_mat) + + mat=np.array(mat) + pos=np.array(pos) + ref=np.eye(4)[np.array(ref)].astype(np.int8) + allele1=np.eye(4)[np.array(allele1)].astype(np.int8) + allele2=np.eye(4)[np.array(allele2)].astype(np.int8) + + allele=allele1+allele2 + + allele[allele==2]=1 + + gt=np.eye(2)[np.array(gt)].astype(np.int8) + + return (pos,mat.astype(np.int16),ref,allele,gt) + +if __name__ == '__main__': + t=time.time() + + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--rate", help="Learning rate",type=float) + parser.add_argument("-i", "--iterations", help="Training iterations",type=int) + parser.add_argument("-s", "--size", help="Batch size",type=int) + parser.add_argument("-train", "--train", help="Train path") + parser.add_argument("-model", "--model", help="Model output path") + parser.add_argument("-cpu", "--cpu", help="CPUs",type=int) + parser.add_argument("-rt_path", "--retrain_path", help="Retrain saved model",type=str) + parser.add_argument('-sparse','--sparse', help='Stores features as sparse matrices', default=False, action='store_true') + parser.add_argument('-val','--validation', help='Perform validation', default=True, action='store_false') + + args = parser.parse_args() + + + os.makedirs(args.model, exist_ok=True) + + in_dict={'cpu':args.cpu,'rate':args.rate, 'iters':args.iterations, 'size':args.size,\ + 'train_path':args.train, 'model':args.model, 'val':args.validation,'rt_path':args.retrain_path,\ + 'sparse':args.sparse} + + train_SNP_model(in_dict) + + elapsed=time.time()-t + print ('Total Time Elapsed: %.2f seconds' %elapsed) \ No newline at end of file diff --git a/misc/training/model_run_indels.py b/misc/training/model_run_indels.py new file mode 100644 index 0000000..a7f2233 --- /dev/null +++ b/misc/training/model_run_indels.py @@ -0,0 +1,260 @@ +from warnings import simplefilter +simplefilter(action='ignore', category=FutureWarning) + +import time, os, copy, argparse, subprocess, glob, re, datetime +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import numpy as np +import multiprocessing as mp +import tensorflow as tf +from model_architect_indel import * + +if type(tf.contrib) != type(tf): tf.contrib._warning = None + + +config = tf.ConfigProto(device_count={"CPU": 32}) +#config = tf.ConfigProto() +config.gpu_options.allow_growth = True + + +rev_gt_map={0:'hom-ref', 1:'hom-alt', 2:'het-ref', 3:'het-alt'} +rev_base_map={0:'A',1:'G',2:'T',3:'C',4:'-'} + +class Genome: + def __init__(self, name, path): + self.name = name + self.path = path + self.training_data = {} + self.testing_data = None + self.chunk_list=[] + + def add_data(self, pool, val): + print('%s: Starting reading genome %s.' %(str(datetime.datetime.now()), self.name), flush=True) + train_list=glob.glob(os.path.join(self.path,'train*')) + chunk_list=[re.findall('pileups.(\d+)', x)[0] for x in train_list] + + print('\nProgress\n\nTotal: %s\nDone : ' %('.'*len(chunk_list)),end='', flush=True) + + for chunk, path in zip(chunk_list, train_list): + self.chunk_list.append(chunk) + x_train, train_gt= get_data(path,pool) + + self.training_data[chunk]=(x_train, train_gt) + + print('.',end='',flush=True) + + if val: + self.testing_data = get_data(os.path.join(self.path,'test.pileups'),pool) + + print('\n\n%s: Finished reading genome %s.\n' %(str(datetime.datetime.now()), self.name), flush=True) + +def train_indel_model(params): + + + tf.reset_default_graph() + cpu=params['cpu'] + params['val']=True + dims=[5,41,5] + n_input=dims + + pool = mp.Pool(processes=cpu) + + print('%s: Starting reading pileups.' %str(datetime.datetime.now()),flush=True) + + genomes_list=[] + with open(params['train_path'],'r') as file: + for line in file: + x=line.rstrip('\n').split(',') + + current_genome=Genome(*x) + current_genome.add_data(pool, params['val']) + + genomes_list.append(current_genome) + + pool.close() + pool.join() + + print('\n%s: Finished reading pileups.' %str(datetime.datetime.now()),flush=True) + + training_iters, learning_rate, batch_size= params['iters'],\ + params['rate'], params['size'] + + training_iters, learning_rate, batch_size= params['iters'],\ + params['rate'], params['size'] + + weights,biases,tensors=get_tensors([5,128,2],learning_rate) + (x0, x1,x2,gt, accuracy, cost, optimizer, prob, rate)=tensors + + + init = tf.global_variables_initializer() + saver = tf.train.Saver(max_to_keep=10000) + + + with tf.Session(config=config) as sess: + sess.run(init) + sess.run(tf.local_variables_initializer()) + if params['rt_path']: + print('%s: Retraining model:> %s' %(str(datetime.datetime.now()),params['rt_path'])) + saver.restore(sess, params['rt_path']) + + stats,v_stats=[],[] + + count=0 + + save_num=1 + + for k in range(training_iters): + t=time.time() + print('\n'+100*'-'+'\n',flush=True) + print('%s: Starting epoch #: %d\n' %(str(datetime.datetime.now()),k),flush=True) + + for genome in genomes_list: + print('%s: Training on genome %s \n' %(str(datetime.datetime.now()), genome.name), flush=True) + + chunk_list=genome.chunk_list + print('Progress\n\nTotal: %s\nDone : ' %('.'*len(chunk_list)),end='', flush=True) + + for chunk in chunk_list: + training_loss, total_train_data, train_acc = 0, 0, 0 + + x_train, train_gt=genome.training_data[chunk] + + for batch in range(len(x_train)//batch_size): + batch_x = x_train[batch*batch_size:min((batch+1)*batch_size, len(x_train))] + batch_gt = train_gt[batch*batch_size:min((batch+1)*batch_size, len(train_gt))] + + opt, loss, batch_acc = sess.run([optimizer,cost,accuracy], feed_dict={x2: batch_x, gt:batch_gt, rate:0.5}) + + training_loss+=loss*len(batch_x) + total_train_data+=len(batch_x) + train_acc+=batch_acc + + + + print('.',end='',flush=True) + + training_loss=training_loss/total_train_data + train_acc=train_acc/total_train_data + + print('\n\nGenome: %s Training Loss: %.4f Training Accuracy: %.4f\n' %(genome.name, training_loss, train_acc), flush=True) + + if params['val']: + print(50*'*'+'\n', flush=True) + print('%s: Performing Validation\n' %str(datetime.datetime.now()), flush=True) + + for val_genome in genomes_list: + vx_test, vtest_gt= val_genome.testing_data + + v_loss,v_acc,total_test_data=0,0,0 + + for batch in range(int(np.ceil(len(vx_test)/(batch_size)))): + vbatch_x = vx_test[batch*batch_size:min((batch+1)*batch_size,len(vx_test))] + vbatch_gt = vtest_gt[batch*batch_size:min((batch+1)*batch_size,len(vtest_gt))] + + batch_loss,batch_acc= sess.run([cost, accuracy], feed_dict={x2: vbatch_x, gt:vbatch_gt, rate:0.0}) + v_loss+=batch_loss*len(vbatch_x) + v_acc+=batch_acc + total_test_data+=len(vbatch_x) + + print('Genome: %s Validation Loss: %.4f Validation Accuracy: %.4f' %(val_genome.name, v_loss/total_test_data, v_acc/total_test_data), flush=True) + + print(50*'*'+'\n', flush=True) + saver.save(sess, save_path=os.path.join(params['model'],'model'),global_step=save_num) + elapsed=time.time()-t + save_num+=1 + + print ('%s: Time Taken for Iteration %d: %.2f seconds\n' %(str(datetime.datetime.now()),k,elapsed), flush=True) + + + + +def get_data(fname,pool): + t=time.time() + l=os.stat(fname).st_size + dims=(5,128,2) + + rec_size=12+dims[0]*dims[1]*dims[2]*4*3 + my_array=[(fname,x,'train',dims) for x in range(0,l,1000*rec_size)] + + results = pool.map(read_pileups_from_file, my_array) + + mat=np.vstack([res[0] for res in results]) + gt=np.vstack([res[1] for res in results]) + + return mat,gt + +def read_pileups_from_file(options): + fname,n,mode,dims=options + file= open(fname,'r') + file.seek(n) + + mat_0,mat_1,mat_2=[],[],[] + pos=[] + gt=[] + + i=0 + + while i<1000: + i+=1 + c=file.read(12) + if not c: + break + + gt.append(int(c[11])) + + m=file.read(dims[0]*dims[1]*dims[2]*4) + p_mat_0=np.array([int(float(m[4*i:4*i+4])) for i in range(dims[0]*dims[1]*dims[2])]).reshape((dims[0],dims[1],dims[2])) + mat_0.append(p_mat_0) + + m=file.read(dims[0]*dims[1]*dims[2]*4) + p_mat_1=np.array([int(float(m[4*i:4*i+4])) for i in range(dims[0]*dims[1]*dims[2])]).reshape((dims[0],dims[1],dims[2])) + mat_1.append(p_mat_1) + + m=file.read(dims[0]*dims[1]*dims[2]*4) + p_mat_2=np.array([int(float(m[4*i:4*i+4])) for i in range(dims[0]*dims[1]*dims[2])]).reshape((dims[0],dims[1],dims[2])) + mat_2.append(p_mat_2) + + mat_0=norm(np.array(mat_0)) + mat_1=norm(np.array(mat_1)) + mat_2=norm(np.array(mat_2)) + + mat=np.hstack([mat_0, mat_1, mat_2]) + + gt=np.array(gt) + gt=np.eye(4)[gt].astype(np.int8) + + return (mat,gt) + + +def norm(batch_x0): + batch_x0=batch_x0.astype(np.float32) + batch_x0[:,:,:,0]=batch_x0[:,:,:,0]/(np.sum(batch_x0[:,:,:,0],axis=1)[:,np.newaxis,:])-batch_x0[:,:,:,1] + return batch_x0 + +if __name__ == '__main__': + t=time.time() + + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--rate", help="Learning rate",type=float) + parser.add_argument("-i", "--iterations", help="Training iterations",type=int) + parser.add_argument("-s", "--size", help="Batch size",type=int) + parser.add_argument("-train", "--train", help="Train path") + parser.add_argument("-model", "--model", help="Model output path") + parser.add_argument("-cpu", "--cpu", help="CPUs",type=int) + parser.add_argument("-rt_path", "--retrain_path", help="Retrain saved model",type=str) + parser.add_argument('-sparse','--sparse', help='Stores features as sparse matrices', default=False, action='store_true') + parser.add_argument('-val','--validation', help='Perform validation', default=True, action='store_false') + + args = parser.parse_args() + + print('aasas',flush=True) + os.makedirs(args.model, exist_ok=True) + + in_dict={'cpu':args.cpu,'rate':args.rate, 'iters':args.iterations, 'size':args.size,\ + 'train_path':args.train, 'model':args.model, 'val':args.validation,'rt_path':args.retrain_path,\ + 'sparse':args.sparse} + + train_indel_model(in_dict) + + elapsed=time.time()-t + print ('Total Time Elapsed: %.2f seconds' %elapsed) \ No newline at end of file