)
# optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
if not os.path.exists(“logCNN“):
os.mkdir(“logCNN“)
writer = tensorboardX.SummaryWriter(“logCNN“)
for epoch in range(epoch_num):
train_sum_loss = 0
train_sum_correct = 0
train_sum_fp = 0
train_sum_fn = 0
train_sum_tp = 0
train_sum_tn = 0
for i, data in enumerate(trainDataLoader):
net.train()
inputs, labels = data
inputs = inputs.unsqueeze(1).to(torch.float32)
labels = labels.type(torch.LongTensor)
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, pred = torch.max(outputs.data, dim=1)
acc = pred.eq(labels.data).cpu().sum()
one = torch.ones_like(labels)
zero = torch.zeros_like(labels)
tn = ((labels == zero) * (pred == zero)).sum()
tp = ((labels == one) * (pred == one)).sum()
fp = ((labels == zero) * (pred == one)).sum()
fn = ((labels == one) * (pred == zero)).sum()
train_sum_fn += fn.item()
train_sum_fp += fp.item()
train_sum_tn += tn.item()
train_sum_tp += tp.item()
train_sum_loss += loss.item()
-->>