Skip to content

Commit

Permalink
Changed files from cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
prasadvagdargi committed May 6, 2022
1 parent d7973ad commit e7bdf59
Show file tree
Hide file tree
Showing 117 changed files with 241 additions and 254,934 deletions.
Binary file removed __pycache__/_config.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/barlow.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/barlow.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/barlow.cpython-39.pyc
Binary file not shown.
Binary file removed __pycache__/barlow_utils.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/barlow_utils.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/models.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/models.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/models.cpython-39.pyc
Binary file not shown.
Binary file removed __pycache__/t_dataset.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/t_dataset.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/t_dataset.cpython-39.pyc
Binary file not shown.
Binary file removed __pycache__/train_translation.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/train_translation.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/train_translation.cpython-39.pyc
Binary file not shown.
Binary file removed __pycache__/translation_dataset.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/translation_dataset.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/translation_dataset.cpython-39.pyc
Binary file not shown.
Binary file removed __pycache__/translation_utils.cpython-37.pyc
Binary file not shown.
Binary file removed __pycache__/translation_utils.cpython-38.pyc
Binary file not shown.
Binary file removed __pycache__/translation_utils.cpython-39.pyc
Binary file not shown.
40 changes: 34 additions & 6 deletions t_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# edits: padding=True
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
Expand All @@ -20,40 +21,67 @@ def __init__(self,
split = "train"
else:
split = "test"
print('getting dataset')
self.dataset = load_dataset('wmt14', "de-en", split=split)
self.de_list = []
self.en_list = []
# self.tokenizer = tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-uncased')
en_list_2 = []
#for k in range(100):#len(self.dataset)):
# n,i = self.dataset[k]
for n, i in enumerate(self.dataset):
en_list_2.append(i['translation']['en'].lower())

a1 = list(self.tokenizer(en_list_2, padding=True, return_tensors='pt')['input_ids'])
#print(n)
if n==500:
break
#print(len(en_list_2))
# print(max(en_list_2))
token_res = self.tokenizer(en_list_2, padding=True,max_length=512, return_tensors='pt', truncation=True)['input_ids']
a1 = list(token_res)
self.en_vocab, self.en_vocab_size = vocab(a1)
self.bert2id_dict = translation_utils.bert2id(self.en_vocab)
self.id2bert_dict = translation_utils.id2bert(self.en_vocab)


for n, i in enumerate(self.dataset):
#if len(i['translation']['de'])> 400:
# print(len(i['translation']['de']))

#elif len(i['translation']['en'])> 400:
# print(len(i['translation']['en']))
# print(i['translation']['en'])

#else:
# print(len(i['translation']['de']))
self.de_list.append(self.tokenizer(i['translation']['de'].lower(), padding=True, return_tensors='pt',max_length=512, truncation=True)["input_ids"])
self.en_list.append(self.tokenizer(i['translation']['en'].lower(), padding=True, return_tensors='pt',max_length=512, truncation=True)["input_ids"])
if n==500:
break

for i in self.dataset:
self.de_list.append(self.tokenizer(i['translation']['de'].lower(),
padding=True, return_tensors='pt')["input_ids"])
self.en_list.append(self.tokenizer(i['translation']['en'].lower(),
self.en_list.append(self.tokenizer(i['translation']['en'].lower(),
padding=True, return_tensors='pt')["input_ids"])

# en_list_id = []
# for i in self.dataset:
# en_list_id.append(i['translation']['en'].lower())
de_list_1 = []
for n,i in enumerate(self.dataset):
de_list_1.append(i['translation']['de'].lower())
if n==500:
break

a = list(self.tokenizer(de_list_1, padding=True, return_tensors='pt')['input_ids'])
a = list(self.tokenizer(de_list_1, padding=True, return_tensors='pt',max_length=512, truncation=True)['input_ids'])

en_list_1 = []
for n,i in enumerate(self.dataset):
en_list_1.append(i['translation']['en'].lower())
if n==500:
break

b = list(self.tokenizer(de_list_1, padding=True, return_tensors='pt')['input_ids'])
b = list(self.tokenizer(de_list_1, padding=True, max_length=512, return_tensors='pt', truncation=True)['input_ids'])
# en_vocab, self.en_vocab_size = vocab(b)
self.de_vocab, self.de_vocab_size = vocab(a)

Expand Down
157 changes: 157 additions & 0 deletions t_dataset2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

import torch
from datasets import load_dataset
from transformers import AutoTokenizer
# from _config import Config as config
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

import translation_utils
from translation_utils import vocab
import os


os.environ['TRANSFORMERS_OFFLINE'] = 'yes'
class Translation_dataset_t(Dataset):

def __init__(self,
train: bool = True):

if train:
split = "train"
else:
split = "test"
print('getting dataset')
self.dataset = load_dataset('wmt14', "de-en", split=split)
self.de_list = []
self.en_list = []
# self.tokenizer = tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-uncased')
en_list_2 = []
#for k in range(100):#len(self.dataset)):
# n,i = self.dataset[k]
for n, i in enumerate(self.dataset):
en_list_2.append(i['translation']['en'].lower())
#print(n)
if n==500:
break
print(len(en_list_2))
# print(max(en_list_2))
print('error not found')
token_res = self.tokenizer(en_list_2, padding='max_length',max_length=512, return_tensors='pt', truncation=True)['input_ids']
a1 = list(token_res)
print('error')
self.en_vocab, self.en_vocab_size = vocab(a1)
self.bert2id_dict = translation_utils.bert2id(self.en_vocab)
self.id2bert_dict = translation_utils.id2bert(self.en_vocab)
print('e')


for n, i in enumerate(self.dataset):
#if len(i['translation']['de'])> 400:
# print(len(i['translation']['de']))

#elif len(i['translation']['en'])> 400:
# print(len(i['translation']['en']))
# print(i['translation']['en'])

#else:
# print(len(i['translation']['de']))
if len(i['translation']['de'].lower()) > 500:
pass
elif len(i['translation']['en'].lower())>500:
pass

self.de_list.append(self.tokenizer(i['translation']['de'].lower(), padding='max_length', return_tensors='pt',max_length=512, truncation=True)["input_ids"])
self.en_list.append(self.tokenizer(i['translation']['en'].lower(), padding='max_length', return_tensors='pt',max_length=512, truncation=True)["input_ids"])
# if n==500:
# break
'''
for i in self.dataset:
self.de_list.append(self.tokenizer(i['translation']['de'].lower(),
padding=True, return_tensors='pt')["input_ids"])
self.en_list.append(self.tokenizer(i['translation']['en'].lower(),
padding=True, return_tensors='pt')["input_ids"])
'''
# en_list_id = []
# for i in self.dataset:
# en_list_id.append(i['translation']['en'].lower())

de_list_1 = []
for n,i in enumerate(self.dataset):

if len(i['translation']['de'].lower()) > 500:
pass
elif len(i['translation']['en'].lower())>500:
pass
de_list_1.append(i['translation']['de'].lower())
#if n==500:
#break

a = list(self.tokenizer(de_list_1, padding='max_length', return_tensors='pt',max_length=512, truncation=True)['input_ids'])

en_list_1 = []
for n,i in enumerate(self.dataset):
en_list_1.append(i['translation']['en'].lower())
if n==500:
break

b = list(self.tokenizer(de_list_1, padding='max_length', max_length=512, return_tensors='pt', truncation=True)['input_ids'])
# en_vocab, self.en_vocab_size = vocab(b)
self.de_vocab, self.de_vocab_size = vocab(a)


#should return the length of the dataset
def __len__(self):
return len(self.de_list)

#should return a particular example
def __getitem__(self, index):
src = self.de_list[index]
trg = self.en_list[index]

return {'src':src, 'trg':trg}



class MyCollate:
def __init__(self,
tokenizer,
bert2id_dict: dict):
self.tokenizer = tokenizer
self.pad_idx = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
self.bert2id_dict = bert2id_dict

def __call__(self, batch):

source = []
for i in batch:
source.append(i['src'].T)
#print(source[0].shape, source[1].shape)
source = pad_sequence(source, batch_first=False, padding_value=self.pad_idx)

target = []
for i in batch:
target.append(i['trg'].T)
target = pad_sequence(target, batch_first=False, padding_value = self.pad_idx)

target_inp = target.squeeze(-1)[:-1, :]
target_out = torch.zeros(target.shape)

for i in range(len(target)):
for j in range(len(target[i])):
try:
target_out[i][j] = self.bert2id_dict[target[i][j].item()]
except KeyError:
target_out[i][j] = self.tokenizer.unk_token_id

target_out = target_out.squeeze(-1)[1:, :]

return source.squeeze(), target.squeeze().long(), target_inp.squeeze().long(), target_out.squeeze().long()


# dataset = Translation_dataset()
# loader = DataLoader(dataset=dataset,
# batch_size= 32,
# shuffle=False,
# collate_fn=MyCollate())
27 changes: 19 additions & 8 deletions train_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import t_dataset
from t_dataset import Translation_dataset_t
from t_dataset import MyCollate
import translation_dataset
import translation_utils
from translation_utils import TokenEmbedding, PositionalEncoding
from translation_utils import create_mask
Expand Down Expand Up @@ -149,10 +150,11 @@ def main_worker(gpu, args):
world_size=args.world_size, rank=args.rank)

if args.rank == 0:

'''
wandb.init(config=args, project='translation_test')#############################################
wandb.config.update(args)
config = wandb.config
'''

# exit()
args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -163,7 +165,11 @@ def main_worker(gpu, args):
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True

# print('loading barlow dataset')
# dataset = translation_dataset.Translation_dataset()
print('loading translation dataset')
dataset = Translation_dataset_t(train=args.train)
print('dataset loaded')
src_vocab_size = dataset.de_vocab_size
trg_vocab_size = dataset.en_vocab_size
tokenizer = dataset.tokenizer
Expand Down Expand Up @@ -236,10 +242,11 @@ def main_worker(gpu, args):
per_device_batch_size = args.batch_size // args.world_size
id2bert_dict = dataset.id2bert_dict
###############################
print('instantiating dataloader')
loader = torch.utils.data.DataLoader(
dataset, batch_size=per_device_batch_size, num_workers=args.workers,
pin_memory=True, sampler=sampler, collate_fn = MyCollate(tokenizer=tokenizer,bert2id_dict=dataset.bert2id_dict))

print('loaded on cuda')
test_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, num_workers=args.workers,
pin_memory=True, sampler=sampler, collate_fn = MyCollate(tokenizer=tokenizer,bert2id_dict=dataset.bert2id_dict))
Expand Down Expand Up @@ -283,7 +290,7 @@ def main_worker(gpu, args):
print(json.dumps(stats), file=stats_file)
if args.rank == 0:

wandb.log({"epoch_loss":epoch_loss/t})
#wandb.log({"epoch_loss":epoch_loss/t})
# save checkpoint
state = dict(epoch=epoch + 1, model=model.module.state_dict(),
optimizer=optimizer.state_dict())
Expand All @@ -296,7 +303,7 @@ def main_worker(gpu, args):
if epoch%args.checkbleu ==0 :

bleu_score = checkbleu(model, tokenizer, test_loader, id2bert_dict, gpu)
wandb.log({'bleu_score': bleu_score})
#wandb.log({'bleu_score': bleu_score})
# print(bleu_score(predicted, target))
##############################################################
# if epoch%1 ==0 :
Expand All @@ -309,14 +316,14 @@ def main_worker(gpu, args):
# optimizer=optimizer.state_dict())
# torch.save(state, args.checkpoint_dir / f'translation_checkpoint.pth')
# print('saved translation model in', args.checkpoint_dir)
wandb.finish()
#wandb.finish()

else:

bleu_score = checkbleu(model,tokenizer, test_loader, id2bert_dict, gpu )
print('test_bleu_score', bleu_score)
if args.rank == 0:
wandb.log({'bleu_score': bleu_score})
# if args.rank == 0:
#wandb.log({'bleu_score': bleu_score})


def checkbleu(model, tokenizer, test_loader, id2bert_dict, gpu):
Expand Down Expand Up @@ -366,6 +373,10 @@ def greedy_decode(model, src, src_mask, max_len, start_symbol, eos_idx, gpu):
memory = memory
tgt_mask = (translation_utils.generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).cuda(gpu, non_blocking=True)

print('ys shape: ', ys.shape)
print('memory.shape', memory.shape)
print('tgt_mask.shape', tgt_mask.shape)
out = model.module.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.module.generator(out[:, -1])
Expand Down Expand Up @@ -400,4 +411,4 @@ def translate(model: torch.nn.Module,

if __name__ == '__main__':
main()
wandb.finish()
#wandb.finish()
12 changes: 10 additions & 2 deletions translation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ def __init__(self):
self.en_list = []

for i in self.dataset:
self.de_list.append(tokenizer(i['translation']['de'].lower(), padding=True, return_tensors='pt')["input_ids"])
self.en_list.append(tokenizer(i['translation']['en'].lower(), padding=True, return_tensors='pt')["input_ids"])
if len(i['translation']['de'])> 400:
#print(len(i['translation']['de']))
pass
elif len(i['translation']['en'])> 400:
#print(len(i['translation']['en']))
pass
else:
# print(len(i['translation']['de']))
self.de_list.append(tokenizer(i['translation']['de'].lower(), padding=True, return_tensors='pt')["input_ids"])
self.en_list.append(tokenizer(i['translation']['en'].lower(), padding=True, return_tensors='pt')["input_ids"])



Expand Down
25 changes: 21 additions & 4 deletions translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,31 @@ def __init__(self, emb_size, mbert):
super(TokenEmbedding, self).__init__()
# self.embedding = nn.Embedding(vocab_size, emb_size)
self.embedding = mbert
# for param in self.embedding.parameters():
# param.requires_grad = False
# for param in self.embedding.pooler.parameters():
# param.requires_grad = True
for param in self.embedding.parameters():
param.requires_grad = False
for param in self.embedding.pooler.parameters():
param.requires_grad = True
self.emb_size = emb_size

def forward(self, tokens: torch.tensor):
# print(tokens.shape)
if len(tokens.shape) ==1:
tokens = tokens.unsqueeze(-1)

try:
self.embedding(tokens.long().T)['last_hidden_state']
except RuntimeError:
print('errored')

return self.embedding(tokens.long().T)['last_hidden_state'].permute(1, 0, 2) * math.sqrt(self.emb_size)

# try:


'''
except RuntimeError:
print('errored')
b = torch.zeros(tokens.shape[0], 1, 768)
pass
'''
1 change: 0 additions & 1 deletion wandb/debug-internal.log

This file was deleted.

1 change: 0 additions & 1 deletion wandb/debug.log

This file was deleted.

1 change: 0 additions & 1 deletion wandb/latest-run

This file was deleted.

Loading

0 comments on commit e7bdf59

Please sign in to comment.