pytorch - nn

Lecture 25

Dr. Colin Rundel

Odds & Ends

Torch models

Implementation details:

  • Models are implemented as a class inheriting from torch.nn.Module

  • Must implement constructor and forward() method

    • __init__() should call parent constructor via super()

      • Use torch.nn.Parameter() to indicate model parameters
    • forward() should implement the model - constants + parameters -> return predictions

Fitting procedure:

  • For each iteration of solver:

    • Get current predictions via a call to forward() or equivalent.

    • Calculate a (scalar) loss or equivalent

    • Call backward() method on loss

    • Use built-in optimizer (step() and then zero_grad() if necessary)

From last time

class Model(torch.nn.Module):
    def __init__(self, X, y, beta=None):
        super().__init__()
        self.X = X
        self.y = y
        if beta is None:
          beta = torch.zeros(X.shape[1])
        beta.requires_grad = True
        self.beta = torch.nn.Parameter(beta)
        
    def forward(self, X):
        return X @ self.beta
    
    def fit(self, opt, n=1000, loss_fn = torch.nn.MSELoss()):
      losses = []
      for i in range(n):
          loss = loss_fn(
            self(self.X).squeeze(), 
            self.y.squeeze()
          )
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

What is self(self.X)?

This is (mostly) just short hand for calling self.forward(X) to generate the output tensor from the current values of the parameters.

This is done via the __call__() method in the torch.nn.Module class.

__call__() allows Python classes to be invoked like functions.


class greet:
  def __init__(self, greeting):
    self.greeting = greeting
  def __call__(self, name):
    return self.greeting + " " + name
hello = greet("Hello")
hello("Jane")
'Hello Jane'
gm = greet("Good morning")
gm("Bob")
'Good morning Bob'

MNIST & Logistic models

MNIST handwritten digits - simplified

from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
X.shape
(1797, 64)
X[0:2]
array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,
         0.,  0.,  0., 13., 15., 10., 15.,
         5.,  0.,  0.,  3., 15.,  2.,  0.,
        11.,  8.,  0.,  0.,  4., 12.,  0.,
         0.,  8.,  8.,  0.,  0.,  5.,  8.,
         0.,  0.,  9.,  8.,  0.,  0.,  4.,
        11.,  0.,  1., 12.,  7.,  0.,  0.,
         2., 14.,  5., 10., 12.,  0.,  0.,
         0.,  0.,  6., 13., 10.,  0.,  0.,
         0.],
       [ 0.,  0.,  0., 12., 13.,  5.,  0.,
         0.,  0.,  0.,  0., 11., 16.,  9.,
         0.,  0.,  0.,  0.,  3., 15., 16.,
         6.,  0.,  0.,  0.,  7., 15., 16.,
        16.,  2.,  0.,  0.,  0.,  0.,  1.,
        16., 16.,  3.,  0.,  0.,  0.,  0.,
         1., 16., 16.,  6.,  0.,  0.,  0.,
         0.,  1., 16., 16.,  6.,  0.,  0.,
         0.,  0.,  0., 11., 16., 10.,  0.,
         0.]])
y = digits.target
y.shape
(1797,)
y[0:10]
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Example digits

Test train split

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, shuffle=True, random_state=1234
)
X_train.shape
(1437, 64)
y_train.shape
(1437,)
X_test.shape
(360, 64)
y_test.shape
(360,)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
lr = LogisticRegression(
  penalty=None
).fit(
  X_train, y_train
)
accuracy_score(y_train, lr.predict(X_train))
1.0
accuracy_score(y_test, lr.predict(X_test))
0.9583333333333334

As Torch tensors

X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test)
X_train.shape
torch.Size([1437, 64])
y_train.shape
torch.Size([1437])
X_test.shape
torch.Size([360, 64])
y_test.shape
torch.Size([360])
X_train.dtype
torch.float32
y_train.dtype
torch.int64
X_test.dtype
torch.float32
y_test.dtype
torch.int64

PyTorch Model

class mnist_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.beta = torch.nn.Parameter(
          torch.randn(input_dim, output_dim, requires_grad=True)  
        )
        self.intercept = torch.nn.Parameter(
          torch.randn(output_dim, requires_grad=True)  
        )
        
    def forward(self, X):
        return (X @ self.beta + self.intercept).squeeze()
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses = []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
      
      return losses

Cross entropy loss

From the pytorch documentation:

\[ \ell(x, y)=L=\left\{l_1, \ldots, l_N\right\}^{\top}, \quad l_n=-w_{y_n} \log \frac{\exp \left(x_{n, y_n}\right)}{\sum_{c=1}^C \exp \left(x_{n, c}\right)} 1\left\{y_n \neq \text { ignore_index }\right\} \]

which is then aggregated across the \(N\) observations according to the reduction argument:

\[ \ell(x, y)= \begin{cases}\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot 1\left\{y_n \neq \text { ignore_index }\right\}} l_n, & \text { if reduction }=\text { 'mean' } \\ \sum_{n=1}^N l_n, & \text { if reduction }=\text { 'sum' }\end{cases} \]

This is just the negative log-likelihood of the multinomial regression model,

