Skip to content

Commit

Permalink
fixed slow data parsing and clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
yazhinia committed Jan 2, 2025
1 parent 08c1fb2 commit c31b19d
Showing 1 changed file with 3 additions and 203 deletions.
206 changes: 3 additions & 203 deletions mcdevol/byol_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
from torch.optim.lr_scheduler import CosineAnnealingLR, _LRScheduler
import logging
from torch.cuda.amp import autocast, GradScaler
# from lightning import Fabric

# TODO: speedup the process
# autocast and GradScaler helps # worked
# deepseed using fabric didn't speed up with single device
class WarmUpLR(_LRScheduler):
def __init__(self, optimizer, total_iters, last_epoch=-1):
self.total_iters = max(1, total_iters)
Expand Down Expand Up @@ -243,33 +239,6 @@ def data_augment(
""" augment read counts """
rcounts_sampled = rcounts.clone().detach()

# if not self.multi_split:
# condition_short = (contigs_length > 4000) & (contigs_length <= 8000) # used for gs pooled assembly
# # condition_short = (contigs_length > 3000) & (contigs_length <= 8000) # used for pooled assembly
# condition_medium1 = (contigs_length > 8000) & (contigs_length <= 16000)
# condition_medium2 = (contigs_length > 16000) & (contigs_length <= 30000)
# condition_long = contigs_length > 30000

# # for multisplit binning
# else:
# condition_short = (contigs_length > 50000) & (contigs_length <= 100000)
# condition_medium1 = (contigs_length > 100000) & (contigs_length <= 500000)
# condition_medium2 = (contigs_length > 500000) & (contigs_length <= 800000)
# condition_long = contigs_length > 800000

# if condition_short.any():
# # always samples short contigs with 0.9 fraction
# rcounts[condition_short] = drawsample_frombinomial(rcounts[condition_short], 0.9)
# if condition_medium1.any():
# # always samples medium contigs with 0.7 fraction
# rcounts[condition_medium1] = drawsample_frombinomial(rcounts[condition_medium1], 0.7)
# if condition_medium2.any():
# # always samples medium contigs with 0.6 fraction
# rcounts[condition_medium2] = drawsample_frombinomial(rcounts[condition_medium2], 0.6)
# if condition_long.any():
# # sample longer contigs based on fraction_pi passed to the function
# rcounts[condition_long] = drawsample_frombinomial(rcounts[condition_long], fraction_pi)

# Define contig length conditions
if not self.multi_split:
length_conditions = {
Expand All @@ -294,16 +263,6 @@ def data_augment(
'long': fraction_pi
}

# # Adjust sampling fractions based on the value of fraction_pi
# if fraction_pi == 0.9:
# sampling_fractions = {key: 0.9 for key in sampling_fractions} # All fractions set to 0.9
# elif fraction_pi == 0.7:
# sampling_fractions.update({'medium1': 0.7, 'medium2': 0.7, 'long': 0.7}) # Set all except 'short' to 0.7
# elif fraction_pi == 0.6:
# sampling_fractions.update({'medium2': 0.6, 'long': 0.6}) # Set 'medium2' and 'long' to 0.6
# elif fraction_pi == 0.5:
# sampling_fractions['long'] = 0.5 # Only 'long' is set to 0.5

# Apply sampling augmentation based on conditions
for key, condition in length_conditions.items():
if condition.any():
Expand Down Expand Up @@ -457,128 +416,14 @@ def apply_augmentations(
kmeraug_target = kmeraug_target.cuda()
return augmented_online, augmented_target, kmeraug_online, kmeraug_target

# def process_batches_withpairs(self,
# epoch: int,
# dataloader,
# training: bool,
# *args):
# """ process batches """

# epoch_losses = []
# epoch_loss = 0.0

# for in_data in dataloader:

# pairs = in_data
# pairindices_1 = pairs[:, 0].to(self.read_counts.device)
# pairindices_2 = pairs[:, 1].to(self.read_counts.device)

# if training:
# loss_array, latent_space, fraction_pi = args
# self.optimizer.zero_grad()

# else:
# loss_array, latent_space = args

# with autocast():
# if training:
# ### augmentation by fragmentation ###
# kmeraug_online, kmeraug_target = random.sample([self.kmeraug1, \
# self.kmeraug2, self.kmeraug3, self.kmeraug4, self.kmeraug5, self.kmeraug6],2)
# augmented_reads1 = self.data_augment(self.rawread_counts[pairindices_1], \
# self.contigs_length[pairindices_1], fraction_pi)
# augmented_reads2 = self.data_augment(self.rawread_counts[pairindices_2], \
# self.contigs_length[pairindices_2], fraction_pi)
# augmented_kmers1 = kmeraug_online[pairindices_1]
# augmented_kmers2 = kmeraug_target[pairindices_2]

# if self.usecuda:
# augmented_reads1 = augmented_reads1.cuda()
# augmented_reads2 = augmented_reads2.cuda()
# augmented_kmers1 = augmented_kmers1.cuda()
# augmented_kmers2 = augmented_kmers2.cuda()

# # rc_reads1 = torch.log(augmented_reads1.sum(axis=1))
# # rc_reads2 = torch.log(augmented_reads2.sum(axis=1))
# latent, loss = \
# self(torch.cat((augmented_reads1, augmented_kmers1), 1), \
# torch.cat((augmented_reads2, augmented_kmers2), 1), True)
# else:
# augmented_reads = self.read_counts[pairindices_1]
# augmented_kmers = self.kmer[pairindices_1]

# if self.usecuda:
# augmented_reads = augmented_reads.cuda()
# augmented_kmers = augmented_kmers.cuda()

# # rc_reads = torch.log(augmented_reads1.sum(axis=1))
# input1 = torch.cat((augmented_reads, augmented_kmers), 1)
# latent, loss = self(input1, input1, True)
# loss_array.append(loss.data.item())
# latent_space.append(latent.cpu().detach().numpy())

# if training:
# # loss.backward()
# # self.fabric.backward(loss)
# # optimizer.step()
# self.scaler.scale(loss).backward() # type: ignore
# self.scaler.step(self.optimizer) # type: ignore
# self.scaler.update() # type: ignore

# self.update_moving_average()

# epoch_loss += loss.detach().data.item()

# epoch_losses.extend([epoch_loss])
# self.logger.info(f'{epoch}: pair loss={epoch_loss}')

# def pair_train(
# self,
# epoch,
# dataloader_pairtrain,
# dataloader_pairval,
# fraction_pi,
# loss_train,
# loss_val,
# ):
# epoch_list = [120]
# if epoch in epoch_list:
# if self.usecuda:
# reads = self.read_counts[self.markercontigs].cuda()
# kmers = self.kmer[self.markercontigs].cuda()
# latent, loss = self(torch.cat((reads, kmers), 1), \
# torch.cat((reads, kmers), 1))
# print(loss, 'loss before', epoch)
# np.save(self.outdir+f'latent_before_ng{epoch}',latent.cpu().detach().numpy())
# for epoch_pair in range(50):
# self.train()
# latent_space_train = []

# # initialize target network
# self.initialize_target_network()
# self.process_batches_withpairs(epoch_pair, dataloader_pairtrain, \
# True, loss_train, latent_space_train, fraction_pi)
# self.eval()
# latent_space_val = []

# with torch.no_grad():
# self.process_batches_withpairs(epoch_pair, dataloader_pairval, \
# False, loss_val, latent_space_val)
# latent, loss = self(torch.cat((reads, kmers), 1), \
# torch.cat((reads, kmers), 1))
# print(loss, 'loss after', epoch)
# np.save(self.outdir+f'latent_after_ng{epoch}',latent.cpu().detach().numpy())
# # self.getlatent(name='after_scmg')



def trainepoch(
self,
nepochs: int,
dataloader_train,
dataloader_val,
batchsteps,
# dataloader_pairtrain,
# dataloader_pairval,
):
""" training epoch """

Expand Down Expand Up @@ -644,22 +489,7 @@ def trainmodel(
dataloader_val = DataLoader(dataset=self.dataset_val, \
batch_size=self.batch_size, drop_last=True, shuffle=True, \
num_workers=self.num_workers, pin_memory=self.cuda)

# # split read mapping dataloader
# dataloader_pairtrain = DataLoader(dataset=self.pairs_train, \
# batch_size= 4096, shuffle=True, drop_last=True, \
# num_workers=self.num_workers, pin_memory=self.cuda)
# dataloader_pairval = DataLoader(dataset=self.pairs_val, \
# batch_size= 4096, shuffle=True, drop_last=True, \
# num_workers=self.num_workers, pin_memory=self.cuda)


# dataloader_train, dataloader_val, \
# dataloader_pairtrain, dataloader_pairval = \
# self.fabric.setup_dataloaders(\
# dataloader_train, dataloader_val,\
# dataloader_pairtrain, dataloader_pairval)


warmup_epochs = 20
warmup_scheduler = WarmUpLR(self.optimizer, total_iters=warmup_epochs * len(dataloader_train))

Expand Down Expand Up @@ -722,14 +552,13 @@ def run(abundance_matrix, outdir, contig_length, contig_names, multi_split, ncpu
level=logging.INFO, datefmt='%d-%b-%y %H:%M:%S',
filename=outdir + 'byol_training.log', filemode='w')
logger = logging.getLogger()
# logger.propagate = False
logger.info('BYOL Training started')
base_names = [''] + [f'{i}' for i in range(1, 7)]
logger.info(f'abundance matrix shape: {abundance_matrix.shape}')
logger.info(f'contig length shape: {contig_length.size}')
logger.info(f'contig names shape: {contig_names.size}')
# filter contigs with low abundances

# filter contigs with low abundances
nonzeroindices = np.nonzero(abundance_matrix.sum(axis=1)>1.5)[0]
print(len(np.nonzero(abundance_matrix.sum(axis=1)==0.0)[0]), 'contigs with zero total counts')
if len(nonzeroindices) < contig_length.size:
Expand All @@ -752,7 +581,6 @@ def run(abundance_matrix, outdir, contig_length, contig_names, multi_split, ncpu
kmerdata_tmp = np.load(arg_name, allow_pickle=True).astype(np.float32)
kmer_data[keyname+name] = kmerdata_tmp[nonzeroindices]

# np.save(os.path.join(outdir, 'abundance_matrix.npy'), abundance_matrix)
byol = BYOLmodel(abundance_matrix, kmer_data, contig_length, outdir, logger, multi_split, ncpus)
byol.trainmodel()
latent = byol.getlatent()
Expand Down Expand Up @@ -781,10 +609,6 @@ def main() -> None:
help="length of contigs in bp", required=True)
parser.add_argument("--names", type=str, \
help="ids of contigs", required=True)
# parser.add_argument("--pairlinks", type=str, \
# help="provide pair links array", required=True)
# parser.add_argument("--otuids", type=str, \
# help="otuids of contigs", required=True)
parser.add_argument("--kmer", type=str, \
help='kmer embedding', required=True)
parser.add_argument("--kmeraug1", type=str, \
Expand All @@ -799,8 +623,6 @@ def main() -> None:
help='kmer embedding augment 5', required=True)
parser.add_argument("--kmeraug6", type=str, \
help='kmer embedding augment 6', required=True)
# parser.add_argument("--marker", type=str, \
# help="marker genes hit", required=True)
parser.add_argument("--outdir", type=str, \
help="output directory", required=True)
parser.add_argument("--nlatent", type=int, \
Expand All @@ -813,7 +635,6 @@ def main() -> None:
args.reads = np.load(args.reads, allow_pickle=True)['arr_0']
args.length = np.load(args.length, allow_pickle=True)['arr_0']
args.names = np.load(args.names, allow_pickle=True)['arr_0']
# args.pairlinks = np.load(args.pairlinks, allow_pickle='True')
args.kmer = np.load(args.kmer, allow_pickle=True).astype(np.float32)
args.kmeraug1 = np.load(args.kmeraug1, allow_pickle=True).astype(np.float32)
args.kmeraug2 = np.load(args.kmeraug2, allow_pickle=True).astype(np.float32)
Expand All @@ -828,25 +649,6 @@ def main() -> None:
for name in base_names:
kmer_data[name] = getattr(args, name)

# scale kmer input
# first attempt * 512 1-3
# second attempt * 100 1-3
# third attempt * 10 1-3
# four attempt * 150 1-3
# five attempt * 200 1-3
# args.kmer = args.kmer * 200
# args.kmeraug1 = args.kmeraug1 * 200
# args.kmeraug2 = args.kmeraug2 * 200

# args.marker = pd.read_csv(args.marker, header=None, sep='\t')
# names_indices = {name: i for i, name in enumerate(args.names)}
# args.marker[2] = args.marker[0].map(names_indices)
# del names_indices
# # remove contigs having scmg but shorter than threshold length
# args.marker = args.marker.dropna(subset=[2])
# args.marker[2] = args.marker[2].astype('int')
# args.marker = dict(args.marker.groupby(1)[2].apply(list))

args.outdir = os.path.join(args.outdir, '')

try:
Expand All @@ -863,8 +665,6 @@ def main() -> None:

byol = BYOLmodel(args.reads, kmer_data, args.contig_length, args.outdir, args.logger, False)

# total_params = sum(p.numel() for p in byol.parameters() if p.requires_grad)

byol.trainmodel()
# byol.testmodel()
latent = byol.getlatent()
Expand Down

0 comments on commit c31b19d

Please sign in to comment.