-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathfinetune.py
78 lines (63 loc) · 2.81 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
from logging import getLogger
import torch
from recbole.config import Config
from recbole.data import data_preparation
from recbole.utils import init_seed, init_logger, get_trainer, set_color
from unisrec import UniSRec
from data.dataset import UniSRecDataset
def finetune(dataset, pretrained_file, fix_enc=True, **kwargs):
# configurations initialization
props = ['props/UniSRec.yaml', 'props/finetune.yaml']
print(props)
# configurations initialization
config = Config(model=UniSRec, dataset=dataset, config_file_list=props, config_dict=kwargs)
init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(config)
# dataset filtering
dataset = UniSRecDataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
model = UniSRec(config, train_data.dataset).to(config['device'])
# Load pre-trained model
if pretrained_file != '':
checkpoint = torch.load(pretrained_file)
logger.info(f'Loading from {pretrained_file}')
logger.info(f'Transfer [{checkpoint["config"]["dataset"]}] -> [{dataset}]')
model.load_state_dict(checkpoint['state_dict'], strict=False)
if fix_enc:
logger.info(f'Fix encoder parameters.')
for _ in model.position_embedding.parameters():
_.requires_grad = False
for _ in model.trm_encoder.parameters():
_.requires_grad = False
logger.info(model)
# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=True, show_progress=config['show_progress']
)
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])
logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
return config['model'], config['dataset'], {
'best_valid_score': best_valid_score,
'valid_score_bigger': config['valid_metric_bigger'],
'best_valid_result': best_valid_result,
'test_result': test_result
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', type=str, default='Scientific', help='dataset name')
parser.add_argument('-p', type=str, default='', help='pre-trained model path')
parser.add_argument('-f', type=bool, default=True)
args, unparsed = parser.parse_known_args()
print(args)
finetune(args.d, pretrained_file=args.p, fix_enc=args.f)