\[ L(y_n = c \mid x_n) = \frac{\exp(x_n^\top \beta_c)}{\sum_{k=1}^C \exp(x_n^\top \beta_k)}, \quad c = 1, \ldots, C \]

Cross entropy loss

model = mnist_model(64, 10)
l = model.fit(X_train, y_train, X_test, y_test)

Out-of-sample accuracy

model(X_test)
tensor([[-5.0910e+01, -1.5698e+01,
          3.5842e+01,  3.6857e+01,
         -1.7779e+01,  1.7211e+01,
         -4.3528e+00,  7.7909e+01,
         -1.1575e+01,  3.1172e+01],
        [ 3.3251e+01,  6.8993e+01,
         -2.7056e+01,  3.2008e+01,
          9.9695e+00, -8.9016e+00,
         -1.5198e+01, -5.1730e+00,
          3.3374e+01,  5.8726e+01],
        [-6.6533e+01, -1.5383e+01,
         -5.2091e-01,  2.7142e+01,
         -2.6137e+01, -2.1673e+01,
         -3.3514e+01,  6.8855e+01,
         -1.5247e+01,  1.9311e+00],
        [ 1.2693e+00,  1.2027e+01,
          2.1424e+01, -2.2020e+01,
          2.7276e+01,  3.5974e+00,
          6.3090e+01, -4.7289e+01,
          3.5741e+01, -5.4434e+01],
        [ 5.5127e+01,  1.2813e+01,
         -4.4139e-01, -3.5944e+01,
          1.3583e+01, -2.4839e+01,
          4.8678e+00, -2.8788e+01,
          3.0849e+01,  4.0380e+01],
        [-3.7835e+01,  4.4548e+01,
          8.3875e+01,  1.8941e+01,
          6.4439e+00,  1.1204e+01,
         -3.0637e+00, -1.5577e+01,
         -2.0800e+00,  1.6805e+00],
        [ 2.9283e+00, -1.4431e+01,
         -1.9107e+01, -4.7868e+01,
          6.4153e+01,  3.1348e+01,
          5.0017e+01, -3.9188e+00,
         -1.1052e+01, -3.6784e+01],
        [ 1.3177e+01,  1.6293e+01,
          3.8449e+01,  9.8833e+01,
         -6.1536e+01,  1.3897e+01,
         -1.8214e+01,  7.6538e+00,
          2.4648e+01,  4.4004e+01],
        [ 1.3880e+01,  3.4395e+01,
         -1.4862e+01,  1.1186e+00,
         -2.2615e+01, -1.6240e+01,
          7.8503e+01, -4.9984e+01,
          5.2967e+01, -9.6603e+00],
        [-1.4718e+01,  3.4677e+01,
          2.8561e+01,  6.1248e+01,
         -4.1203e+01, -1.5874e+01,
         -3.0135e+01,  6.5283e+00,
          2.3837e+01, -1.7946e+00],
        [-3.5603e+01,  2.2292e+00,
          2.5780e+01,  2.0052e+01,
          2.2534e+01, -2.9703e+01,
          1.8374e+01,  8.5798e+01,
          2.2109e+01, -1.8283e+01],
        [ 1.2303e+01,  5.1233e+01,
          2.6455e+01,  2.0425e+01,
          5.7182e+01,  1.3774e+01,
          2.4755e+01, -2.4699e+01,
          1.0360e+02,  1.9102e+01],
        [-2.5198e+01,  2.7179e+01,
          2.1665e+01,  3.6961e+01,
          6.0677e+00,  1.1290e+01,
          1.1957e+01,  7.9102e+01,
          1.0728e+01,  2.2449e+01],
        [ 2.5018e+01,  1.7624e+01,
          1.0029e+01,  2.6456e+01,
         -1.1258e+01,  1.5435e+01,
         -2.2917e+01, -1.4933e+01,
          2.8880e+01,  7.3959e+01],
        [ 6.6694e+00,  3.3886e+01,
         -5.6576e+01, -2.8641e+01,
          9.7140e+01, -3.4839e+01,
          2.0302e+01, -3.2983e+01,
          3.1235e+01, -6.5002e+00],
        [-1.1734e+00, -4.5216e+00,
          1.4943e+01,  4.7478e+01,
         -3.9329e+01, -5.2468e+00,
         -2.3627e+01,  1.4382e-01,
          3.0478e+01,  2.3987e+01],
        [-1.4882e+01,  3.2971e+01,
          1.6126e+01,  3.4246e+01,
          1.4154e+00, -2.5364e+01,
          3.2972e+00,  1.5488e+01,
          3.6623e+01, -1.7405e+01],
        [-3.2033e+01,  1.2189e+01,
         -6.6836e+00,  3.9826e+01,
         -3.7389e+01,  1.4217e+01,
         -4.2356e+01,  1.1143e+02,
          9.2111e+00,  9.9521e+00],
        [ 1.2560e+00,  3.1731e+01,
          1.6956e+01,  2.2329e+01,
          1.8662e+01,  1.7101e+01,
          1.0070e+01, -2.9944e+01,
          9.6458e+01,  1.6448e+00],
        [ 3.7476e+01, -9.7911e-01,
         -7.3799e+01, -6.6920e+01,
          9.4661e+01, -1.8939e+01,
          2.5634e+01, -6.8201e+01,
          9.7981e+00, -2.6029e+01],
        [ 9.0227e+01, -8.1650e+00,
          2.8885e+01, -3.3427e+00,
          1.9529e+01,  1.9936e+01,
         -7.5909e+00, -2.0029e+01,
          1.7793e+01, -1.6608e+01],
        [ 5.1244e+00, -1.6384e+01,
          5.9660e+00,  7.7560e+01,
         -4.2583e+01,  1.4090e+01,
         -1.2442e+01, -1.7171e+01,
          3.7229e+00,  4.0405e+01],
        [ 2.0481e+01, -1.2045e+01,
         -1.6732e+01,  1.3923e+01,
         -1.7844e+01,  3.2814e+01,
         -2.0890e+01, -6.3442e+01,
          2.1234e+01,  6.4403e+01],
        [-5.3314e+01,  7.7502e+01,
          5.6783e+01,  4.2800e+01,
          5.6050e+00,  4.8668e+01,
          4.6430e+01, -5.0161e+01,
          4.4233e+01, -6.0706e-01],
        [-3.0897e+00,  2.3627e+01,
          2.4166e+01,  8.0352e+01,
          2.1047e+01,  9.7716e+00,
         -2.2670e+01,  2.9227e+01,
          3.5720e+00,  9.0788e+00],
        [ 8.2682e+00,  3.7964e+01,
         -1.6653e+01, -1.9110e+01,
         -2.6808e+01, -1.9096e+01,
          7.7984e+01, -5.1150e+01,
          4.0238e+01,  1.1806e+01],
        [ 2.3750e+00,  4.0138e+00,
          3.0986e+01,  1.1505e+01,
         -1.6478e+01, -1.5184e+01,
          7.5066e+01, -3.5604e+01,
          3.8345e+01, -3.0667e+01],
        [ 6.6238e+01, -2.5477e+01,
          8.1475e+00,  3.2550e+00,
         -3.2202e+01,  5.8469e+00,
         -1.0144e+01, -4.5645e+01,
          7.2084e+00,  1.1770e+01],
        [ 4.0685e+01,  2.0065e+01,
          1.3171e+01,  4.4713e+00,
          1.9001e+01,  6.7819e+01,
          7.2205e+00, -1.4337e+01,
         -1.1090e-02,  5.3937e+01],
        [ 2.9739e+01,  1.2940e+01,
         -3.4188e+01, -6.8207e+01,
          8.0644e+01,  3.6965e+00,
          1.8758e+01, -4.2682e+01,
         -1.6299e+01, -2.7199e+01],
        ...,
        [ 1.0436e+01,  1.0912e+01,
          9.8693e+00,  8.6534e+01,
         -3.2285e+01,  2.9107e+01,
         -3.7913e+00,  1.6108e+01,
          3.2379e+01,  2.8269e+01],
        [ 3.0497e+01,  1.1552e+01,
         -8.5665e+01, -5.9881e+01,
          1.0409e+02, -2.9302e+01,
          1.0195e+01, -3.6689e+01,
          2.9854e+01, -3.5514e+01],
        [-6.1552e+01,  7.0571e+01,
          2.9022e+01,  2.7282e+01,
          5.0426e+01, -9.2069e+00,
          1.0703e+01,  1.5536e+00,
          1.1948e+01,  9.9230e+00],
        [ 2.6957e+01,  2.2832e+01,
          1.0830e+01, -3.6855e+01,
         -4.0475e+00, -3.6833e+00,
          2.0023e+01, -1.5137e+01,
          4.3144e+01, -3.1485e+01],
        [-4.9082e+01,  4.3296e+01,
          1.8573e+01,  2.1748e+01,
          2.7131e+01, -1.3277e+00,
          5.7522e+00,  1.0218e+01,
          1.2415e+01,  9.3222e+00],
        [ 9.0156e+00,  6.7176e+00,
          1.3034e+01,  2.2376e+01,
         -3.0792e+01, -2.6655e+01,
          1.1933e+01, -3.2102e+01,
          6.6797e+01,  1.9298e+01],
        [ 1.2205e+01,  8.4537e+00,
          3.1098e+01,  1.1742e+01,
         -9.0764e+00,  8.2002e+01,
          2.3341e+01, -1.7799e+01,
          2.3246e+01,  1.4242e+01],
        [ 7.4153e+01, -4.9730e+01,
          2.9613e+01, -4.4428e+00,
         -1.0511e+00,  2.5738e+01,
         -2.3271e+01, -4.6989e+01,
         -2.5687e+00, -6.7190e+00],
        [ 2.3908e+01,  2.1738e+01,
         -5.4869e+01, -1.9974e+01,
          2.4730e+01,  4.4281e+01,
         -3.6533e+01, -3.4675e+01,
          1.0965e+01,  4.6620e+01],
        [-4.9751e-01,  4.9938e+01,
          7.3265e+01,  5.6673e+01,
         -4.8626e+00,  4.2019e+01,
          3.8793e+01, -1.0377e+01,
          1.0890e+01,  4.8674e+00],
        [-5.0727e+01, -4.3911e+01,
         -5.2236e+01, -3.5538e+00,
          1.6656e+01, -1.6025e+01,
         -2.2819e+01,  4.6692e+01,
         -3.8738e+01, -1.8398e+01],
        [-4.2116e+01,  4.5190e+01,
          1.0408e+02,  6.4541e+01,
         -2.1890e+01,  1.3946e+01,
         -1.8347e+01,  9.2573e+00,
          3.2436e+01,  1.3198e+01],
        [ 1.0684e+01,  2.3777e+01,
          3.6666e+01,  9.3215e+01,
         -4.5239e+01,  1.4440e+01,
          7.5432e+00,  7.0002e+00,
          1.2379e+01,  4.9116e+01],
        [ 1.2197e+01,  2.9320e+01,
          3.7280e+01,  3.2814e+00,
         -4.2088e+01,  6.3419e+01,
          2.6777e+01,  2.4096e+01,
          3.7431e+01,  1.8139e+01],
        [-3.8736e+01,  3.8583e+01,
          1.0385e+02,  4.2613e+01,
         -1.6339e+01,  7.7191e+00,
         -1.4653e+01,  4.9580e+00,
          3.5580e+01,  1.0053e+00],
        [-2.7031e+01,  4.8101e+01,
         -7.0057e+00, -8.4088e+00,
         -3.5277e+01, -2.1469e+01,
          8.7112e+01, -5.7668e+01,
          2.4828e+01,  1.4495e+01],
        [ 5.7359e+00,  2.0366e+01,
          2.9890e+01,  1.1092e+02,
         -4.9190e+01,  2.8293e+01,
          1.0455e+00,  1.3256e+01,
          2.2653e+01,  3.5847e+01],
        [ 2.7259e+01, -3.0614e+01,
         -4.5591e+01, -4.2728e+01,
          7.4556e+01,  4.0568e+00,
         -4.5453e+00, -2.5068e+01,
         -5.9977e+00, -1.6386e+01],
        [-5.3639e+01,  4.4665e+01,
          8.9341e+00,  1.4254e+01,
          2.9005e+01, -3.8199e+00,
          1.1113e+01,  1.5891e+01,
          1.1505e+01,  4.7313e+00],
        [ 3.3292e+01,  1.8479e+01,
          3.3855e-02,  2.5886e+01,
          2.5043e+01,  7.9764e+01,
         -6.8903e-01, -3.7174e+01,
          2.7411e+01,  5.0382e+01],
        [ 6.4131e+01, -3.0234e+01,
          1.9579e+01, -1.2356e+01,
         -1.1617e+01,  1.5546e+01,
         -1.6034e+01, -5.4266e+01,
          1.5258e+01, -3.6494e+00],
        [-3.9472e+00,  2.1498e+01,
          1.2156e+01,  9.4088e+00,
         -9.4482e+00,  9.7018e+01,
          2.1464e+01,  8.3984e+00,
          5.1724e+01,  5.1681e+00],
        [ 5.3649e+00, -1.2007e+01,
         -7.9431e+01, -2.2981e+01,
          4.1823e+01,  1.9062e+01,
         -2.8082e+01,  1.8240e+01,
         -4.3887e+01,  1.0434e+01],
        [ 2.8336e+01,  4.4525e+00,
          6.8270e+00, -4.1938e+00,
         -8.8925e+00, -4.7393e+00,
          6.8818e+01, -5.7356e+01,
          3.2612e+01, -3.6327e+01],
        [ 2.8115e+01,  3.8016e+01,
          2.2296e+01,  7.8751e+01,
         -2.1160e+01,  3.9483e+01,
          1.2115e+01, -6.2957e+00,
          2.7490e+01,  4.5235e+01],
        [-1.0052e+01,  2.1980e+01,
          9.6560e+01,  4.5243e+01,
         -3.0159e+01,  3.2215e+01,
         -2.1137e+00, -2.2475e+01,
          4.7805e+01,  1.1469e+01],
        [ 6.4488e+00,  6.6870e+00,
         -1.1110e+01,  4.3325e+01,
         -2.5443e+01,  5.7341e+01,
         -4.8446e+00, -1.9325e+01,
          3.3335e+01,  8.0550e-02],
        [ 5.9689e+01,  1.8064e+01,
         -5.0688e+00, -1.1496e+00,
          2.5541e+01,  9.7262e+00,
          2.9136e+00, -1.0322e+01,
          4.4119e+01,  4.3789e+01],
        [-4.0454e+01,  1.7829e+00,
         -4.1219e+01,  4.3238e+00,
          1.1135e+01,  2.3236e+01,
         -2.4114e+01,  6.1651e+01,
         -4.2060e+01,  2.7119e+01],
        [-2.4672e+00,  1.7440e+01,
          3.7937e+01,  8.1788e+01,
         -1.3622e+01,  2.4583e+01,
          3.1543e+01,  2.8278e+00,
          3.0217e+01,  2.3274e+01]],
       grad_fn=<SqueezeBackward0>)
