Things to discuss in this lecture:
- Construction of multi-layer networks
- Alternate activation functions to improve capabilities
- Creating multi-layer perceptrons in PyTorch
Limits of Logistic Regression¶
We have demonstrated how logistic regression paired with the softmax function enables us to perform multi-class classification. Through toy datasets in the previous lectures and the MNIST dataset in the homework, we have shown that logistic regression is highly effective in many problem settings. Recall that we can define a multi-class logistic regression model as performing a matrix-vector multiplication with an input vector $x\in\mathbb{R}^N$ to produce class scores $z\in\mathbb{R}^M$.
$$ \begin{align} z &= Wx+b\\ &= \begin{bmatrix} \rule[.6ex]{4ex}{0.75pt} & w_1^\top & \rule[.6ex]{4ex}{0.75pt}\\ \rule[.6ex]{4ex}{0.75pt} & w_2^\top & \rule[.6ex]{4ex}{0.75pt}\\ & \vdots & \\ \rule[.6ex]{4ex}{0.75pt} & w_M^\top & \rule[.6ex]{4ex}{0.75pt}\\ \end{bmatrix}\begin{bmatrix} \rule[-1ex]{0.5pt}{4ex}\\ x\\ \rule[1ex]{0.5pt}{4ex}\\ \end{bmatrix} +\begin{bmatrix} b_1\\ b_2\\ \vdots\\ b_M \end{bmatrix}\\ &= \begin{bmatrix} z_1\\ z_2\\ \vdots\\ z_M \end{bmatrix} \end{align} $$In PyTorch, we can efficiently implement the multi-class logistic regression model using the nn.Linear
class which implements parameter matrices including bias terms. The resulting class scores are then converted to class probabilities using the softmax function.
Let's now consider another toy dataset known as the two moons dataset and see how logistic regression performs on this data.
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
class TwoMoonsDataset(Dataset):
def __init__(self, sigma, N):
self.sigma = sigma
self.N = N
self.radius = 1
self.angle_offset = -np.pi/8
self.positive_center = torch.tensor([0, 0])
self.negative_center = torch.tensor([1, 0])
# generate angles of each moon
positive_angles = torch.rand(N)*(np.pi-2*self.angle_offset)+self.angle_offset
negative_angles = torch.rand(N)*(np.pi-2*self.angle_offset)+self.angle_offset
# generate each moon
self.positive_data = self.positive_center + torch.stack((self.radius*torch.cos(positive_angles),
self.radius*torch.sin(positive_angles)), dim=-1)
self.positive_data = self.positive_data + torch.randn(N, 2)*sigma
self.negative_data = self.negative_center + torch.stack((self.radius*torch.cos(-negative_angles),
self.radius*torch.sin(-negative_angles)), dim=-1)
self.negative_data = self.negative_data + torch.randn(N, 2)*sigma
# wrap up all data and labels
self.data = torch.cat((self.positive_data, self.negative_data), dim=0)
self.labels = torch.cat((torch.ones(N), torch.zeros(N))).long()
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class MulticlassLogisticRegression(nn.Module):
def __init__(self, N, M):
super().__init__()
self.N = N # input dimension
self.M = M # number of classes
self.weight_matrix = nn.Linear(N, M, bias=True) # N input dimensions, M output dimensions
def forward(self, x):
return self.weight_matrix(x)
def plot_model_probs(model, plus_class, negative_class):
x = torch.linspace(-2, 3, 100)
y = torch.linspace(-2, 2, 100)
X, Y = torch.meshgrid(x, y, indexing='ij')
meshgrid_inputs = torch.stack((X.flatten(), Y.flatten()), dim=1)
with torch.no_grad():
meshgrid_outputs = torch.softmax(model(meshgrid_inputs), dim=1)[:, 1]
plt.figure(figsize=(8, 6))
plt.contourf(X.numpy(), Y.numpy(), meshgrid_outputs.reshape(100, 100).numpy(), cmap='RdBu_r', levels=100)
plt.colorbar()
plt.title('Probability of positive class')
plt.scatter(plus_class[:, 0].numpy(), plus_class[:, 1].numpy(), color='tomato', s=50, edgecolor='black')
plt.scatter(negative_class[:, 0].numpy(), negative_class[:, 1].numpy(), color='cornflowerblue', s=50, edgecolor='black')
plt.axis(False)
plt.tight_layout()
def plot_model_probs_with_component_lines(model, plus_class, negative_class, W, b, layers):
x = torch.linspace(-2, 3, 100)
y = torch.linspace(-2, 2, 100)
X, Y = torch.meshgrid(x, y, indexing='ij')
meshgrid_inputs = torch.stack((X.flatten(), Y.flatten()), dim=1)
with torch.no_grad():
meshgrid_outputs = torch.softmax(model(meshgrid_inputs), dim=1)[:, 1]
plt.figure(figsize=(8, 6))
plt.contourf(X.numpy(), Y.numpy(), meshgrid_outputs.reshape(100, 100).numpy(), cmap='RdBu_r', levels=100)
plt.colorbar()
plt.title('Probability of positive class')
plt.scatter(plus_class[:, 0].numpy(), plus_class[:, 1].numpy(), color='tomato', s=50, edgecolor='black')
plt.scatter(negative_class[:, 0].numpy(), negative_class[:, 1].numpy(), color='cornflowerblue', s=50, edgecolor='black')
plt.axis(False)
plt.tight_layout()
# Plot the hidden layer linear decision boundaries
#W = model.weight_matrix1.weight.detach().numpy() # Shape: (L, 2)
#b = model.weight_matrix1.bias.detach().numpy() # Shape: (L,)
x_vals = np.linspace(-2, 3, 200)
for i in range(layers):
w = W[i]
bias = b[i]
# Line equation: w0 * x + w1 * y + b = 0 ⟹ y = (-w0 * x - b) / w1
if np.abs(w[1]) > 1e-6: # avoid division by zero
y_vals = (-w[0] * x_vals - bias) / w[1]
plt.plot(x_vals, y_vals, linestyle='--', color='black', alpha=0.5)
plt.xlim(-2, 3)
plt.ylim(-2, 2)
def multiclass_model_accuracy(model, input_data, labels):
predictions = model(input_data) # no need to squeeze/unsqueeze dimensions now!
predicted_classes = torch.argmax(predictions, dim=1) # find highest scoring class along the columns
n_correct = torch.sum(torch.eq(predicted_classes, labels))
return n_correct
def print_model_params(model):
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
# visualize example of two moons dataset
N = 100
sigma = 0.1
dataset = TwoMoonsDataset(sigma, N)
plus_data = dataset.positive_data
negative_data = dataset.negative_data
plt.figure(figsize=(8, 8))
plt.scatter(plus_data[:, 0].numpy(), plus_data[:, 1].numpy(), color='tomato', s=50, edgecolor='black', label='Positive Class')
plt.scatter(negative_data[:, 0].numpy(), negative_data[:, 1].numpy(), color='cornflowerblue', s=50, edgecolor='black', label='Negative Class')
plt.title('Two Moons Dataset')
plt.legend()
plt.tight_layout()
Now, let's try training a logistic regression model on this data. For the purposes of this exercise, we will not worry about creating separate training, validation, and testing splits.
Let's recap on what our logistic perceptron move is visually:

# loss function, model, and optimizer
criterion = nn.CrossEntropyLoss(reduction='mean') # cross-entropy loss, use mean of loss
lr = 1e-2 # learning rate
M = 2 # two classes
N = 2 # data is two-dimensional
model = MulticlassLogisticRegression(M, N)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.99, weight_decay=1e-3)
# create training dataloader
batch_size = 16
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# training loop
n_epoch = 200
loss_values, accuracies = [], []
for n in range(n_epoch):
epoch_loss, epoch_acc = 0, 0
for x_batch, y_batch in train_loader:
# zero out gradients
optimizer.zero_grad()
# pass batch to model
predictions = model(x_batch)
# calculate loss
loss = criterion(predictions, y_batch)
# backpropagate and update
loss.backward() # backprop
optimizer.step()
# logging to update epoch_loss (add loss value) and epoch_acc (add current batch accuracy)
epoch_loss += loss.item()
epoch_acc += multiclass_model_accuracy(model, x_batch, y_batch)
loss_values.append(epoch_loss/len(train_loader))
accuracies.append(epoch_acc/len(dataset))
# plot model probabilities
plot_model_probs(model, dataset.positive_data, dataset.negative_data)
# plot loss values
plt.figure(figsize=(12,6))
plt.subplot(121)
plt.semilogy(loss_values)
plt.grid(True)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(122)
plt.plot(accuracies)
plt.grid(True)
plt.title('Classification Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
Text(0, 0.5, 'Accuracy')
plot_model_probs_with_component_lines(model, dataset.positive_data, dataset.negative_data, model.weight_matrix.weight.detach().numpy(), model.weight_matrix.bias.detach().numpy(), 2)
print_model_params(model)
weight_matrix.weight tensor([[ 0.8102, -1.9734], [-0.8240, 2.0188]]) weight_matrix.bias tensor([-0.5790, 0.5570])
The provided two moons dataset is clearly separable by some curvy line, thus we should be able to find some representation for the class boundary that perfectly separates the two classes. However, logistic regression is incapable of performing this separation because the weights of a logistic regression model parameterize a straight line or hyperplane class boundary. In other words, we can only achieve linear class boundaries with logistic regression while this toy dataset requires a non-linear class boundary.
How can we represent more complex functions?¶
We would like to implement barriers that aren't just straight lines, but are also curves. So let's forget training for a moment and try to figure out how to create a curved barrier:
class MulticlassLogisticRegression_Test1(nn.Module):
def __init__(self):
super().__init__()
self.N = 2 # input dimension
self.M = 2 # number of classes
self.weight_matrix = nn.Linear(N, M, bias=True)
# Set custom initial weights and biases
custom_weights_1 = torch.tensor([[-0.6106, -2.2457],
[ 0.6192, 2.2808]])
custom_biases_1 = torch.tensor([0.2708, 0.2180])
custom_weights_2 = torch.tensor([[0.6106, -2.2457],
[ -0.6192, 2.2808]])
custom_biases_2 = torch.tensor([-0.2708, 0.2180])
with torch.no_grad():
self.weight_matrix.weight.copy_(custom_weights_1)
self.weight_matrix.bias.copy_(custom_biases_1)
def forward(self, x):
return self.weight_matrix(x)
model_test = MulticlassLogisticRegression_Test1()
plot_model_probs_with_component_lines(model_test, dataset.positive_data, dataset.negative_data, model_test.weight_matrix.weight.detach().numpy(), model_test.weight_matrix.bias.detach().numpy(), 2)
print_model_params(model_test)
weight_matrix.weight tensor([[-0.6106, -2.2457], [ 0.6192, 2.2808]]) weight_matrix.bias tensor([0.2708, 0.2180])
print(model_test(torch.tensor([1.0,1.0])))
tensor([-2.5855, 3.1180], grad_fn=<ViewBackward0>)
What if we try to combine multiple linear separation planes? Can that give us the curvature we want?

class MulticlassLogisticRegression_Test2(nn.Module):
def __init__(self):
super().__init__()
self.N = 2 # input dimension
self.M = 2 # number of classes
self.weight_matrix = nn.Linear(N, 2, bias=True)
self.weight_matrix_2 = nn.Linear(2, M, bias=True)
# Set custom initial weights and biases
custom_weights_1 = torch.tensor([[-0.6106, -2.2457],
[ 0.6106, -2.2457]])
custom_biases_1 = torch.tensor([0.2708, -0.2708])
# Set custom initial weights and biases
custom_weights_2 = torch.tensor([[1, 1],
[ -1, -1]])
custom_biases_2 = torch.tensor([0.0, 0.0])
with torch.no_grad():
self.weight_matrix.weight.copy_(custom_weights_1)
self.weight_matrix.bias.copy_(custom_biases_1)
self.weight_matrix_2.weight.copy_(custom_weights_2)
self.weight_matrix_2.bias.copy_(custom_biases_2)
def forward(self, x):
x = self.weight_matrix(x)
x = torch.nn.functional.relu(x)
z = self.weight_matrix_2(x)
return z
#return self.weight_matrix_1(x)
model_test = MulticlassLogisticRegression_Test2()
plot_model_probs_with_component_lines(model_test, dataset.positive_data, dataset.negative_data, model_test.weight_matrix.weight.detach().numpy(), model_test.weight_matrix.bias.detach().numpy(), 2)
print_model_params(model_test)
weight_matrix.weight tensor([[-0.6106, -2.2457], [ 0.6106, -2.2457]]) weight_matrix.bias tensor([ 0.2708, -0.2708]) weight_matrix_2.weight tensor([[ 1., 1.], [-1., -1.]]) weight_matrix_2.bias tensor([0., 0.])
What's with the ReLU?¶
Well let's try without ReLu real quick:
class MulticlassLogisticRegression_Test3(nn.Module):
def __init__(self):
super().__init__()
self.N = 2 # input dimension
self.M = 2 # number of classes
self.weight_matrix = nn.Linear(N, 2, bias=True)
self.weight_matrix_2 = nn.Linear(2, M, bias=True)
# Set custom initial weights and biases
custom_weights_1 = torch.tensor([[-0.6106, -2.2457],
[ 0.6106, -2.2457]])
custom_biases_1 = torch.tensor([0.2708, -0.2708])
# Set custom initial weights and biases
custom_weights_2 = torch.tensor([[1, 1],
[ -1, -1]])
custom_biases_2 = torch.tensor([0.0, 0.0])
with torch.no_grad():
self.weight_matrix.weight.copy_(custom_weights_1)
self.weight_matrix.bias.copy_(custom_biases_1)
self.weight_matrix_2.weight.copy_(custom_weights_2)
self.weight_matrix_2.bias.copy_(custom_biases_2)
def forward(self, x):
x = self.weight_matrix(x)
#x = torch.nn.functional.relu(x)
z = self.weight_matrix_2(x)
return z
#return self.weight_matrix_1(x)
model_test = MulticlassLogisticRegression_Test3()
plot_model_probs_with_component_lines(model_test, dataset.positive_data, dataset.negative_data, model_test.weight_matrix.weight.detach().numpy(), model_test.weight_matrix.bias.detach().numpy(), 2)
print_model_params(model_test)
weight_matrix.weight tensor([[-0.6106, -2.2457], [ 0.6106, -2.2457]]) weight_matrix.bias tensor([ 0.2708, -0.2708]) weight_matrix_2.weight tensor([[ 1., 1.], [-1., -1.]]) weight_matrix_2.bias tensor([0., 0.])
Without ReLU, the two Linear layers collapse into a single linear transformation:
$$ \text{Output} = W_2(W_1 x + b_1) + b_2 = (W_2 W_1) x + (W_2 b_1 + b_2) $$This is still just a linear function of the input, and the decision boundary (where the output logits for two classes are equal) will be defined by a linear equation—i.e., a straight line in 2D.
So, in effect, you’re just doing one big matrix multiply with a bias—like a single-layer linear classifier.
We need to break up the linearity of multiplying these matrices by injecting a simple element-wise non-linear function:
Let $\textrm{ReLU}(z)$ denote the rectified linear unit (ReLU) function where $$ \textrm{ReLU}(z)=\max\{0, z\} $$ simply thresholds negative numbers to zero. This function is non-linear and allows us to create a now function: $$ f(x) = z = W_2\textrm{ReLU}(W_1x). $$
When you introduce ReLU, it applies an element-wise non-linearity between the two linear layers. So now, your network looks like this:
$$ x_1 = W_1 x + b_1 \ x_2 = \text{ReLU}(x_1) \ z = W_2 x_2 + b_2 $$The ReLU “bends” the input space by zeroing out negative components of the intermediate vector. Because of this:
- The network can behave differently in different regions of the input space.
- It becomes a piecewise linear function instead of a single linear one.
- These different linear regions (from ReLU activating or deactivating certain nodes) combine to form non-linear, curved decision boundaries.
This is what allows neural networks to learn complex shapes for classification—even though each layer is linear, the ReLU introduces crucial non-linear behavior.
Let's return to our dual half-moon example and see how good we can classify:
class TwoLayerModel(nn.Module):
def __init__(self, N, L, M):
super().__init__()
self.N = N # input dimension
self.M = M # number of classes
self.weight_matrix1 = nn.Linear(N, L, bias=True) # N input dimensions, L hidden dimensions
self.weight_matrix2 = nn.Linear(L, M)
def forward(self, x):
x = self.weight_matrix1(x)
x = torch.nn.functional.relu(x)
z = self.weight_matrix2(x)
return z
# loss function, model, and optimizer
criterion = nn.CrossEntropyLoss(reduction='mean') # cross-entropy loss, use mean of loss
lr = 1e-2 # learning rate
M = 2 # two classes
N = 2 # data is two-dimensional
L = 8 # number of hidden features
fancy_new_model = TwoLayerModel(N, L, M)
optimizer = torch.optim.SGD(fancy_new_model.parameters(), lr=lr, momentum=0.99, weight_decay=1e-3)
# create training dataloader
batch_size = 16
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# training loop
n_epoch = 200
loss_values, accuracies = [], []
for n in range(n_epoch):
epoch_loss, epoch_acc = 0, 0
for x_batch, y_batch in train_loader:
# zero out gradients
optimizer.zero_grad()
# pass batch to model
predictions = fancy_new_model(x_batch)
# calculate loss
loss = criterion(predictions, y_batch)
# backpropagate and update
loss.backward() # backprop
optimizer.step()
# logging to update epoch_loss (add loss value) and epoch_acc (add current batch accuracy)
epoch_loss += loss.item()
epoch_acc += multiclass_model_accuracy(fancy_new_model, x_batch, y_batch)
loss_values.append(epoch_loss/len(train_loader))
accuracies.append(epoch_acc/len(dataset))
# plot model probabilities
plot_model_probs(fancy_new_model, dataset.positive_data, dataset.negative_data)
# plot loss values
plt.figure(figsize=(12,6))
plt.subplot(121)
plt.semilogy(loss_values)
plt.grid(True)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(122)
plt.plot(accuracies)
plt.grid(True)
plt.title('Classification Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
Text(0, 0.5, 'Accuracy')
plot_model_probs_with_component_lines(fancy_new_model, dataset.positive_data, dataset.negative_data, fancy_new_model.weight_matrix1.weight.detach().numpy(), fancy_new_model.weight_matrix1.bias.detach().numpy(), L)
for name, param in fancy_new_model.named_parameters():
if param.requires_grad:
print(name, param.data)
weight_matrix1.weight tensor([[-0.0984, -0.0491], [ 1.0634, -0.3411], [ 0.1274, -1.2722], [-0.0078, 0.0162], [ 1.9571, -0.5915], [-0.0313, -0.0305], [-3.8406, 0.6589], [-2.2012, -1.3838]]) weight_matrix1.bias tensor([-0.1895, -1.0570, -0.4503, -0.0288, -1.9345, -0.1282, 0.1808, 1.8989]) weight_matrix2.weight tensor([[-0.0573, 1.0909, 0.9631, 0.0124, 1.9715, 0.0703, -2.7555, 2.2609], [ 0.0421, -1.0758, -0.9455, -0.0158, -1.9905, -0.0833, 2.7518, -2.2542]]) weight_matrix2.bias tensor([-1.9892, 1.9645])
Activation Functions¶
We commonly refer to these element-wise non-linearities as activation functions. Examples of activation functions, including ReLU, are as follows:
$$ \sigma(z) = \begin{cases} z,~&z\geq 0\\ 0,~&z<0 \end{cases}=\max\{0, z\} $$ $$ \sigma(z) = \frac{1}{1+e^{-z}} $$ $$ \sigma(z) = \frac{e^z-e^{-z}}{e^{z}+e^{-z}} $$ $$ \sigma_\tau(z) = \begin{cases} z,~&z\geq 0\\ -\tau z~&z < 0 \end{cases} $$- and many more may be found here
Multi-layer Perceptron¶
The model we created above is known as a two-layer perceptron. One individual weight matrix that transforms an input vector to another vector is commonly referred to as a perceptron. The concatenation of multiple perceptrons separated by non-linear activation functions is known as a multi-layer perceptron (MLP) or fully-connected network (we will skip the abbreviation since FCN is commonly used for something else in machine learning).
Multi-layer perceptrons are our first example of a deep neural network or deep net in this course! As mentioned earlier, we may stack arbitrarily many perceptrons and non-linearities to form deeper neural nets. Each layer has an input and output dimension that is a hyperparameter of the model architecture. Every layer that is followed by an activation function is referred to as a hidden layer as it separates the inputs from the outputs of the model. The below figure depicts a three-layer MLP with two hidden layers. Each arrow represents a weight multiplying one entry of an input vector. The result of this multiplication is passed to a node in the next layer where the result at each node is the summation of all incoming arrows followed by an activation function.

Let's create another example MLP.
Larger Deepnet: Image Classification with MLP¶
To conclude this lecture, we will experiment with creating MLP models to perform image classification with the FashionMNIST dataset. This dataset contains $28\times 28$ grayscale images of clothing sorted into ten classes. We have created a dataset to download the data from the torchvision
package (which you may need to install, as well as the tqdm
package for tracking progress bars during training).
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# Define a transform to convert the data to tensor and normalize it
transform = transforms.ToTensor()
# Download the FashionMNIST training dataset
fashion_mnist = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
# Create a DataLoader to iterate through the dataset
data_loader = torch.utils.data.DataLoader(fashion_mnist, batch_size=8, shuffle=True)
# Define class labels for FashionMNIST
class_labels = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]
# Get a batch of data
images, labels = next(iter(data_loader))
# Plot the images
plt.figure(figsize=(10, 4))
for i in range(len(images)):
plt.subplot(2, 4, i+1)
plt.imshow(images[i].squeeze(), cmap='gray')
plt.title(class_labels[labels[i]])
plt.axis('off')
plt.tight_layout()
plt.show()
import torch
import torch.nn as nn
class MyMLPModel(nn.Module):
# add hidden dimension sizes to constructor as you see fit!
def __init__(self, input_dim, h1, output_dim, activation_fn):
super().__init__()
# create layers using nn.Linear(input_dimension, output_dimension) objects
self.fc1 = nn.Linear(input_dim, h1)
self.fc2 = nn.Linear(h1, output_dim)
# assign non-linear activation function to class
self.activation = activation_fn
def forward(self, x):
# implement forward pass
x = self.fc1(x)
x = self.activation(x)
z = self.fc2(x)
return z
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
# set up data
class FashionMNIST(Dataset):
def __init__(self):
self.data = torchvision.datasets.FashionMNIST(root='./',
download=True,
train=True,
transform=transforms.Compose([transforms.ToTensor()]))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image, label = self.data[idx]
return image.reshape(-1), label
N = 1000
N_train = 600
N_val = 200
N_test = 200
dataset = FashionMNIST()
indices = np.random.choice(np.arange(len(dataset)), size=N, replace=False)
np.random.shuffle(indices)
train_indices = indices[:N_train]
val_indices = indices[N_train:N_train+N_val]
test_indices = indices[N_train+N_val]
batch_size = 8
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(val_indices))
# initialize model
input_dim = 784 # dimension of images after being vectorized
h1 = 128
output_dim = 10 # number of classes
activation_fn = nn.ReLU()
model = MyMLPModel(input_dim, h1, output_dim, activation_fn) # fill this is based on your implementation!
# initialize loss function and optimizer
criterion = nn.CrossEntropyLoss()
lr = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.99, weight_decay=1e-4)
# logging info
loss_values, train_accuracies, val_accuracies = [], [], []
n_epoch = 300 # set this value
for n in tqdm(range(n_epoch)):
epoch_loss, epoch_acc = 0, 0
for x_batch, y_batch in train_loader:
# zero out gradients
optimizer.zero_grad()
# pass batch to model, no need to worry about using squeeze/unsqueeze now
predictions = model(x_batch)
# calculate loss
loss = criterion(predictions, y_batch)
# backpropagate and update
loss.backward() # backprop
optimizer.step()
# logging to update epoch_loss (add loss value) and epoch_acc (add current batch accuracy)
epoch_loss += loss.item()
epoch_acc += multiclass_model_accuracy(model, x_batch, y_batch)
loss_values.append(epoch_loss/len(train_loader))
train_accuracies.append(epoch_acc/N_train)
# validation performance
val_acc = 0
for x_batch, y_batch in val_loader:
# don't compute gradients since we are only evaluating the model
with torch.no_grad():
# validation batch accuracy
val_acc += multiclass_model_accuracy(model, x_batch, y_batch)
val_accuracies.append(val_acc/N_val)
plt.figure(figsize=(12,6))
plt.subplot(131)
plt.semilogy(loss_values)
plt.grid(True)
plt.title('Loss values')
plt.xlabel('Epoch')
plt.subplot(132)
plt.plot(train_accuracies)
plt.grid(True)
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.subplot(133)
plt.plot(val_accuracies)
plt.grid(True)
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
0%| | 0/300 [00:00<?, ?it/s]
0%|▌ | 1/300 [00:00<01:20, 3.71it/s]
1%|█ | 2/300 [00:00<01:15, 3.93it/s]
1%|█▋ | 3/300 [00:00<01:23, 3.54it/s]
1%|██▏ | 4/300 [00:01<01:58, 2.50it/s]
2%|██▊ | 5/300 [00:01<02:10, 2.25it/s]
2%|███▎ | 6/300 [00:02<01:55, 2.54it/s]
2%|███▊ | 7/300 [00:02<01:44, 2.79it/s]
3%|████▍ | 8/300 [00:02<01:38, 2.97it/s]
3%|████▉ | 9/300 [00:03<01:35, 3.06it/s]
3%|█████▍ | 10/300 [00:03<01:34, 3.06it/s]
4%|██████ | 11/300 [00:03<01:31, 3.15it/s]
4%|██████▌ | 12/300 [00:04<01:28, 3.25it/s]
4%|███████ | 13/300 [00:04<01:30, 3.16it/s]
5%|███████▋ | 14/300 [00:04<01:28, 3.24it/s]
5%|████████▏ | 15/300 [00:04<01:26, 3.31it/s]
5%|████████▋ | 16/300 [00:05<01:24, 3.34it/s]
6%|█████████▎ | 17/300 [00:05<01:23, 3.40it/s]
6%|█████████▊ | 18/300 [00:05<01:22, 3.43it/s]
6%|██████████▍ | 19/300 [00:06<01:20, 3.47it/s]
7%|██████████▉ | 20/300 [00:06<01:21, 3.44it/s]
7%|███████████▍ | 21/300 [00:06<01:24, 3.29it/s]
7%|████████████ | 22/300 [00:07<01:24, 3.30it/s]
8%|████████████▌ | 23/300 [00:07<01:24, 3.27it/s]
8%|█████████████ | 24/300 [00:07<01:26, 3.18it/s]
8%|█████████████▋ | 25/300 [00:08<01:30, 3.06it/s]
9%|██████████████▏ | 26/300 [00:08<01:31, 2.98it/s]
9%|██████████████▊ | 27/300 [00:08<01:31, 2.98it/s]
9%|███████████████▎ | 28/300 [00:09<01:29, 3.04it/s]
10%|███████████████▊ | 29/300 [00:09<01:32, 2.93it/s]
10%|████████████████▍ | 30/300 [00:09<01:35, 2.82it/s]
10%|████████████████▉ | 31/300 [00:10<01:46, 2.53it/s]
11%|█████████████████▍ | 32/300 [00:10<01:54, 2.35it/s]
11%|██████████████████ | 33/300 [00:11<02:04, 2.14it/s]
11%|██████████████████▌ | 34/300 [00:11<02:07, 2.08it/s]
12%|███████████████████▏ | 35/300 [00:12<02:08, 2.06it/s]
12%|███████████████████▋ | 36/300 [00:12<02:01, 2.18it/s]
12%|████████████████████▏ | 37/300 [00:13<02:00, 2.18it/s]
13%|████████████████████▊ | 38/300 [00:13<02:07, 2.06it/s]
13%|█████████████████████▎ | 39/300 [00:14<02:15, 1.93it/s]
13%|█████████████████████▊ | 40/300 [00:14<02:09, 2.00it/s]
14%|██████████████████████▍ | 41/300 [00:15<02:20, 1.84it/s]
14%|██████████████████████▉ | 42/300 [00:15<02:11, 1.96it/s]
14%|███████████████████████▌ | 43/300 [00:16<01:57, 2.19it/s]
15%|████████████████████████ | 44/300 [00:16<01:47, 2.39it/s]
15%|████████████████████████▌ | 45/300 [00:16<01:42, 2.48it/s]
15%|█████████████████████████▏ | 46/300 [00:17<01:36, 2.63it/s]
16%|█████████████████████████▋ | 47/300 [00:17<01:33, 2.72it/s]
16%|██████████████████████████▏ | 48/300 [00:17<01:38, 2.56it/s]
16%|██████████████████████████▊ | 49/300 [00:18<01:34, 2.66it/s]
17%|███████████████████████████▎ | 50/300 [00:18<01:31, 2.74it/s]
17%|███████████████████████████▉ | 51/300 [00:18<01:27, 2.85it/s]
17%|████████████████████████████▍ | 52/300 [00:19<01:24, 2.94it/s]
18%|████████████████████████████▉ | 53/300 [00:19<01:23, 2.95it/s]
18%|█████████████████████████████▌ | 54/300 [00:20<01:29, 2.75it/s]
18%|██████████████████████████████ | 55/300 [00:20<01:29, 2.73it/s]
19%|██████████████████████████████▌ | 56/300 [00:20<01:32, 2.64it/s]
19%|███████████████████████████████▏ | 57/300 [00:21<01:30, 2.68it/s]
19%|███████████████████████████████▋ | 58/300 [00:21<01:27, 2.77it/s]
20%|████████████████████████████████▎ | 59/300 [00:21<01:26, 2.80it/s]
20%|████████████████████████████████▊ | 60/300 [00:22<01:24, 2.84it/s]
20%|█████████████████████████████████▎ | 61/300 [00:22<01:25, 2.79it/s]
21%|█████████████████████████████████▉ | 62/300 [00:22<01:24, 2.83it/s]
21%|██████████████████████████████████▍ | 63/300 [00:23<01:23, 2.83it/s]
21%|██████████████████████████████████▉ | 64/300 [00:23<01:22, 2.88it/s]
22%|███████████████████████████████████▌ | 65/300 [00:23<01:19, 2.95it/s]
22%|████████████████████████████████████ | 66/300 [00:24<01:20, 2.89it/s]
22%|████████████████████████████████████▋ | 67/300 [00:24<01:21, 2.87it/s]
23%|█████████████████████████████████████▏ | 68/300 [00:25<01:24, 2.75it/s]
23%|█████████████████████████████████████▋ | 69/300 [00:25<01:25, 2.70it/s]
23%|██████████████████████████████████████▎ | 70/300 [00:25<01:33, 2.46it/s]
24%|██████████████████████████████████████▊ | 71/300 [00:26<01:36, 2.37it/s]
24%|███████████████████████████████████████▎ | 72/300 [00:26<01:39, 2.30it/s]
24%|███████████████████████████████████████▉ | 73/300 [00:27<01:42, 2.21it/s]
25%|████████████████████████████████████████▍ | 74/300 [00:27<01:39, 2.26it/s]
25%|█████████████████████████████████████████ | 75/300 [00:28<01:38, 2.28it/s]
25%|█████████████████████████████████████████▌ | 76/300 [00:28<01:35, 2.36it/s]
26%|██████████████████████████████████████████ | 77/300 [00:28<01:29, 2.49it/s]
26%|██████████████████████████████████████████▋ | 78/300 [00:29<01:26, 2.56it/s]
26%|███████████████████████████████████████████▏ | 79/300 [00:29<01:21, 2.70it/s]
27%|███████████████████████████████████████████▋ | 80/300 [00:29<01:18, 2.82it/s]
27%|████████████████████████████████████████████▎ | 81/300 [00:30<01:18, 2.79it/s]
27%|████████████████████████████████████████████▊ | 82/300 [00:30<01:22, 2.65it/s]
28%|█████████████████████████████████████████████▎ | 83/300 [00:31<01:23, 2.60it/s]
28%|█████████████████████████████████████████████▉ | 84/300 [00:31<01:26, 2.51it/s]
28%|██████████████████████████████████████████████▍ | 85/300 [00:31<01:22, 2.61it/s]
29%|███████████████████████████████████████████████ | 86/300 [00:32<01:22, 2.59it/s]
29%|███████████████████████████████████████████████▌ | 87/300 [00:32<01:25, 2.49it/s]
29%|████████████████████████████████████████████████ | 88/300 [00:33<01:24, 2.52it/s]
30%|████████████████████████████████████████████████▋ | 89/300 [00:33<01:20, 2.63it/s]
30%|█████████████████████████████████████████████████▏ | 90/300 [00:33<01:16, 2.73it/s]
30%|█████████████████████████████████████████████████▋ | 91/300 [00:34<01:14, 2.82it/s]
31%|██████████████████████████████████████████████████▎ | 92/300 [00:34<01:10, 2.93it/s]
31%|██████████████████████████████████████████████████▊ | 93/300 [00:34<01:09, 2.96it/s]
31%|███████████████████████████████████████████████████▍ | 94/300 [00:35<01:08, 2.99it/s]
32%|███████████████████████████████████████████████████▉ | 95/300 [00:35<01:06, 3.06it/s]
32%|████████████████████████████████████████████████████▍ | 96/300 [00:35<01:04, 3.16it/s]
32%|█████████████████████████████████████████████████████ | 97/300 [00:36<01:02, 3.24it/s]
33%|█████████████████████████████████████████████████████▌ | 98/300 [00:36<01:00, 3.34it/s]
33%|██████████████████████████████████████████████████████ | 99/300 [00:36<00:57, 3.50it/s]
33%|██████████████████████████████████████████████████████▎ | 100/300 [00:36<00:55, 3.64it/s]
34%|██████████████████████████████████████████████████████▉ | 101/300 [00:37<00:57, 3.48it/s]
34%|███████████████████████████████████████████████████████▍ | 102/300 [00:37<00:58, 3.37it/s]
34%|███████████████████████████████████████████████████████▉ | 103/300 [00:37<00:57, 3.41it/s]
35%|████████████████████████████████████████████████████████▌ | 104/300 [00:38<00:59, 3.30it/s]
35%|█████████████████████████████████████████████████████████ | 105/300 [00:38<01:02, 3.10it/s]
35%|█████████████████████████████████████████████████████████▌ | 106/300 [00:38<01:02, 3.10it/s]
36%|██████████████████████████████████████████████████████████▏ | 107/300 [00:39<01:00, 3.18it/s]
36%|██████████████████████████████████████████████████████████▋ | 108/300 [00:39<00:59, 3.22it/s]
36%|███████████████████████████████████████████████████████████▏ | 109/300 [00:39<00:57, 3.33it/s]
37%|███████████████████████████████████████████████████████████▊ | 110/300 [00:39<00:55, 3.39it/s]
37%|████████████████████████████████████████████████████████████▎ | 111/300 [00:40<00:54, 3.49it/s]
37%|████████████████████████████████████████████████████████████▊ | 112/300 [00:40<00:52, 3.59it/s]
38%|█████████████████████████████████████████████████████████████▍ | 113/300 [00:40<00:51, 3.63it/s]
38%|█████████████████████████████████████████████████████████████▉ | 114/300 [00:40<00:51, 3.58it/s]
38%|██████████████████████████████████████████████████████████████▍ | 115/300 [00:41<00:55, 3.30it/s]
39%|███████████████████████████████████████████████████████████████ | 116/300 [00:41<00:55, 3.34it/s]
39%|███████████████████████████████████████████████████████████████▌ | 117/300 [00:41<00:56, 3.25it/s]
39%|████████████████████████████████████████████████████████████████ | 118/300 [00:42<00:58, 3.11it/s]
40%|████████████████████████████████████████████████████████████████▋ | 119/300 [00:42<00:58, 3.08it/s]
40%|█████████████████████████████████████████████████████████████████▏ | 120/300 [00:42<00:57, 3.12it/s]
40%|█████████████████████████████████████████████████████████████████▋ | 121/300 [00:43<00:57, 3.11it/s]
41%|██████████████████████████████████████████████████████████████████▎ | 122/300 [00:43<00:56, 3.16it/s]
41%|██████████████████████████████████████████████████████████████████▊ | 123/300 [00:43<00:55, 3.17it/s]
41%|███████████████████████████████████████████████████████████████████▎ | 124/300 [00:44<00:57, 3.04it/s]
42%|███████████████████████████████████████████████████████████████████▉ | 125/300 [00:44<00:56, 3.10it/s]
42%|████████████████████████████████████████████████████████████████████▍ | 126/300 [00:44<01:01, 2.81it/s]
42%|█████████████████████████████████████████████████████████████████████ | 127/300 [00:45<01:01, 2.79it/s]
43%|█████████████████████████████████████████████████████████████████████▌ | 128/300 [00:45<01:01, 2.79it/s]
43%|██████████████████████████████████████████████████████████████████████ | 129/300 [00:45<00:58, 2.94it/s]
43%|██████████████████████████████████████████████████████████████████████▋ | 130/300 [00:46<00:56, 2.98it/s]
44%|███████████████████████████████████████████████████████████████████████▏ | 131/300 [00:46<00:55, 3.03it/s]
44%|███████████████████████████████████████████████████████████████████████▋ | 132/300 [00:46<00:54, 3.09it/s]
44%|████████████████████████████████████████████████████████████████████████▎ | 133/300 [00:47<00:51, 3.27it/s]
45%|████████████████████████████████████████████████████████████████████████▊ | 134/300 [00:47<00:48, 3.43it/s]
45%|█████████████████████████████████████████████████████████████████████████▎ | 135/300 [00:47<00:46, 3.52it/s]
45%|█████████████████████████████████████████████████████████████████████████▉ | 136/300 [00:48<00:46, 3.55it/s]
46%|██████████████████████████████████████████████████████████████████████████▍ | 137/300 [00:48<00:44, 3.67it/s]
46%|██████████████████████████████████████████████████████████████████████████▉ | 138/300 [00:48<00:43, 3.76it/s]
46%|███████████████████████████████████████████████████████████████████████████▌ | 139/300 [00:48<00:42, 3.75it/s]
47%|████████████████████████████████████████████████████████████████████████████ | 140/300 [00:49<00:42, 3.76it/s]
47%|████████████████████████████████████████████████████████████████████████████▌ | 141/300 [00:49<00:42, 3.78it/s]
47%|█████████████████████████████████████████████████████████████████████████████▏ | 142/300 [00:49<00:42, 3.71it/s]
48%|█████████████████████████████████████████████████████████████████████████████▋ | 143/300 [00:49<00:43, 3.63it/s]
48%|██████████████████████████████████████████████████████████████████████████████▏ | 144/300 [00:50<00:42, 3.63it/s]
48%|██████████████████████████████████████████████████████████████████████████████▊ | 145/300 [00:50<00:42, 3.61it/s]
49%|███████████████████████████████████████████████████████████████████████████████▎ | 146/300 [00:50<00:41, 3.67it/s]
49%|███████████████████████████████████████████████████████████████████████████████▊ | 147/300 [00:50<00:41, 3.66it/s]
49%|████████████████████████████████████████████████████████████████████████████████▍ | 148/300 [00:51<00:41, 3.65it/s]
50%|████████████████████████████████████████████████████████████████████████████████▉ | 149/300 [00:51<00:41, 3.64it/s]
50%|█████████████████████████████████████████████████████████████████████████████████▌ | 150/300 [00:51<00:41, 3.65it/s]
50%|██████████████████████████████████████████████████████████████████████████████████ | 151/300 [00:52<00:40, 3.71it/s]
51%|██████████████████████████████████████████████████████████████████████████████████▌ | 152/300 [00:52<00:41, 3.61it/s]
51%|███████████████████████████████████████████████████████████████████████████████████▏ | 153/300 [00:52<00:40, 3.67it/s]
51%|███████████████████████████████████████████████████████████████████████████████████▋ | 154/300 [00:52<00:39, 3.67it/s]
52%|████████████████████████████████████████████████████████████████████████████████████▏ | 155/300 [00:53<00:38, 3.73it/s]
52%|████████████████████████████████████████████████████████████████████████████████████▊ | 156/300 [00:53<00:38, 3.74it/s]
52%|█████████████████████████████████████████████████████████████████████████████████████▎ | 157/300 [00:53<00:37, 3.77it/s]
53%|█████████████████████████████████████████████████████████████████████████████████████▊ | 158/300 [00:53<00:37, 3.75it/s]
53%|██████████████████████████████████████████████████████████████████████████████████████▍ | 159/300 [00:54<00:37, 3.79it/s]
53%|██████████████████████████████████████████████████████████████████████████████████████▉ | 160/300 [00:54<00:36, 3.84it/s]
54%|███████████████████████████████████████████████████████████████████████████████████████▍ | 161/300 [00:54<00:35, 3.87it/s]
54%|████████████████████████████████████████████████████████████████████████████████████████ | 162/300 [00:54<00:36, 3.79it/s]
54%|████████████████████████████████████████████████████████████████████████████████████████▌ | 163/300 [00:55<00:36, 3.77it/s]
55%|█████████████████████████████████████████████████████████████████████████████████████████ | 164/300 [00:55<00:35, 3.80it/s]
55%|█████████████████████████████████████████████████████████████████████████████████████████▋ | 165/300 [00:55<00:36, 3.70it/s]
55%|██████████████████████████████████████████████████████████████████████████████████████████▏ | 166/300 [00:56<00:38, 3.44it/s]
56%|██████████████████████████████████████████████████████████████████████████████████████████▋ | 167/300 [00:56<00:40, 3.28it/s]
56%|███████████████████████████████████████████████████████████████████████████████████████████▎ | 168/300 [00:56<00:39, 3.33it/s]
56%|███████████████████████████████████████████████████████████████████████████████████████████▊ | 169/300 [00:57<00:39, 3.35it/s]
57%|████████████████████████████████████████████████████████████████████████████████████████████▎ | 170/300 [00:57<00:37, 3.46it/s]
57%|████████████████████████████████████████████████████████████████████████████████████████████▉ | 171/300 [00:57<00:36, 3.52it/s]
57%|█████████████████████████████████████████████████████████████████████████████████████████████▍ | 172/300 [00:57<00:36, 3.50it/s]
58%|█████████████████████████████████████████████████████████████████████████████████████████████▉ | 173/300 [00:58<00:36, 3.51it/s]
58%|██████████████████████████████████████████████████████████████████████████████████████████████▌ | 174/300 [00:58<00:35, 3.58it/s]
58%|███████████████████████████████████████████████████████████████████████████████████████████████ | 175/300 [00:58<00:34, 3.61it/s]
59%|███████████████████████████████████████████████████████████████████████████████████████████████▋ | 176/300 [00:58<00:34, 3.63it/s]
59%|████████████████████████████████████████████████████████████████████████████████████████████████▏ | 177/300 [00:59<00:33, 3.67it/s]
59%|████████████████████████████████████████████████████████████████████████████████████████████████▋ | 178/300 [00:59<00:32, 3.73it/s]
60%|█████████████████████████████████████████████████████████████████████████████████████████████████▎ | 179/300 [00:59<00:32, 3.74it/s]
60%|█████████████████████████████████████████████████████████████████████████████████████████████████▊ | 180/300 [01:00<00:31, 3.75it/s]
60%|██████████████████████████████████████████████████████████████████████████████████████████████████▎ | 181/300 [01:00<00:31, 3.74it/s]
61%|██████████████████████████████████████████████████████████████████████████████████████████████████▉ | 182/300 [01:00<00:30, 3.81it/s]
61%|███████████████████████████████████████████████████████████████████████████████████████████████████▍ | 183/300 [01:00<00:31, 3.67it/s]
61%|███████████████████████████████████████████████████████████████████████████████████████████████████▉ | 184/300 [01:01<00:32, 3.61it/s]
62%|████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 185/300 [01:01<00:32, 3.54it/s]
62%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 186/300 [01:01<00:31, 3.58it/s]
62%|█████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 187/300 [01:01<00:31, 3.63it/s]
63%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 188/300 [01:02<00:30, 3.64it/s]
63%|██████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 189/300 [01:02<00:30, 3.65it/s]
63%|███████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 190/300 [01:02<00:30, 3.61it/s]
64%|███████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 191/300 [01:03<00:29, 3.66it/s]
64%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 192/300 [01:03<00:30, 3.60it/s]
64%|████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 193/300 [01:03<00:29, 3.68it/s]
65%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 194/300 [01:03<00:30, 3.48it/s]
65%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 195/300 [01:04<00:29, 3.61it/s]
65%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 196/300 [01:04<00:28, 3.67it/s]
66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 197/300 [01:04<00:27, 3.79it/s]
66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 198/300 [01:04<00:26, 3.87it/s]
66%|████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 199/300 [01:05<00:25, 3.91it/s]
67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 200/300 [01:05<00:26, 3.80it/s]
67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 201/300 [01:05<00:26, 3.74it/s]
67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 202/300 [01:06<00:26, 3.65it/s]
68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 203/300 [01:06<00:26, 3.71it/s]
68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 204/300 [01:06<00:25, 3.73it/s]
68%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 205/300 [01:06<00:25, 3.66it/s]
69%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 206/300 [01:07<00:25, 3.62it/s]
69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 207/300 [01:07<00:26, 3.57it/s]
69%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 208/300 [01:07<00:25, 3.66it/s]
70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 209/300 [01:07<00:25, 3.59it/s]
70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 210/300 [01:08<00:24, 3.70it/s]
70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 211/300 [01:08<00:23, 3.75it/s]
71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 212/300 [01:08<00:23, 3.76it/s]
71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 213/300 [01:09<00:23, 3.69it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 214/300 [01:09<00:22, 3.76it/s]
72%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 215/300 [01:09<00:22, 3.78it/s]
72%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 216/300 [01:09<00:21, 3.83it/s]
72%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 217/300 [01:10<00:21, 3.80it/s]
73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 218/300 [01:10<00:21, 3.73it/s]
73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 219/300 [01:10<00:21, 3.73it/s]
73%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 220/300 [01:10<00:21, 3.75it/s]
74%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 221/300 [01:11<00:22, 3.54it/s]
74%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 222/300 [01:11<00:21, 3.59it/s]
74%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 223/300 [01:11<00:21, 3.63it/s]
75%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 224/300 [01:11<00:20, 3.67it/s]
75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 225/300 [01:12<00:20, 3.68it/s]
75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 226/300 [01:12<00:19, 3.70it/s]
76%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 227/300 [01:12<00:19, 3.69it/s]
76%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 228/300 [01:13<00:19, 3.78it/s]
76%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 229/300 [01:13<00:18, 3.84it/s]
77%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 230/300 [01:13<00:17, 3.91it/s]
77%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 231/300 [01:13<00:17, 3.94it/s]
77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 232/300 [01:14<00:17, 3.91it/s]
78%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 233/300 [01:14<00:16, 3.96it/s]
78%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 234/300 [01:14<00:16, 3.98it/s]
78%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 235/300 [01:14<00:16, 3.98it/s]
79%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 236/300 [01:15<00:16, 3.99it/s]
79%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 237/300 [01:15<00:16, 3.92it/s]
79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 238/300 [01:15<00:15, 3.96it/s]
80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 239/300 [01:15<00:15, 4.00it/s]
80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 240/300 [01:16<00:14, 4.01it/s]
80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 241/300 [01:16<00:14, 4.01it/s]
81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 242/300 [01:16<00:14, 4.03it/s]
81%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 243/300 [01:16<00:14, 4.02it/s]
81%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 244/300 [01:17<00:14, 3.97it/s]
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 245/300 [01:17<00:14, 3.88it/s]
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 246/300 [01:17<00:14, 3.84it/s]
82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 247/300 [01:17<00:13, 3.85it/s]
83%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 248/300 [01:18<00:13, 3.87it/s]
83%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 249/300 [01:18<00:12, 3.92it/s]
83%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 250/300 [01:18<00:12, 3.94it/s]
84%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 251/300 [01:18<00:12, 3.94it/s]
84%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 252/300 [01:19<00:12, 3.89it/s]
84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 253/300 [01:19<00:11, 3.93it/s]
85%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 254/300 [01:19<00:11, 3.87it/s]
85%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 255/300 [01:19<00:12, 3.57it/s]
85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 256/300 [01:20<00:12, 3.60it/s]
86%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 257/300 [01:20<00:11, 3.67it/s]
86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 258/300 [01:20<00:11, 3.76it/s]
86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 259/300 [01:21<00:10, 3.83it/s]
87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 260/300 [01:21<00:10, 3.82it/s]
87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 261/300 [01:21<00:10, 3.85it/s]
87%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 262/300 [01:21<00:09, 3.91it/s]
88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 263/300 [01:22<00:09, 3.95it/s]
88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 264/300 [01:22<00:09, 3.95it/s]
88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 265/300 [01:22<00:08, 3.98it/s]
89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 266/300 [01:22<00:08, 3.92it/s]
89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 267/300 [01:23<00:08, 3.80it/s]
89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 268/300 [01:23<00:08, 3.68it/s]
90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 269/300 [01:23<00:08, 3.74it/s]
90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 270/300 [01:23<00:07, 3.82it/s]
90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 271/300 [01:24<00:07, 3.86it/s]
91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 272/300 [01:24<00:07, 3.91it/s]
91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 273/300 [01:24<00:06, 3.96it/s]
91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 274/300 [01:24<00:06, 3.98it/s]
92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 275/300 [01:25<00:06, 3.93it/s]
92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 276/300 [01:25<00:06, 3.87it/s]
92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 277/300 [01:25<00:05, 3.89it/s]
93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 278/300 [01:25<00:05, 3.90it/s]
93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 279/300 [01:26<00:05, 3.86it/s]
93%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 280/300 [01:26<00:05, 3.90it/s]
94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 281/300 [01:26<00:04, 3.98it/s]
94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 282/300 [01:26<00:04, 4.03it/s]
94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 283/300 [01:27<00:04, 4.06it/s]
95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 284/300 [01:27<00:04, 3.97it/s]
95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 285/300 [01:27<00:03, 3.86it/s]
95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 286/300 [01:27<00:03, 3.66it/s]
96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 287/300 [01:28<00:03, 3.64it/s]
96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 288/300 [01:28<00:03, 3.65it/s]
96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 289/300 [01:28<00:02, 3.70it/s]
97%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 290/300 [01:29<00:02, 3.77it/s]
97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 291/300 [01:29<00:02, 3.68it/s]
97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 292/300 [01:29<00:02, 3.44it/s]
98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 293/300 [01:29<00:01, 3.56it/s]
98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 294/300 [01:30<00:01, 3.67it/s]
98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 295/300 [01:30<00:01, 3.72it/s]
99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 296/300 [01:30<00:01, 3.76it/s]
99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 297/300 [01:30<00:00, 3.82it/s]
99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 298/300 [01:31<00:00, 3.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍| 299/300 [01:31<00:00, 3.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [01:31<00:00, 3.90it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [01:31<00:00, 3.27it/s]
Text(0.5, 0, 'Epoch')
There's a problem¶
Deep nets seem pretty great right? So why do we need anything more? Well.....
Issue with fully connected layers:
- Suppose the input is an image of size $256 \times 256$
- Let the output of this layer have identical size
- How many weights are necessary?
- That's as big as a lot of smaller LLMs and I can promise you it can't do nearly as much... Better methods needed ... which we'll discuss next time.
That's all for today¶
- Homework and midterm grades will be posted in the next couple days.
- Project description will be posted within the week.
- Will talk about convolutional neural nets next time.