-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
39 lines (38 loc) · 1.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from mydataset import MyDataset
from mymodel import MyModel
from torch.utils.tensorboard import SummaryWriter
if __name__ == '__main__':
train_dataset = MyDataset("./datasets/train/")
# shuffle: mess up the order
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
mymodel = MyModel().cuda()
loss_fn = nn.MultiLabelSoftMarginLoss().cuda() # 多标签交叉熵损失函数
# 优化器 Adam 一般要求学习率比较小
optim = Adam(mymodel.parameters(),
lr=0.001 # learn (speed) rate
)
writer = SummaryWriter("logs")
total_step = 0
# train for 10 rounds
for epoch in range(10):
# train once
for i, (image, label) in enumerate(train_dataloader):
image = image.cuda()
label = label.cuda()
mymodel.train()
output = mymodel(image)
loss = loss_fn(output, label)
optim.zero_grad() # 梯度归零
loss.backward() # 反向传播计算
optim.step()
total_step += 1
# display results every 10 runs
if i % 10 == 0:
print("epoch {}, step {}, loss {}".format(epoch, i, loss))
writer.add_scalar("loss", loss, total_step)
writer.close()
torch.save(mymodel, "model.pth")