value, index = torch.max(model(X_test), dim=1)
index
tensor([7, 1, 7, 6, 0, 2, 4, 3, 6, 3, 7, 8, 7,
        9, 4, 3, 8, 7, 8, 4, 0, 3, 9, 1, 3, 6,
        6, 0, 5, 4, 1, 2, 1, 2, 3, 2, 7, 6, 4,
        8, 6, 4, 4, 0, 9, 2, 8, 5, 4, 4, 4, 1,
        7, 6, 8, 2, 9, 8, 8, 0, 4, 3, 1, 8, 8,
        1, 3, 9, 4, 3, 9, 6, 9, 5, 2, 1, 9, 2,
        1, 3, 8, 7, 3, 3, 8, 7, 7, 5, 8, 2, 6,
        1, 9, 1, 6, 4, 5, 2, 2, 4, 5, 4, 7, 6,
        9, 9, 2, 4, 1, 0, 7, 6, 1, 2, 9, 5, 2,
        5, 0, 3, 2, 7, 6, 4, 3, 2, 1, 1, 6, 5,
        6, 2, 5, 4, 7, 5, 0, 9, 1, 0, 5, 6, 7,
        6, 3, 8, 3, 2, 0, 4, 4, 8, 5, 4, 6, 1,
        1, 1, 6, 1, 7, 9, 0, 7, 9, 5, 4, 1, 3,
        8, 6, 4, 7, 1, 5, 7, 4, 7, 4, 2, 2, 2,
        7, 1, 4, 4, 3, 5, 6, 9, 4, 5, 5, 9, 3,
        9, 3, 1, 2, 0, 8, 2, 8, 9, 2, 4, 6, 8,
        3, 8, 1, 0, 8, 1, 8, 5, 6, 8, 7, 1, 4,
        3, 4, 9, 7, 0, 5, 5, 6, 1, 3, 0, 5, 8,
        2, 0, 9, 8, 6, 7, 2, 4, 1, 0, 5, 1, 5,
        1, 6, 4, 7, 1, 2, 6, 4, 4, 6, 3, 3, 1,
        2, 6, 5, 2, 9, 7, 7, 0, 1, 0, 4, 3, 1,
        2, 7, 9, 8, 5, 9, 5, 7, 0, 4, 8, 4, 9,
        4, 5, 7, 7, 3, 5, 3, 5, 3, 9, 7, 5, 5,
        2, 7, 0, 1, 9, 1, 7, 9, 8, 5, 0, 2, 0,
        2, 7, 0, 9, 5, 5, 9, 6, 1, 2, 3, 9, 6,
        3, 2, 9, 3, 1, 3, 4, 1, 8, 1, 8, 5, 0,
        9, 2, 7, 2, 3, 5, 2, 6, 3, 4, 1, 5, 0,
        5, 4, 6, 3, 2, 5, 0, 7, 3])
