diff --git a/mcdevol/byol_model.py b/mcdevol/byol_model.py index 53d59d4..6a0d387 100644 --- a/mcdevol/byol_model.py +++ b/mcdevol/byol_model.py @@ -19,11 +19,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR, _LRScheduler import logging from torch.cuda.amp import autocast, GradScaler -# from lightning import Fabric -# TODO: speedup the process -# autocast and GradScaler helps # worked -# deepseed using fabric didn't speed up with single device class WarmUpLR(_LRScheduler): def __init__(self, optimizer, total_iters, last_epoch=-1): self.total_iters = max(1, total_iters) @@ -243,33 +239,6 @@ def data_augment( """ augment read counts """ rcounts_sampled = rcounts.clone().detach() - # if not self.multi_split: - # condition_short = (contigs_length > 4000) & (contigs_length <= 8000) # used for gs pooled assembly - # # condition_short = (contigs_length > 3000) & (contigs_length <= 8000) # used for pooled assembly - # condition_medium1 = (contigs_length > 8000) & (contigs_length <= 16000) - # condition_medium2 = (contigs_length > 16000) & (contigs_length <= 30000) - # condition_long = contigs_length > 30000 - - # # for multisplit binning - # else: - # condition_short = (contigs_length > 50000) & (contigs_length <= 100000) - # condition_medium1 = (contigs_length > 100000) & (contigs_length <= 500000) - # condition_medium2 = (contigs_length > 500000) & (contigs_length <= 800000) - # condition_long = contigs_length > 800000 - - # if condition_short.any(): - # # always samples short contigs with 0.9 fraction - # rcounts[condition_short] = drawsample_frombinomial(rcounts[condition_short], 0.9) - # if condition_medium1.any(): - # # always samples medium contigs with 0.7 fraction - # rcounts[condition_medium1] = drawsample_frombinomial(rcounts[condition_medium1], 0.7) - # if condition_medium2.any(): - # # always samples medium contigs with 0.6 fraction - # rcounts[condition_medium2] = drawsample_frombinomial(rcounts[condition_medium2], 0.6) - # if condition_long.any(): - # # sample longer contigs based on fraction_pi passed to the function - # rcounts[condition_long] = drawsample_frombinomial(rcounts[condition_long], fraction_pi) - # Define contig length conditions if not self.multi_split: length_conditions = { @@ -294,16 +263,6 @@ def data_augment( 'long': fraction_pi } - # # Adjust sampling fractions based on the value of fraction_pi - # if fraction_pi == 0.9: - # sampling_fractions = {key: 0.9 for key in sampling_fractions} # All fractions set to 0.9 - # elif fraction_pi == 0.7: - # sampling_fractions.update({'medium1': 0.7, 'medium2': 0.7, 'long': 0.7}) # Set all except 'short' to 0.7 - # elif fraction_pi == 0.6: - # sampling_fractions.update({'medium2': 0.6, 'long': 0.6}) # Set 'medium2' and 'long' to 0.6 - # elif fraction_pi == 0.5: - # sampling_fractions['long'] = 0.5 # Only 'long' is set to 0.5 - # Apply sampling augmentation based on conditions for key, condition in length_conditions.items(): if condition.any(): @@ -457,119 +416,7 @@ def apply_augmentations( kmeraug_target = kmeraug_target.cuda() return augmented_online, augmented_target, kmeraug_online, kmeraug_target - # def process_batches_withpairs(self, - # epoch: int, - # dataloader, - # training: bool, - # *args): - # """ process batches """ - - # epoch_losses = [] - # epoch_loss = 0.0 - - # for in_data in dataloader: - - # pairs = in_data - # pairindices_1 = pairs[:, 0].to(self.read_counts.device) - # pairindices_2 = pairs[:, 1].to(self.read_counts.device) - - # if training: - # loss_array, latent_space, fraction_pi = args - # self.optimizer.zero_grad() - - # else: - # loss_array, latent_space = args - - # with autocast(): - # if training: - # ### augmentation by fragmentation ### - # kmeraug_online, kmeraug_target = random.sample([self.kmeraug1, \ - # self.kmeraug2, self.kmeraug3, self.kmeraug4, self.kmeraug5, self.kmeraug6],2) - # augmented_reads1 = self.data_augment(self.rawread_counts[pairindices_1], \ - # self.contigs_length[pairindices_1], fraction_pi) - # augmented_reads2 = self.data_augment(self.rawread_counts[pairindices_2], \ - # self.contigs_length[pairindices_2], fraction_pi) - # augmented_kmers1 = kmeraug_online[pairindices_1] - # augmented_kmers2 = kmeraug_target[pairindices_2] - - # if self.usecuda: - # augmented_reads1 = augmented_reads1.cuda() - # augmented_reads2 = augmented_reads2.cuda() - # augmented_kmers1 = augmented_kmers1.cuda() - # augmented_kmers2 = augmented_kmers2.cuda() - - # # rc_reads1 = torch.log(augmented_reads1.sum(axis=1)) - # # rc_reads2 = torch.log(augmented_reads2.sum(axis=1)) - # latent, loss = \ - # self(torch.cat((augmented_reads1, augmented_kmers1), 1), \ - # torch.cat((augmented_reads2, augmented_kmers2), 1), True) - # else: - # augmented_reads = self.read_counts[pairindices_1] - # augmented_kmers = self.kmer[pairindices_1] - - # if self.usecuda: - # augmented_reads = augmented_reads.cuda() - # augmented_kmers = augmented_kmers.cuda() - - # # rc_reads = torch.log(augmented_reads1.sum(axis=1)) - # input1 = torch.cat((augmented_reads, augmented_kmers), 1) - # latent, loss = self(input1, input1, True) - # loss_array.append(loss.data.item()) - # latent_space.append(latent.cpu().detach().numpy()) - - # if training: - # # loss.backward() - # # self.fabric.backward(loss) - # # optimizer.step() - # self.scaler.scale(loss).backward() # type: ignore - # self.scaler.step(self.optimizer) # type: ignore - # self.scaler.update() # type: ignore - - # self.update_moving_average() - - # epoch_loss += loss.detach().data.item() - - # epoch_losses.extend([epoch_loss]) - # self.logger.info(f'{epoch}: pair loss={epoch_loss}') - - # def pair_train( - # self, - # epoch, - # dataloader_pairtrain, - # dataloader_pairval, - # fraction_pi, - # loss_train, - # loss_val, - # ): - # epoch_list = [120] - # if epoch in epoch_list: - # if self.usecuda: - # reads = self.read_counts[self.markercontigs].cuda() - # kmers = self.kmer[self.markercontigs].cuda() - # latent, loss = self(torch.cat((reads, kmers), 1), \ - # torch.cat((reads, kmers), 1)) - # print(loss, 'loss before', epoch) - # np.save(self.outdir+f'latent_before_ng{epoch}',latent.cpu().detach().numpy()) - # for epoch_pair in range(50): - # self.train() - # latent_space_train = [] - - # # initialize target network - # self.initialize_target_network() - # self.process_batches_withpairs(epoch_pair, dataloader_pairtrain, \ - # True, loss_train, latent_space_train, fraction_pi) - # self.eval() - # latent_space_val = [] - - # with torch.no_grad(): - # self.process_batches_withpairs(epoch_pair, dataloader_pairval, \ - # False, loss_val, latent_space_val) - # latent, loss = self(torch.cat((reads, kmers), 1), \ - # torch.cat((reads, kmers), 1)) - # print(loss, 'loss after', epoch) - # np.save(self.outdir+f'latent_after_ng{epoch}',latent.cpu().detach().numpy()) - # # self.getlatent(name='after_scmg') - + def trainepoch( self, @@ -577,8 +424,6 @@ def trainepoch( dataloader_train, dataloader_val, batchsteps, - # dataloader_pairtrain, - # dataloader_pairval, ): """ training epoch """ @@ -644,22 +489,7 @@ def trainmodel( dataloader_val = DataLoader(dataset=self.dataset_val, \ batch_size=self.batch_size, drop_last=True, shuffle=True, \ num_workers=self.num_workers, pin_memory=self.cuda) - - # # split read mapping dataloader - # dataloader_pairtrain = DataLoader(dataset=self.pairs_train, \ - # batch_size= 4096, shuffle=True, drop_last=True, \ - # num_workers=self.num_workers, pin_memory=self.cuda) - # dataloader_pairval = DataLoader(dataset=self.pairs_val, \ - # batch_size= 4096, shuffle=True, drop_last=True, \ - # num_workers=self.num_workers, pin_memory=self.cuda) - - - # dataloader_train, dataloader_val, \ - # dataloader_pairtrain, dataloader_pairval = \ - # self.fabric.setup_dataloaders(\ - # dataloader_train, dataloader_val,\ - # dataloader_pairtrain, dataloader_pairval) - + warmup_epochs = 20 warmup_scheduler = WarmUpLR(self.optimizer, total_iters=warmup_epochs * len(dataloader_train)) @@ -722,14 +552,13 @@ def run(abundance_matrix, outdir, contig_length, contig_names, multi_split, ncpu level=logging.INFO, datefmt='%d-%b-%y %H:%M:%S', filename=outdir + 'byol_training.log', filemode='w') logger = logging.getLogger() - # logger.propagate = False logger.info('BYOL Training started') base_names = [''] + [f'{i}' for i in range(1, 7)] logger.info(f'abundance matrix shape: {abundance_matrix.shape}') logger.info(f'contig length shape: {contig_length.size}') logger.info(f'contig names shape: {contig_names.size}') - # filter contigs with low abundances + # filter contigs with low abundances nonzeroindices = np.nonzero(abundance_matrix.sum(axis=1)>1.5)[0] print(len(np.nonzero(abundance_matrix.sum(axis=1)==0.0)[0]), 'contigs with zero total counts') if len(nonzeroindices) < contig_length.size: @@ -752,7 +581,6 @@ def run(abundance_matrix, outdir, contig_length, contig_names, multi_split, ncpu kmerdata_tmp = np.load(arg_name, allow_pickle=True).astype(np.float32) kmer_data[keyname+name] = kmerdata_tmp[nonzeroindices] - # np.save(os.path.join(outdir, 'abundance_matrix.npy'), abundance_matrix) byol = BYOLmodel(abundance_matrix, kmer_data, contig_length, outdir, logger, multi_split, ncpus) byol.trainmodel() latent = byol.getlatent() @@ -781,10 +609,6 @@ def main() -> None: help="length of contigs in bp", required=True) parser.add_argument("--names", type=str, \ help="ids of contigs", required=True) - # parser.add_argument("--pairlinks", type=str, \ - # help="provide pair links array", required=True) - # parser.add_argument("--otuids", type=str, \ - # help="otuids of contigs", required=True) parser.add_argument("--kmer", type=str, \ help='kmer embedding', required=True) parser.add_argument("--kmeraug1", type=str, \ @@ -799,8 +623,6 @@ def main() -> None: help='kmer embedding augment 5', required=True) parser.add_argument("--kmeraug6", type=str, \ help='kmer embedding augment 6', required=True) - # parser.add_argument("--marker", type=str, \ - # help="marker genes hit", required=True) parser.add_argument("--outdir", type=str, \ help="output directory", required=True) parser.add_argument("--nlatent", type=int, \ @@ -813,7 +635,6 @@ def main() -> None: args.reads = np.load(args.reads, allow_pickle=True)['arr_0'] args.length = np.load(args.length, allow_pickle=True)['arr_0'] args.names = np.load(args.names, allow_pickle=True)['arr_0'] - # args.pairlinks = np.load(args.pairlinks, allow_pickle='True') args.kmer = np.load(args.kmer, allow_pickle=True).astype(np.float32) args.kmeraug1 = np.load(args.kmeraug1, allow_pickle=True).astype(np.float32) args.kmeraug2 = np.load(args.kmeraug2, allow_pickle=True).astype(np.float32) @@ -828,25 +649,6 @@ def main() -> None: for name in base_names: kmer_data[name] = getattr(args, name) - # scale kmer input - # first attempt * 512 1-3 - # second attempt * 100 1-3 - # third attempt * 10 1-3 - # four attempt * 150 1-3 - # five attempt * 200 1-3 - # args.kmer = args.kmer * 200 - # args.kmeraug1 = args.kmeraug1 * 200 - # args.kmeraug2 = args.kmeraug2 * 200 - - # args.marker = pd.read_csv(args.marker, header=None, sep='\t') - # names_indices = {name: i for i, name in enumerate(args.names)} - # args.marker[2] = args.marker[0].map(names_indices) - # del names_indices - # # remove contigs having scmg but shorter than threshold length - # args.marker = args.marker.dropna(subset=[2]) - # args.marker[2] = args.marker[2].astype('int') - # args.marker = dict(args.marker.groupby(1)[2].apply(list)) - args.outdir = os.path.join(args.outdir, '') try: @@ -863,8 +665,6 @@ def main() -> None: byol = BYOLmodel(args.reads, kmer_data, args.contig_length, args.outdir, args.logger, False) - # total_params = sum(p.numel() for p in byol.parameters() if p.requires_grad) - byol.trainmodel() # byol.testmodel() latent = byol.getlatent()