Skip to content

Commit

Permalink
Simplify training code
Browse files Browse the repository at this point in the history
  • Loading branch information
huyvnphan committed Nov 21, 2019
1 parent 48ec280 commit b46f8d3
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 138 deletions.
94 changes: 25 additions & 69 deletions .ipynb_checkpoints/CIFAR10-checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,7 +34,7 @@
"\n",
"from tqdm import tqdm as pbar\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from models import *"
"from cifar10_models import *"
]
},
{
Expand All @@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -65,16 +65,16 @@
" transforms.ToTensor(),\n",
" transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
" \n",
"# transform_validation = transforms.Compose([transforms.ToTensor(),\n",
"# transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
" transform_validation = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])\n",
" \n",
" transform_validation = transforms.Compose([transforms.ToTensor()])\n",
" \n",
" trainset = torchvision.datasets.CIFAR10(root=params['path'], train=True, transform=transform_train)\n",
" testset = torchvision.datasets.CIFAR10(root=params['path'], train=False, transform=transform_validation)\n",
" \n",
" trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'])\n",
" testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'])\n",
" trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=4)\n",
" testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=4)\n",
" return trainloader, testloader"
]
},
Expand All @@ -87,25 +87,22 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, params):\n",
" \n",
" writer = SummaryWriter('runs/' + params['description'])\n",
" model = model.to(params['device'])\n",
" optimizer = optim.SGD(model.parameters(), lr=params['max_lr'], weight_decay=params['weight_decay'], momentum=0.9, nesterov=True)\n",
" optimizer = optim.AdamW(model.parameters())\n",
" total_updates = params['num_epochs']*len(params['train_loader'])\n",
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_updates, eta_min=params['min_lr'])\n",
" \n",
" criterion = nn.CrossEntropyLoss()\n",
" best_accuracy = test_model(model, params)\n",
" best_model = copy.deepcopy(model.state_dict())\n",
" \n",
" for epoch in pbar(range(params['num_epochs'])):\n",
" scheduler.step()\n",
" \n",
" # Each epoch has a training and validation phase\n",
" for phase in ['train', 'validation']:\n",
" \n",
Expand Down Expand Up @@ -153,7 +150,7 @@
" \n",
" # Write best weights to disk\n",
" if epoch % params['check_point'] == 0 or epoch == params['num_epochs']-1:\n",
" torch.save(best_model, params['state_dict_path'] + params['description'] + '.pt')\n",
" torch.save(best_model, params['description'] + '.pt')\n",
" \n",
" final_accuracy = test_model(model, params)\n",
" writer.add_text('Final_Accuracy', str(final_accuracy), 0)\n",
Expand All @@ -169,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -203,23 +200,11 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"model = resnet18()\n",
"model.load_state_dict(torch.load('/tmp/checkpoint_12000.pth'))"
"model = resnet18()"
]
},
{
Expand All @@ -231,74 +216,45 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda:2\n"
]
}
],
"outputs": [],
"source": [
"# Train on cuda if available\n",
"device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"print(\"Using\", device)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_params = {'path': '/raid/data/pytorch_dataset/cifar10',\n",
" 'batch_size': 256, 'num_workers': 4}\n",
"data_params = {'path': '/raid/data/pytorch_dataset/cifar10', 'batch_size': 256}\n",
"\n",
"train_loader, validation_loader = make_dataloaders(data_params)\n",
"\n",
"train_params = {'description': 'ResNet18',\n",
"train_params = {'description': 'Test',\n",
" 'num_epochs': 300,\n",
" 'max_lr': 5e-2, 'min_lr': 1e-5, 'weight_decay': 1e-3,\n",
" 'check_point': 50, 'device': device,\n",
" 'state_dict_path': 'trained_models/',\n",
" 'train_loader': train_loader, 'validation_loader': validation_loader}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# train_model(model, train_params)"
"train_model(model, train_params)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 40/40 [00:01<00:00, 23.87it/s]\n"
]
},
{
"data": {
"text/plain": [
"0.7538"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"test_model(model, train_params)"
]
Expand Down Expand Up @@ -327,7 +283,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit b46f8d3

Please sign in to comment.