(index == y_test).sum()
tensor(331)
(index == y_test).sum() / len(y_test)
tensor(0.9194)

Calculating Accuracy

class mnist_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.beta = torch.nn.Parameter(
          torch.randn(input_dim, output_dim, requires_grad=True)  
        )
        self.intercept = torch.nn.Parameter(
          torch.randn(output_dim, requires_grad=True)  
        )
        
    def forward(self, X):
        return (X @ self.beta + self.intercept).squeeze()
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance

loss, train_acc, test_acc = mnist_model(
  64, 10
).fit(
  X_train, y_train, X_test, y_test, acc_step=10, n=3000
)

NN Layers

class mnist_nn_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, X):
        return self.linear(X)
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

NN linear layer

Applies a linear transform to the incoming data (\(X\)): \[y = X A^T+b\]

X.shape
(1797, 64)
model = mnist_nn_model(64, 10)
model.parameters()
<generator object Module.parameters at 0x7ff547099300>
list(model.parameters())[0].shape  # A - weights (betas)
torch.Size([10, 64])
list(model.parameters())[1].shape  # b - bias (intercept)
torch.Size([10])

Performance

loss, train_acc, test_acc = model.fit(X_train, y_train, X_test, y_test, n=1000)
train_acc[-5:]
[tensor(0.9923), tensor(0.9923), tensor(0.9930), tensor(0.9930), tensor(0.9930)]
test_acc[-5:]
[tensor(0.9639), tensor(0.9639), tensor(0.9639), tensor(0.9639), tensor(0.9667)]

