In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
In [2]:
train_data = datasets.MNIST(
root="data",
download=True,
train=True,
transform=ToTensor()
)
test_data = datasets.MNIST(
root="data",
download=True,
train=False,
transform=ToTensor()
)
In [3]:
batch_size = 64
train_dataloader = DataLoader(train_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size = batch_size)
for X,y in train_dataloader:
print(X.shape, X.dtype, y.shape, y.dtype)
break
torch.Size([64, 1, 28, 28]) torch.float32 torch.Size([64]) torch.int64
In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
cpu
In [5]:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
model = Net().to(device)
print(model)
Net( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (dropout1): Dropout(p=0.25, inplace=False) (dropout2): Dropout(p=0.5, inplace=False) (fc1): Linear(in_features=9216, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=10, bias=True) )
In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)
In [7]:
def train(dataloader, model, loss_fn, optimizer, device):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# loss
pred = model(X)
loss = loss_fn(pred, y)
#backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss = loss.item()
current = batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
def learn(train_dataloader, test_dataloader, model, loss_fn, optimizer, device, epoch):
for t in range(epoch):
print(f"Epoch {t+1}")
train(train_dataloader, model, loss_fn, optimizer, device)
test(test_dataloader, model, loss_fn, device)
print("done!")
In [8]:
epoch = 6
learn(
train_dataloader = train_dataloader,
test_dataloader = test_dataloader,
model = model,
loss_fn = loss_fn,
optimizer = optimizer,
device = device,
epoch = epoch
)
Epoch 1 loss: 2.316458 [ 0/60000] loss: 2.289086 [ 6400/60000] loss: 2.286704 [12800/60000] loss: 2.264874 [19200/60000] loss: 2.240640 [25600/60000] loss: 2.203022 [32000/60000] loss: 2.183667 [38400/60000] loss: 2.159261 [44800/60000] loss: 2.110952 [51200/60000] loss: 2.006685 [57600/60000] Test Error: Accuracy: 72.2%, Avg loss: 1.948846 Epoch 2 loss: 1.953865 [ 0/60000] loss: 1.818807 [ 6400/60000] loss: 1.763511 [12800/60000] loss: 1.509245 [19200/60000] loss: 1.405758 [25600/60000] loss: 1.282818 [32000/60000] loss: 1.060655 [38400/60000] loss: 1.183832 [44800/60000] loss: 1.009735 [51200/60000] loss: 0.810097 [57600/60000] Test Error: Accuracy: 84.1%, Avg loss: 0.688012 Epoch 3 loss: 0.938321 [ 0/60000] loss: 0.737556 [ 6400/60000] loss: 0.738570 [12800/60000] loss: 0.759897 [19200/60000] loss: 0.677490 [25600/60000] loss: 0.753909 [32000/60000] loss: 0.631555 [38400/60000] loss: 0.826381 [44800/60000] loss: 0.659942 [51200/60000] loss: 0.624089 [57600/60000] Test Error: Accuracy: 88.4%, Avg loss: 0.441919 Epoch 4 loss: 0.547082 [ 0/60000] loss: 0.542435 [ 6400/60000] loss: 0.432525 [12800/60000] loss: 0.604901 [19200/60000] loss: 0.521915 [25600/60000] loss: 0.575968 [32000/60000] loss: 0.444457 [38400/60000] loss: 0.654924 [44800/60000] loss: 0.520847 [51200/60000] loss: 0.521495 [57600/60000] Test Error: Accuracy: 89.8%, Avg loss: 0.366778 Epoch 5 loss: 0.510356 [ 0/60000] loss: 0.488137 [ 6400/60000] loss: 0.430981 [12800/60000] loss: 0.551029 [19200/60000] loss: 0.470385 [25600/60000] loss: 0.496819 [32000/60000] loss: 0.372566 [38400/60000] loss: 0.642954 [44800/60000] loss: 0.460717 [51200/60000] loss: 0.567079 [57600/60000] Test Error: Accuracy: 90.7%, Avg loss: 0.324936 Epoch 6 loss: 0.384501 [ 0/60000] loss: 0.358661 [ 6400/60000] loss: 0.367948 [12800/60000] loss: 0.530707 [19200/60000] loss: 0.376030 [25600/60000] loss: 0.481405 [32000/60000] loss: 0.387268 [38400/60000] loss: 0.488402 [44800/60000] loss: 0.481843 [51200/60000] loss: 0.466399 [57600/60000] Test Error: Accuracy: 91.5%, Avg loss: 0.295597 done!