Feedforward Neural Network

FNN Model

class mnist_fnn_model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nl_step = torch.nn.ReLU()):
        super().__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.nl = nl_step
        self.l2 = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, X):
        out = self.l1(X)
        out = self.nl(out)
        out = self.l2(out)
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Non-linear activation functions

\[\text{Tanh}(x) = \frac{\exp(x)-\exp(-x)}{\exp(x) + \exp(-x)}\]

\[\text{ReLU}(x) = \max(0,x)\]

Model parameters

model = mnist_fnn_model(64,64,10)
len(list(model.parameters()))
4
for i, p in enumerate(model.parameters()):
  print("Param", i, p.shape)
Param 0 torch.Size([64, 64])
Param 1 torch.Size([64])
Param 2 torch.Size([10, 64])
Param 3 torch.Size([10])

Performance - ReLU

loss, train_acc, test_acc = mnist_fnn_model(64,64,10).fit(
  X_train, y_train, X_test, y_test, n=2000
)
train_acc[-5:]
[0.9986082115518441, 0.9986082115518441, 0.9986082115518441, 0.9986082115518441, 0.9986082115518441]
test_acc[-5:]
[0.9638888888888889, 0.9638888888888889, 0.9638888888888889, 0.9638888888888889, 0.9638888888888889]

Performance - tanh

loss, train_acc, test_acc = mnist_fnn_model(64,64,10, nl_step=torch.nn.Tanh()).fit(
  X_train, y_train, X_test, y_test, n=2000
)
train_acc[-5:]
[0.9951287404314544, 0.9951287404314544, 0.9951287404314544, 0.9951287404314544, 0.9951287404314544]
test_acc[-5:]
[0.975, 0.975, 0.975, 0.975, 0.975]

Adding another layer

class mnist_fnn2_model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nl_step = torch.nn.ReLU()):
        super().__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.nl1 = nl_step
        self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.nl2 = nl_step
        self.l3 = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, X):
        out = self.l1(X)
        out = self.nl1(out)
        out = self.l2(out)
        out = self.nl2(out)
        out = self.l3(out)
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      loss_fn = torch.nn.CrossEntropyLoss()
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = loss_fn(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance - relu

loss, train_acc, test_acc = mnist_fnn2_model(
  64,64,10, nl_step=torch.nn.ReLU()
).fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9902574808629089, 0.9902574808629089, 0.9902574808629089, 0.9902574808629089, 0.9902574808629089]
test_acc[-5:]
[0.9666666666666667, 0.9666666666666667, 0.9666666666666667, 0.9666666666666667, 0.9666666666666667]

Performance - tanh

loss, train_acc, test_acc = mnist_fnn2_model(
  64,64,10, nl_step=torch.nn.Tanh()
).fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9784272790535838, 0.9784272790535838, 0.9791231732776617, 0.9791231732776617, 0.9791231732776617]
test_acc[-5:]
[0.9527777777777777, 0.9527777777777777, 0.9527777777777777, 0.9527777777777777, 0.9527777777777777]

Convolutional NN

2d convolutions

nn.Conv2d()

cv = torch.nn.Conv2d(
  in_channels=1, out_channels=4, 
  kernel_size=3, 
  stride=1, padding=1
)
list(cv.parameters())[0] # kernel weights
Parameter containing:
tensor([[[[-0.2050, -0.2136, -0.2673],
          [-0.0281, -0.2074, -0.0626],
          [ 0.0474, -0.2465, -0.0176]]],

        [[[-0.2323, -0.2782, -0.0957],
          [ 0.2592,  0.2719, -0.1723],
          [-0.0421,  0.0604,  0.3314]]],

        [[[-0.2179, -0.3243, -0.0961],
          [-0.2825,  0.2366, -0.1071],
          [-0.0350,  0.0302, -0.2904]]],

        [[[ 0.2639, -0.2899, -0.1652],
          [-0.2402,  0.0671, -0.1947],
          [ 0.1718,  0.3028,  0.0522]]]],
       requires_grad=True)
list(cv.parameters())[1] # biases
Parameter containing:
tensor([ 0.3212, -0.1856, -0.1927, -0.0550],
       requires_grad=True)

Applying Conv2d()

X_train[[0]]
tensor([[ 0.,  0.,  0., 10., 11.,  0.,  0.,
          0.,  0.,  0.,  9., 16.,  6.,  0.,
          0.,  0.,  0.,  0., 15., 13.,  0.,
          0.,  0.,  0.,  0.,  0., 14., 10.,
          0.,  0.,  0.,  0.,  0.,  1., 15.,
         12.,  8.,  2.,  0.,  0.,  0.,  0.,
         12., 16., 16., 16., 10.,  1.,  0.,
          0.,  7., 16., 12., 12., 16.,  4.,
          0.,  0.,  0.,  9., 15., 12.,  5.,
          0.]])
X_train[[0]].shape
torch.Size([1, 64])
cv(X_train[[0]])
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 64]
X_train[[0]].view(1,8,8)
tensor([[[ 0.,  0.,  0., 10., 11.,  0.,  0.,
           0.],
         [ 0.,  0.,  9., 16.,  6.,  0.,  0.,
           0.],
         [ 0.,  0., 15., 13.,  0.,  0.,  0.,
           0.],
         [ 0.,  0., 14., 10.,  0.,  0.,  0.,
           0.],
         [ 0.,  1., 15., 12.,  8.,  2.,  0.,
           0.],
         [ 0.,  0., 12., 16., 16., 16., 10.,
           1.],
         [ 0.,  0.,  7., 16., 12., 12., 16.,
           4.],
         [ 0.,  0.,  0.,  9., 15., 12.,  5.,
           0.]]])
cv(X_train[[0]].view(1,8,8))
tensor([[[ 3.2123e-01,  1.6247e-01,
          -2.8059e+00, -6.0653e+00,
          -2.9627e+00,  2.9598e-01,
           3.2123e-01,  3.2123e-01],
         [ 3.2123e-01, -5.0722e-01,
          -9.1470e+00, -1.1196e+01,
          -5.1568e+00, -2.1020e+00,
           3.2123e-01,  3.2123e-01],
         [ 3.2123e-01, -3.2709e+00,
          -1.3430e+01, -1.1465e+01,
          -4.1317e+00, -9.0847e-01,
           3.2123e-01,  3.2123e-01],
         [ 3.0359e-01, -5.0760e+00,
          -1.3749e+01, -1.0386e+01,
          -4.0633e+00,  2.0707e-01,
           4.1593e-01,  3.2123e-01],
         [ 2.5858e-01, -4.7794e+00,
          -1.2473e+01, -1.1754e+01,
          -7.3186e+00, -3.6811e+00,
          -1.4599e+00,  5.4825e-01],
         [ 5.3963e-02, -4.7767e+00,
          -1.1794e+01, -1.5937e+01,
          -1.1564e+01, -8.8123e+00,
          -6.1214e+00, -3.9573e-01],
         [ 3.2123e-01, -3.3245e+00,
          -9.1313e+00, -1.6582e+01,
          -1.7825e+01, -1.5213e+01,
          -9.9320e+00, -2.9847e+00],
         [ 3.2123e-01, -1.5497e+00,
          -6.0142e+00, -1.0545e+01,
          -1.2844e+01, -1.2202e+01,
          -7.9994e+00, -3.9530e+00]],

        [[-1.8560e-01,  2.7968e+00,
           3.9371e+00,  3.2138e+00,
           5.0857e+00,  2.4127e+00,
          -1.8560e-01, -1.8560e-01],
         [-1.8560e-01,  3.2348e+00,
           3.7620e+00,  1.7830e+00,
          -3.3710e-01, -1.1855e+00,
          -1.8560e-01, -1.8560e-01],
         [-1.8560e-01,  1.0086e+00,
           1.7783e+00,  1.3599e-01,
          -2.6228e+00, -1.5792e+00,
          -1.8560e-01, -1.8560e-01],
         [ 1.4578e-01,  9.9837e-01,
           1.3223e+00,  1.8056e+00,
           2.6817e-02, -4.0186e-01,
          -2.6984e-01, -1.8560e-01],
         [-3.5785e-01,  1.3926e-01,
           3.2607e+00,  5.3162e+00,
           8.0266e+00,  6.0373e+00,
           5.9376e-01, -5.4646e-01],
         [-2.8130e-01, -1.6467e+00,
           4.9307e-01,  1.5793e+00,
           4.3792e+00,  9.6961e+00,
           7.8296e+00,  2.2457e+00],
         [-1.8560e-01, -2.5398e+00,
          -2.9248e+00,  6.5747e-01,
          -3.7799e-02, -3.9428e+00,
          -2.1079e-01,  2.2376e+00],
         [-1.8560e-01, -8.5550e-01,
          -5.2142e+00, -7.5468e+00,
          -4.0436e+00, -1.5523e+00,
          -3.3363e+00, -3.7187e+00]],

        [[-1.9274e-01, -2.8059e+00,
          -5.6373e+00, -5.7823e-01,
          -7.9333e-01, -3.5098e+00,
          -1.9274e-01, -1.9274e-01],
         [-1.9274e-01, -5.5118e+00,
          -4.0592e+00, -4.0245e+00,
          -9.4939e+00, -4.2843e+00,
          -1.9274e-01, -1.9274e-01],
         [-1.9274e-01, -6.7290e+00,
          -4.9732e+00, -9.2686e+00,
          -9.6468e+00, -1.5001e+00,
          -1.9274e-01, -1.9274e-01],
         [-4.8309e-01, -7.4588e+00,
          -7.1317e+00, -1.1751e+01,
          -6.6088e+00, -4.1221e-01,
          -2.6271e-01, -1.9274e-01],
         [-2.9982e-01, -6.3923e+00,
          -7.9961e+00, -1.3323e+01,
          -8.8046e+00, -4.9592e+00,
          -1.3057e+00, -5.1240e-01],
         [-2.8886e-01, -5.2763e+00,
          -9.7372e+00, -1.2685e+01,
          -1.1723e+01, -9.0921e+00,
          -3.9869e+00, -3.2198e+00],
         [-1.9274e-01, -2.0957e+00,
          -8.2928e+00, -1.3095e+01,
          -1.6717e+01, -1.3707e+01,
          -7.3194e+00, -6.4439e+00],
         [-1.9274e-01, -8.6559e-01,
          -4.9647e+00, -7.5374e+00,
          -9.0024e+00, -1.0170e+01,
          -1.0588e+01, -6.3886e+00]],

        [[-5.4985e-02,  4.1525e-01,
           1.5596e+00,  5.1789e+00,
           2.8465e+00, -1.6662e+00,
          -5.4985e-02, -5.4985e-02],
         [-5.4985e-02, -1.0232e+00,
           1.0037e+00, -5.1447e-01,
          -1.8123e+00,  1.4071e+00,
          -5.4985e-02, -5.4985e-02],
         [-5.4985e-02, -3.7299e+00,
          -2.0700e+00, -6.0755e-01,
           1.0236e+00,  1.5285e+00,
          -5.4985e-02, -5.4985e-02],
         [-2.7365e-03, -4.1712e+00,
          -2.2182e+00,  4.0712e+00,
           5.5625e+00,  1.9249e+00,
           2.8859e-01, -5.4985e-02],
         [-2.4964e-01, -4.5932e+00,
          -2.8663e+00,  4.1276e+00,
           8.2787e+00,  6.2736e+00,
           5.2934e+00,  1.9657e+00],
         [-2.2016e-01, -4.7927e+00,
          -5.4763e+00,  8.5395e-01,
           1.5872e+00,  3.2914e+00,
           4.2213e+00,  1.5701e+00],
         [-5.4985e-02, -3.3997e+00,
          -8.3518e+00, -3.6048e+00,
          -1.7729e+00, -8.4323e-01,
           2.0910e+00, -4.2130e-01],
         [-5.4985e-02, -1.2112e+00,
          -6.4793e+00, -7.1449e+00,
          -4.7850e+00, -6.7811e+00,
          -4.7345e+00,  1.8071e+00]]],
       grad_fn=<SqueezeBackward1>)

Pooling

x = torch.tensor(
  [[[0,0,0,0],
    [0,1,2,0],
    [0,3,4,0],
    [0,0,0,0]]],
  dtype=torch.float
)
x.shape
torch.Size([1, 4, 4])
torch.nn.MaxPool2d(
  kernel_size=2, stride=1
)(x)
tensor([[[1., 2., 2.],
         [3., 4., 4.],
         [3., 4., 4.]]])
torch.nn.MaxPool2d(
  kernel_size=3, stride=1, padding=1
)(x)
tensor([[[1., 2., 2., 2.],
         [3., 4., 4., 4.],
         [3., 4., 4., 4.],
         [3., 4., 4., 4.]]])
torch.nn.AvgPool2d(
  kernel_size=2
)(x)
tensor([[[0.2500, 0.5000],
         [0.7500, 1.0000]]])
torch.nn.AvgPool2d(
  kernel_size=2, padding=1
)(x)
tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 2.5000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

Convolutional model

class mnist_conv_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn  = torch.nn.Conv2d(
          in_channels=1, out_channels=8,
          kernel_size=3, stride=1, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2)
        self.lin  = torch.nn.Linear(8 * 4 * 4, 10)
        
    def forward(self, X):
        out = self.cnn(X.view(-1, 1, 8, 8))     # (N, 1, 8, 8) -> (N, 8, 8, 8)
        out = self.relu(out)                    # (N, 8, 8, 8)
        out = self.pool(out)                    # (N, 8, 8, 8) -> (N, 8, 4, 4)
        out = self.lin(out.view(-1, 8 * 4 * 4)) # (N, 128) -> (N, 10)
        return out

    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      loss_fn = torch.nn.CrossEntropyLoss()
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
      losses, train_acc, test_acc = [], [], []

      for i in range(n):
          opt.zero_grad()
          loss = loss_fn(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance

loss, train_acc, test_acc = mnist_conv_model().fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9965205288796103, 0.9965205288796103, 0.9965205288796103, 0.9965205288796103, 0.9965205288796103]
test_acc[-5:]
[0.9833333333333333, 0.9833333333333333, 0.9833333333333333, 0.9833333333333333, 0.9833333333333333]

Organizing models

class mnist_conv_model2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
          torch.nn.Unflatten(1, (1,8,8)),
          torch.nn.Conv2d(
            in_channels=1, out_channels=8,
            kernel_size=3, stride=1, padding=1
          ),
          torch.nn.ReLU(),
          torch.nn.MaxPool2d(kernel_size=2),
          torch.nn.Flatten(),
          torch.nn.Linear(8 * 4 * 4, 10)
        )
        
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

A bit more on non-linear
activation layers

Non-linear functions

df = pd.read_csv("data/gp.csv")
X = torch.tensor(df["x"], dtype=torch.float32).reshape(-1,1)
y = torch.tensor(df["y"], dtype=torch.float32)

Linear regression

class lin_reg(torch.nn.Module):
    def __init__(self, X):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, self.p)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

m1 = lin_reg(X)
loss = m1.fit(X,y, n=2000)

Training loss:

Predictions

Double linear regression

class dbl_lin_reg(torch.nn.Module):
    def __init__(self, X, hidden_dim=10):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

m2 = dbl_lin_reg(X, hidden_dim=10)
loss = m2.fit(X,y, n=2000)

Training loss:

Predictions

Non-linear regression w/ ReLU

class lin_reg_relu(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

Hidden dimensions

Non-linear regression w/ Tanh

class lin_reg_tanh(torch.nn.Module):
    def __init__(self, X, hidden_dim=10):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.Tanh(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Tanh & hidden dimension

Three layers

class three_layers(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

Five layers

class five_layers(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results