Optimization

Lecture 21

Dr. Colin Rundel

Optimization

Optimization problems underlie nearly everything we do in Machine Learning and Statistics.

Most models can be formulated as

\[ P \; : \; \underset{x \in \boldsymbol{D}}{\text{arg min}} \; f(x) \]

  • Formulating a problem \(P\) is not the same as being able to solve \(P\) in practice

  • Many different algorithms exist for optimization but their performance varies widely depending on the exact nature of the problem

Gradient Descent

Naive Gradient Descent

The basic idea behind this approach is that the gradient of a function tells us the direction of steepest ascent (or descent). Therefore, to find the minimum we should take our next step in the direction of the negative gradient to most quickly approach the nearest minima.


Given an \(n\)-dimensional function \(f(x_1, \ldots, x_n)\), and an initial position \(x_k\) then our update rule is,

\[ x_{k+1} = x_{k} - \alpha \nabla f(x_k) \]

here \(\alpha\) refers to the step length or the learning rate which determines how big a step we will take.

Implementation

def grad_desc_1d(x, f, grad, step, max_step=100, tol = 1e-6):
  res = {"x": [x], "f": [f(x)]}

  try:
    for i in range(max_step): 
      x = x - grad(x) * step
      if np.abs(x - res["x"][-1]) < tol: 
        break

      res["f"].append( f(x) )
      res["x"].append( x )
      
  except OverflowError as err:
    print(f"{type(err).__name__}: {err}")
  
  if i == max_step-1:
    warnings.warn("Failed to converge!", RuntimeWarning)
  
  return res

A basic example

\[ \begin{aligned} f(x) &= x^2 \\ \nabla f(x) &= 2x \end{aligned} \]

f = lambda x: x**2
grad = lambda x: 2*x
opt = grad_desc_1d(-2., f, grad, step=0.25)
plot_1d_traj( (-2, 2), f, opt )

opt = grad_desc_1d(-2., f, grad, step=0.5)
plot_1d_traj( (-2, 2), f, opt )

Where can it go wrong?

If you pick a bad step size …

opt = grad_desc_1d(-2, f, grad, step=0.9)
plot_1d_traj( (-2,2), f, opt )

opt = grad_desc_1d(-2, f, grad, step=1)
plot_1d_traj( (-2,2), f, opt )

Local minima

The function below has multiple minima, both starting point and step size affect the solution we obtain,

\[ \begin{aligned} f(x) &= x^4 + x^3 -x^2 - x \\ \nabla f(x) &= 4x^3 + 3x^2 - 2x - 1 \end{aligned} \]

f = lambda x: x**4 + x**3 - x**2 - x 
grad = lambda x: 4*x**3 + 3*x**2 - 2*x - 1
opt = grad_desc_1d(-1.5, f, grad, step=0.2)
plot_1d_traj( (-1.5, 1.5), f, opt )

opt = grad_desc_1d(-1.5, f, grad, step=0.25)
plot_1d_traj( (-1.5, 1.5), f, opt)

Alternative starting points

opt = grad_desc_1d(1.5, f, grad, step=0.2)
plot_1d_traj( (-1.75, 1.5), f, opt )

opt = grad_desc_1d(1.25, f, grad, step=0.2)
plot_1d_traj( (-1.75, 1.5), f, opt)

Problematic step sizes

If the step size is too large it is possible for the algorithm to overflow,

opt = grad_desc_1d(-1.5, f, grad, step=0.75)
OverflowError: (34, 'Result too large')
plot_1d_traj( (-3, 3), f, opt)

opt['x']
[-1.5, 2.0625, -29.986083984375, 78789.99556875888, -1467366557235808.0, 9.478445237313853e+45]
opt['f']
[0.9375, 20.552993774414062, 780666.4923959533, 3.853805712579921e+19, 4.636117851941789e+60, 8.071391646153008e+183]

Gradient Descent w/ backtracking

As we have just seen having too large of a step can be problematic, one solution is to allow the step size to adapt.

Backtracking involves checking if the proposed move is advantageous (i.e. \(f(x_k-\alpha \nabla f(x_k)) < f(x_k)\)),

  • If it is downhill then accept \(x_{k+1} = x_k-\alpha \nabla f(x_k)\).

  • If not, adjust \(\alpha\) by a factor \(\tau\) (e.g. 0.5) and check again.

Pick larger \(\alpha\) to start (but not so large so as to overflow) and then let the backtracking tune things.

Implementation

def grad_desc_1d_bt(x, f, grad, step, tau=0.5, max_step=100, max_back=10, tol = 1e-6):
  res = {"x": [x], "f": [f(x)]}
  
  for i in range(max_step):
    grad_f = grad(x)
    for j in range(max_back):
      x = res["x"][-1] - step * grad_f
      f_x = f(x)
      if (f_x < res["f"][-1]): 
        break
      step = step * tau
    
    if np.abs(x - res["x"][-1]) < tol: 
      break
    res["x"].append(x)
    res["f"].append(f_x)
    
  if i == max_step-1:
    warnings.warn("Failed to converge!", RuntimeWarning)
  
  return res

opt = grad_desc_1d_bt(
  -1.5, f, grad, step=0.75, tau=0.5
)
plot_1d_traj( (-1.5, 1.5), f, opt )

opt = grad_desc_1d_bt(
  1.5, f, grad, step=0.25, tau=0.5
)
plot_1d_traj( (-1.5, 1.5), f, opt)

A 2d cost function

We will be using mk_quad() to create quadratic functions with varying conditioning (as specified by the epsilon parameter).

\[ \begin{align} f(x,y) &= 0.33(x^2 + \epsilon^2 y^2 ) \\ \nabla f(x,y) &= \left[ \begin{matrix} 0.66 \, x \\ 0.66 \, \epsilon^2 \, y \end{matrix} \right] \\ \nabla^2 f(x,y) &= \left[\begin{array}{cc} 0.66 & 0 \\ 0 & 0.66 \, \epsilon^2 \end{array}\right] \end{align} \]

Examples

f, grad, hess = mk_quad(0.7)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.7"
)

f, grad, hess = mk_quad(0.05)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.05"
)

\(n\)-d gradient descent w/ backtracking

def grad_desc(x, f, grad, step, tau=0.5, max_step=100, max_back=10, tol = 1e-8):
  res = {"x": [x], "f": [f(x)]}
  
  for i in range(max_step):
    grad_f = grad(x)
    
    for j in range(max_back):
      x = res["x"][-1] - grad_f * step
      f_x = f(x)
      if (f_x < res["f"][-1]): 
        break
      step = step * tau

    if np.sqrt(np.sum((x - res["x"][-1])**2)) < tol: 
      break  
    
    res["x"].append(x)
    res["f"].append(f_x)
      
  if i == max_step-1:
    warnings.warn("Failed to converge!", RuntimeWarning)
    
  return res

Well conditioned cost function

f, grad, hess = mk_quad(0.7)
opt = grad_desc((1.6, 1.1), f, grad, step=1)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.7", traj=opt
)

f, grad, hess = mk_quad(0.7)
opt = grad_desc((1.6, 1.1), f, grad, step=2)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.7", traj=opt
)

Ill-conditioned cost function

f, grad, hess = mk_quad(0.05)
opt = grad_desc((1.6, 1.1), f, grad, step=1)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.05", traj=opt
)

f, grad, hess = mk_quad(0.05)
opt = grad_desc((1.6, 1.1), f, grad, step=2)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.05", traj=opt
)

Aside - Ill-conditioned functions

A function is ill-conditioned when the Hessian has eigenvalues that differ by orders of magnitude. The condition number \(\kappa = \lambda_\text{max} / \lambda_\text{min}\) measures this:

  • \(\kappa \approx 1\) - well-conditioned, contours are nearly circular, GD converges quickly

  • \(\kappa \gg 1\) - ill-conditioned, contours are highly elongated, GD zigzags slowly

For mk_quad(epsilon), the Hessian eigenvalues are \(0.66\) and \(0.66\epsilon^2\), giving \(\kappa = 1/\epsilon^2\). With \(\epsilon = 0.05\) we have \(\kappa = 400\).

Rosenbrock’s function

Another classic ill-conditioned function commonly used to benchmark optimization algorithms,

\[ f(u, v) = \tfrac{1}{2}(1 - u)^2 + (v - u^2)^2 \]

Our scaled version uses \(u = 4x+1\) and \(v = 4y+3\),

\[ f(x, y) = 8x^2 + \big(4y + 3 - (4x+1)^2\big)^2 \]

  • Minimum at \((u^*, v^*) = (1, 1)\), i.e. \((x^*, y^*) = (0, -\tfrac{1}{2})\) where \(f = 0\)

  • Has a narrow, curved (banana-shaped) valley - easy to find the valley, hard to follow it to the minimum

  • Condition number of the Hessian at the minimum is \(\approx 2500\)

Rosenbrock function (very ill conditioned)

f, grad, hess = mk_rosenbrock()
opt = grad_desc((1.6, 1.1), f, grad, step=0.25)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = grad_desc((-0.5, 0), f, grad, step=0.25)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Some regression examples

from sklearn.datasets import make_regression
X, y, coef = make_regression(
  n_samples=200, n_features=20, n_informative=4, 
  bias=3, noise=1, random_state=1234, coef=True
)
y
array([ -36.2252,    9.6357,   66.4583,   48.9574,   24.1885,  -13.2444,
         18.1455, -135.047 ,  116.5772,   60.2524,   30.9319,  107.148 ,
         21.6209,   66.2401, -132.8878,   58.636 ,   22.186 ,   60.3852,
        -85.0383,   55.1704,  -31.3817,  -57.0697,   67.3215,    2.878 ,
        -29.5613,  -41.3973,  -30.3048,  -41.5597,   52.7531,  -63.5633,
          3.5671,   63.712 ,    9.9833,   78.4881,  -76.126 ,   13.4331,
        122.6162,   79.0354,   91.2171,   48.7344,  103.6366,   52.5964,
         35.0064,  -65.8423,  -47.3045,  -25.6876,    1.8359,   35.4113,
         28.0687,   56.3528,    3.6755,  -72.3309,   57.143 ,  -16.9438,
         54.1445,   72.6828,   -5.0538, -180.6135,  -44.6205,    9.2071,
         -5.5324,  -29.6013,  135.3656,  114.241 ,  -97.4878,   15.0648,
         14.7958,   71.503 ,   -4.6583,  -36.791 ,   -5.3845, -119.8073,
         11.174 ,   36.3008,   82.5499,  -20.0869,   14.7146,  -59.0765,
         39.4171,   48.4013,  -61.9613,   -5.6247,  103.2374,   41.2613,
       -129.2273,   10.5113,   32.4936,   78.6921,    5.2956,   64.4473,
         88.8358,   39.4851,  -11.4866,  -52.5082,  112.7248,   -9.7006,
         13.8393,  -36.4004,   68.4865,   19.5335,  -75.447 ,  -87.9538,
         79.4784,  -75.094 ,   25.6229,   84.9034,   71.2779,  -66.4093,
         77.6444,   40.8875,   31.3165,  -22.7143,   84.562 ,    6.8075,
          9.778 ,  -65.9149,  106.6952,   -3.1901,   41.1555,   32.6265,
        -36.5738,   38.9966,  -78.664 ,  -56.0434,    2.9191,   42.6286,
         51.3644,  -21.8072,  -21.9779,  -15.7102,  -23.5586,    1.3801,
         20.4269,   55.7188,  -45.6388,  -55.1542,   74.6067,   -7.2716,
        -31.1045,   48.1571,   14.7487,   41.6956,  -59.6062,  -33.0811,
         81.0177,   -9.4896,  164.1317,   25.3507,    6.0141,   46.3718,
         84.2983,  -63.2593,  -17.4733,  -26.2977,  -56.4681,   17.003 ,
         53.1867,  -94.5398,  -18.2541,  -49.343 ,   40.8724,  -90.5986,
         27.9392,   41.7287,   49.8082,   -9.6384,  -66.7551,  122.9159,
        -41.3566,  -98.6863,  -45.0718,    9.9327,  -22.0927,   10.6199,
        -12.2831,    7.4184,   57.6091,  -27.3456,  -36.4045,  -51.659 ,
         28.8175,  -23.9402,  -51.0637,    4.3618,   10.8402,  -11.087 ,
        -29.9801,  113.6633,   66.5601,    1.3808,  -19.4875,   40.812 ,
         43.0652,   35.4802,   77.0732,  -49.7352,   65.7192,   73.8539,
        -59.4116,   72.9501])
X
array([[-0.6465,  2.0803,  0.1412, -0.8419, -0.1595,  1.3321, -0.4262,
        -0.0351, -0.1938, -0.6093, -0.3433,  0.6126,  0.3777, -1.2062,
        -0.2277, -0.8896, -0.4674, -1.3566,  1.4989, -0.7468],
       [-0.3834, -0.3631, -1.2196,  0.6   ,  0.3315,  1.1056,  0.2662,
        -0.7239,  0.0259, -0.2172, -0.6841,  0.0991,  0.2794, -1.208 ,
        -0.7818, -1.7348, -1.3397, -0.5723, -0.5882,  0.2717],
       [-0.1637, -0.8118,  0.9551,  0.5711,  0.8719, -0.9619,  1.9846,
        -1.1806, -1.1261,  0.297 ,  1.2499,  0.7109, -0.1183,  0.6708,
         0.6895,  1.4705,  0.0634, -0.3079, -2.2512, -0.0216],
       [-0.9292, -0.4897, -2.1196, -1.142 ,  1.266 , -0.2988,  1.0016,
        -2.1969, -1.0739, -0.1149,  0.5122,  0.302 , -0.0974,  1.3461,
         0.1909,  1.1223,  0.6268,  2.2035, -0.5135,  2.0118],
       [ 0.1645, -0.5847,  0.2708, -3.5635,  0.1526,  0.5283,  0.7674,
         1.392 , -0.0819,  1.3211,  0.4644, -1.0279,  0.9849, -1.069 ,
        -0.4301,  0.0798, -0.5119, -0.3448,  0.8166, -0.4   ],
       [ 0.4134,  1.9511, -0.5013, -1.4894,  0.4191, -1.4104,  0.2617,
        -0.6981,  0.0368, -1.151 ,  2.0752,  0.5001, -0.2428,  0.45  ,
         0.7176,  1.3846,  0.5155,  0.4459, -0.2784, -0.2864],
       [-0.0628, -1.424 , -1.1023,  0.1445, -0.4836,  1.4795, -0.5921,
         1.6423, -0.5013,  0.4435,  2.0044,  0.6221,  0.0747, -1.4117,
        -0.202 , -1.3071, -0.8656, -1.311 ,  0.0424,  0.7255],
       [-0.6642,  1.4317, -0.0658, -0.7379, -0.9153,  0.8653,  0.7143,
         1.0912, -1.3773, -2.6022, -0.2955, -0.3985,  0.0918,  0.3851,
         0.502 , -0.4665,  1.6432, -0.2438, -0.4943,  1.4753],
       [ 1.5247, -1.3419, -0.4453, -0.6141,  2.0632, -1.0742, -1.4419,
        -1.4923,  0.3135,  0.7691,  0.5383, -0.9741,  0.8457, -0.0014,
         0.3895,  0.2118,  1.0977, -0.4036, -1.5496, -0.3672],
       [ 1.5524, -1.1109, -0.5624, -0.9106, -0.0506,  0.8533, -0.5452,
        -1.7836, -0.8365,  2.171 , -0.6158,  0.2523, -1.8707,  0.6142,
         0.7962,  0.0706, -0.2386, -0.4144, -0.0898, -0.4745],
       [-0.7206, -2.0213,  0.0157, -1.191 , -0.3127,  0.2891,  0.8596,
        -2.2427,  0.0021,  1.4327,  0.4714,  0.9533, -0.6365,  1.3212,
         0.8872,  1.15  , -1.5469,  0.4055, -0.3341,  0.9919],
       [ 0.2328,  0.5523, -0.2356, -1.2547,  0.6686, -2.1204, -0.186 ,
         1.4915, -1.1353,  2.3889,  0.3449, -0.6703, -0.2358, -2.1923,
        -0.4635, -0.9962, -0.1116,  0.0605,  0.0027, -1.439 ],
       [ 1.8092, -1.5857, -0.9765,  1.6171,  0.368 , -0.2947,  1.5897,
        -0.8878,  1.0547, -0.0427, -1.187 ,  0.7605,  1.2381, -0.5014,
         1.0201, -0.5773, -0.632 , -0.502 , -1.6914,  0.803 ],
       [ 0.1935,  0.5289, -0.7559, -0.1047, -0.3334, -1.0275,  1.0327,
        -0.8811,  0.0483,  1.8504,  1.5727,  0.3325, -1.7398, -0.2383,
        -0.4967,  0.3939,  1.9322,  0.062 , -1.1205, -0.95  ],
       [ 0.6914, -2.0308,  1.1554, -0.4219, -1.6257, -0.1138, -0.9225,
        -1.9216,  1.2995, -1.5084, -0.863 ,  0.2528,  1.3636,  0.2059,
         0.0381,  1.1124,  1.73  ,  0.4496, -0.1806,  0.7681],
       [-0.0862, -0.2131, -0.5343, -0.1066, -0.8403,  1.3862,  0.5885,
        -1.089 , -0.8571,  2.0178,  2.6078, -0.5807, -0.3466, -0.5166,
        -0.7863,  0.2918, -0.1904, -0.8012, -1.6868,  0.2538],
       [-1.0714, -0.4582,  0.4255,  0.5657, -0.1743,  2.0978, -0.8453,
        -0.9807, -0.0414,  0.5851,  0.2645, -0.3602,  0.4151,  1.2829,
        -0.0485, -0.4278,  0.2703,  0.821 , -1.338 ,  1.4986],
       [ 0.2126,  2.1145, -0.1471,  1.7549,  0.9465, -1.3906, -1.0954,
        -0.5224,  0.5338,  0.0591, -0.2671,  1.5731,  0.3903,  0.6137,
        -0.5277,  0.6306,  0.7467,  1.7232,  0.62  ,  2.0249],
       [ 0.7085,  1.312 , -0.6134,  0.8665, -1.4706,  0.2597, -0.1606,
        -0.7118,  0.2154, -0.7415, -0.608 , -0.3412,  1.0772,  0.4695,
        -0.1285,  0.0654,  0.4922, -0.6707, -1.8229, -0.4215],
       [ 1.2418, -1.9068, -0.6066,  0.1639,  0.986 , -0.1853, -0.0303,
         1.152 , -0.161 ,  0.0226,  0.8991,  0.9874, -0.802 , -0.7241,
         0.2466,  0.747 , -0.9682, -1.1908,  0.4313, -1.2039],
       [-0.6263,  0.2757,  0.9388,  1.3835, -0.5935,  0.4409, -1.4681,
         0.0114, -0.3643, -0.3373, -1.3341,  0.0036,  0.5513, -0.1016,
         0.6814, -1.4258, -1.3869, -2.0679, -1.6482,  1.0062],
       [-0.2397,  0.6481,  1.7758,  0.0166, -1.7724, -0.1862,  1.118 ,
        -0.8409,  0.6136,  0.5269, -0.2908, -0.2294,  0.1747, -0.3881,
        -0.2667, -0.7601,  0.4313, -0.7488, -0.7594, -0.4084],
       [-1.0989,  1.1887, -0.5288,  1.6782,  0.3827,  0.4309, -1.3949,
         0.6801, -1.2572,  0.6585,  0.7674, -1.5397,  1.1786,  1.2429,
        -1.1094,  1.2524, -0.7556,  0.4051, -0.3198,  0.6704],
       [ 0.0582,  0.1247,  0.1058, -0.4947, -0.1381,  1.3226,  0.3375,
         0.0445,  1.2923,  0.67  , -1.3132, -0.7997, -0.1669,  1.5938,
        -0.7805, -0.3689, -2.5977, -1.2921, -1.2897, -0.074 ],
       [-0.9515, -1.0973,  1.5675,  0.0103, -1.1347,  0.165 ,  0.0289,
        -0.6242, -1.3193,  0.2246,  0.7557, -0.9032,  2.1041, -0.6316,
        -0.1271, -0.4006, -0.8671, -0.5601, -0.0713, -1.1371],
       [ 0.4599,  0.5513,  1.6362, -1.2392, -0.3352,  1.0237,  1.7626,
        -0.5441,  1.3217, -1.2237,  2.5112, -1.7501, -0.0857,  0.8239,
        -0.6406, -1.05  , -0.635 , -2.1445,  1.4129,  0.2546],
       [ 0.1871, -2.2206,  1.2475,  1.2345, -1.5021,  1.1434, -1.0406,
         0.0709,  1.2826,  0.5946, -0.176 ,  0.0639, -1.4364, -0.3326,
        -0.4648,  0.0733, -1.5075,  0.7799, -0.6549,  0.2562],
       [ 0.084 ,  0.9564,  0.366 , -0.6843, -0.6239,  0.3233, -0.4753,
        -0.7024, -0.8606, -0.8089,  1.7968, -0.9079, -0.1103, -0.8212,
         1.328 , -1.2039, -2.1219, -0.8672, -1.345 , -1.0769],
       [-0.807 , -0.037 ,  0.7597,  0.9556,  0.1334, -0.0225,  1.9088,
        -0.423 ,  0.267 ,  0.7138,  0.996 ,  0.0679,  0.1559,  0.1314,
        -0.342 ,  0.1817,  0.4344,  1.383 , -0.1708,  0.2745],
       [-0.3675, -0.93  ,  1.2117,  1.0203, -1.1554, -0.0461,  0.827 ,
         1.793 , -1.0029, -0.7901,  0.0797,  0.992 , -0.5725,  1.3592,
         1.2639,  1.3791,  0.021 , -2.5727, -0.2494,  2.0499],
       ...,
       [ 1.6243, -1.2413, -0.4177,  0.2389, -0.2734, -0.6785, -1.0147,
        -0.2772, -1.711 , -1.1543,  0.2933,  1.487 ,  0.7526,  0.6561,
         0.4132,  0.1095,  0.1406, -0.6598,  1.2687,  1.2148],
       [ 0.6854, -0.7399, -1.0681,  0.2991,  0.0382, -0.9321, -0.8341,
         0.0215,  0.0612, -0.129 ,  0.8795, -0.1681,  0.8851,  1.2921,
         0.3478,  1.5717,  2.4181, -0.0638,  1.3938,  0.884 ],
       [ 2.2326, -1.7645,  1.9779, -1.6875, -0.8401,  0.1057,  1.1688,
         0.3301, -0.5216,  1.207 , -1.5042,  1.6341, -1.0896, -0.7015,
        -1.7587,  1.4814,  0.6081, -0.7485,  2.1342, -0.4016],
       [ 2.2171, -0.6177,  0.1949, -1.0798,  0.586 , -0.859 ,  2.5508,
        -0.8039,  0.1503, -0.1069, -0.6496, -0.2479,  0.1649,  0.765 ,
         0.8986, -0.3648,  0.6722, -0.2408,  0.7112,  0.1551],
       [ 1.1797,  2.0229,  0.2965, -0.4986,  0.6617,  0.8841, -0.8252,
        -0.3799, -1.1173, -1.3918,  0.9206, -0.076 , -0.8812, -1.7954,
         0.2918,  0.7677,  0.4183,  1.236 , -0.1036, -0.0952],
       [-1.4325,  1.5241,  0.4914,  0.4466, -1.5217, -0.5697, -0.1623,
        -0.0357, -1.3161,  1.733 ,  0.7872,  1.4468,  1.8372,  1.0749,
        -2.0308, -0.2996, -1.1323,  0.6271,  0.7217,  0.8836],
       [ 0.4932, -1.2439,  0.5748, -0.1409,  1.7359, -1.1483, -0.4902,
        -0.5052, -0.4267, -0.6533,  0.242 ,  0.7283,  0.2963, -1.2347,
         0.6998,  0.7025, -0.5894, -2.7557, -1.1078, -0.5546],
       [ 0.7622, -1.2367, -0.9891,  1.7989,  0.1187, -1.8608, -0.559 ,
         0.775 ,  0.0616, -1.6046,  0.2385, -1.5886, -0.1833, -0.7817,
         1.8364, -0.5933, -0.3687,  0.3881,  1.2738,  1.2086],
       [-0.7038, -1.2434,  0.5626,  0.3224,  0.3713,  0.7815,  1.957 ,
         0.4423,  0.6326, -1.956 ,  1.1085,  0.1665,  0.841 ,  2.1472,
         0.5566, -0.2651, -0.9084,  2.0134,  0.3486,  1.2223],
       [-0.8597, -1.2391,  0.9525, -0.7438, -0.9162,  0.1223,  0.6288,
         0.9881, -0.223 , -0.3202,  0.5368, -0.9382,  0.1865, -1.4094,
         0.226 , -0.0726,  1.423 ,  2.1237,  0.1397, -0.5506],
       [-0.0956,  0.2007, -0.3942,  0.812 ,  0.6777, -2.5834, -1.3294,
        -0.6009, -1.0962, -0.359 ,  0.0455, -0.5706,  0.0263, -0.9308,
        -0.0649,  0.6586, -0.0469, -1.0283,  0.3524, -0.3676],
       [-1.4168, -0.0501, -0.5261, -0.3774, -0.9222,  0.253 , -0.2044,
         0.0524,  0.8973,  0.0268,  1.6289,  0.4335, -2.1098, -0.488 ,
         0.8635, -1.7442,  0.6374, -1.3713, -0.6505,  0.0552],
       [-0.1955, -0.6314,  1.0877, -0.9238, -1.2701,  0.3528,  0.9894,
         0.4388, -0.2405,  0.3558, -0.2266,  0.5029,  1.3886, -1.8156,
        -0.4634, -0.9616, -0.9101,  0.5856, -0.7043,  1.2456],
       [ 0.1021, -0.9809,  0.3537,  1.6762, -0.7037, -0.2284, -0.278 ,
        -0.4083,  0.666 ,  0.681 , -0.9055,  1.054 , -0.0522,  0.3645,
         1.1951, -1.8104, -1.5148,  1.0655,  0.3521, -0.9033],
       [ 0.2733, -0.0657, -0.5118, -0.9471, -0.4646,  1.4039,  0.5143,
         0.0839, -1.1759,  0.5828,  1.7003, -0.4077,  0.5941, -0.7209,
        -0.2797,  0.6193,  1.001 ,  1.6007, -0.6754,  0.2706],
       [ 1.0588, -2.1785,  0.2624,  0.1752, -0.0675,  0.0093, -1.8534,
        -1.7246,  1.5724, -0.5295,  0.4144,  0.6353, -0.7018, -1.1284,
        -0.1179,  0.2766, -1.2591,  2.0028,  0.312 ,  1.073 ],
       [ 0.428 , -0.2036, -1.3212,  1.1555, -0.439 , -1.6271, -0.1925,
         0.622 , -1.5984, -0.5565, -0.5401,  0.0358,  0.0133, -1.4583,
         2.0069,  0.1263,  0.4161, -0.9847,  0.7966,  0.486 ],
       [-0.1173,  0.6381, -0.2812, -0.3578,  1.9299,  2.1058,  2.3049,
         1.0974,  0.5727,  0.9601, -0.211 ,  0.6298,  0.7401, -1.7709,
        -1.0439,  0.8751, -1.5677,  0.3412, -0.1276, -0.8195],
       [ 0.2991, -0.7227, -0.4858,  0.2055,  0.8665,  0.5538,  0.3632,
         0.3877,  0.9835,  0.313 ,  1.5294, -0.3187,  1.8937,  0.3538,
         1.0765,  0.0236, -0.2756,  0.0235,  0.1774, -0.6602],
       [ 0.2811,  0.94  ,  0.2862,  0.2612, -0.3128,  1.8086, -0.1915,
         0.5764,  0.5697,  0.8199, -2.1527, -1.3712, -0.6986,  0.3284,
        -0.4288,  0.2109, -0.4925, -0.3614,  0.4904, -2.0268],
       [-0.1443, -1.0635, -1.0286, -0.5993, -0.0703,  1.1204, -0.0037,
        -0.4246, -0.3002, -0.7602,  1.4965,  2.5779,  0.7647,  0.1203,
         0.1931,  0.7629,  1.2026, -0.8532,  0.1837,  0.5151],
       [ 0.4857,  0.4299,  0.1458,  0.7444, -0.0979,  1.0103,  0.6701,
        -0.2955,  0.0965,  1.4683, -1.7762,  0.9793, -0.2762, -0.0483,
         0.0385, -0.5835,  0.6415,  0.0514,  0.3048,  1.6422],
       [ 0.4843,  0.7492,  1.2246,  0.5665,  0.2853, -1.8185, -0.7811,
        -1.2811, -0.1822,  0.5036,  0.2912, -0.4508, -0.468 ,  0.0471,
         1.3635,  0.8755,  0.3948,  0.6807, -0.2039, -1.7107],
       [-1.0209, -0.6634, -0.9163,  1.3832,  0.0976,  1.3228,  0.1802,
         0.2148,  1.6929,  0.4813, -0.2126,  0.7982, -1.1308, -0.6505,
         0.0008,  1.4194, -0.0373,  0.3867,  0.2306, -0.5345],
       [-0.7579,  1.397 ,  1.337 , -1.8439,  2.0066,  1.6138, -1.5925,
        -0.2433,  0.1118,  0.11  , -0.0658,  0.3186,  0.2924, -0.2974,
         1.016 , -0.231 ,  1.639 ,  0.4316, -0.8798, -0.3389],
       [ 0.7059,  1.752 , -0.0646, -0.0381, -0.2057,  0.6298, -0.0253,
        -1.205 , -0.1709, -0.6655, -2.0803,  0.4152, -0.1783,  0.8111,
        -2.6128, -3.8809,  2.1338,  0.7489,  0.485 ,  0.9745],
       [-1.1806, -0.3993, -0.2226, -2.3027,  1.3672, -0.1536, -1.9393,
         0.2231, -0.0869,  0.9348, -0.4875, -0.8112,  0.6287,  0.2276,
        -0.2909,  0.0516,  1.3023,  0.5602,  0.515 ,  0.4992],
       [-0.6427, -0.3821,  0.1691,  0.1282,  0.4937,  0.4471, -0.2522,
         0.8716, -1.5433,  1.6954, -0.9341, -0.5208, -1.0103,  0.9605,
        -1.877 ,  0.8052, -1.7248, -1.3744, -0.4322,  1.1537],
       [ 0.8806, -0.7055, -0.0158,  0.172 , -0.4213,  0.8811,  0.7577,
        -0.3878, -0.6382, -1.365 ,  0.1341,  1.7316, -0.6366, -0.6532,
        -1.4726,  0.8897, -1.32  ,  0.7008, -1.2858,  1.1342],
       [-0.5563, -1.0968, -0.9132, -1.2012,  1.8192, -1.1012,  0.805 ,
         2.0475,  2.122 ,  0.0836, -0.1508, -0.6059,  1.1388, -1.2726,
        -1.0138, -0.8948,  0.8435,  1.1215,  0.2176,  1.0591]],
      shape=(200, 20))
coef.reshape((-1,1))
array([[ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 9.6106],
       [43.4239],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [34.453 ],
       [ 9.2929],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ],
       [ 0.    ]])

A jax implementation of GD

def grad_desc_jax(x, f, step, tau=0.5, max_step=100, max_back=10, tol = 1e-8):
  grad_f = jax.grad(f)
  f_x = f(x)
  converged = False

  for i in range(max_step):
    grad_f_x = grad_f(x)
    
    for j in range(max_back):
      new_x = x - grad_f_x * step
      new_f_x = f(new_x)
      if (new_f_x < f_x): 
        break
      step *= tau

    cur_tol = jnp.sqrt(jnp.sum((x - new_x)**2))

    x = new_x 
    f_x = new_f_x

    if cur_tol < tol: 
      converged = True
      break
  
  if not converged:
    warnings.warn("Failed to converge!", RuntimeWarning)

  return {
    "x": x,
    "n_iter": i,
    "converged": converged,
    "final_tol": float(cur_tol),
    "final_step": step
  }

Linear regression

lm = LinearRegression().fit(X,y)
np.r_[lm.intercept_, lm.coef_]
array([ 3.0616, -0.0121, -0.0096,  0.096 ,  9.6955, 43.406 ,  0.0253,
        0.0284,  0.0962,  0.1069, 34.4884,  9.3445, -0.0165, -0.0147,
       -0.0396,  0.0969, -0.1057, -0.0943,  0.11  , -0.0096, -0.0875])
def jax_linear_regression(X, y, beta):
  Xb = jnp.c_[jnp.ones(X.shape[0]), X]
  return jnp.sum((y - Xb @ beta)**2)

grad_desc_jax(
  np.zeros(X.shape[1]+1), 
  lambda beta: jax_linear_regression(X,y,beta), 
  step = 1, tau = 0.5
)
{'x': Array([ 3.0617, -0.0121, -0.0097,  0.0961,  9.6958, 43.4061,  0.0255,
        0.0285,  0.0963,  0.1066, 34.488 ,  9.3445, -0.0166, -0.0146,
       -0.0396,  0.0966, -0.1056, -0.0941,  0.1101, -0.0099, -0.0879],      dtype=float32), 'n_iter': 18, 'converged': True, 'final_tol': 1.31708899342442e-09, 'final_step': 7.450580596923828e-09}

Ridge regression

r = GridSearchCV(Ridge(), param_grid = {"alpha": np.logspace(-3,0)}).fit(X,y)
r.best_estimator_
Ridge(alpha=np.float64(0.05963623316594643))
np.r_[r.best_estimator_.intercept_, r.best_estimator_.coef_]
array([ 3.0631, -0.0122, -0.0102,  0.0957,  9.6928, 43.3936,  0.0265,
        0.0282,  0.0959,  0.1074, 34.478 ,  9.3399, -0.0179, -0.0134,
       -0.0403,  0.0956, -0.106 , -0.0933,  0.1101, -0.0108, -0.0877])
def jax_ridge(X, y, beta, alpha):
  Xb = jnp.c_[jnp.ones(X.shape[0]), X]
  ls_loss = jnp.sum((y - Xb @ beta)**2)
  coef_loss = jnp.sum(beta[1:]**2)
  return ls_loss + alpha * coef_loss

grad_desc_jax(
  np.zeros(X.shape[1]+1), 
  lambda beta: jax_ridge(X, y, beta, r.best_estimator_.alpha), 
  step = 1, tau = 0.5, tol=1e-8
)
{'x': Array([ 3.0633, -0.0122, -0.0104,  0.0957,  9.6931, 43.3938,  0.0267,
        0.0282,  0.0961,  0.1071, 34.4775,  9.3398, -0.0179, -0.0134,
       -0.0404,  0.0952, -0.1059, -0.093 ,  0.1102, -0.011 , -0.0882],      dtype=float32), 'n_iter': 18, 'converged': True, 'final_tol': 0.0, 'final_step': 1.862645149230957e-09}

Lasso

ls = GridSearchCV(Lasso(), param_grid = {"alpha": np.logspace(-3,0)}).fit(X,y)
ls.best_estimator_
Lasso(alpha=np.float64(0.10481131341546852))
np.r_[ls.best_estimator_.intercept_, ls.best_estimator_.coef_]
array([ 3.0555, -0.    , -0.    ,  0.    ,  9.5849, 43.312 ,  0.    ,
        0.    ,  0.    ,  0.    , 34.3712,  9.2233, -0.    ,  0.    ,
       -0.    ,  0.    , -0.    , -0.    ,  0.    , -0.    , -0.    ])
def jax_lasso(X, y, beta, alpha):
  n = X.shape[0]
  Xb = jnp.c_[jnp.ones(n), X]
  ls_loss = (1/(2*n))*jnp.sum((y - Xb @ beta)**2)
  coef_loss = jnp.sum(jnp.abs(beta[1:]))
  return ls_loss + alpha * coef_loss

grad_desc_jax(
  np.zeros(X.shape[1]+1), 
  lambda beta: jax_lasso(X, y, beta, ls.best_estimator_.alpha), 
  step = 1, tau = 0.5, tol=1e-10
)
{'x': Array([ 3.0623,  0.    , -0.    ,  0.    ,  9.5842, 43.303 ,  0.    ,
       -0.    ,  0.0031,  0.0098, 34.3752,  9.2158,  0.    ,  0.    ,
       -0.    ,  0.    , -0.0144, -0.0124,  0.0191, -0.    , -0.    ],      dtype=float32), 'n_iter': 44, 'converged': True, 'final_tol': 7.919517472365634e-12, 'final_step': 1.4551915228366852e-11}

Limitations of gradient descent

  • Converges to a local minima - sensitive to starting location

    • Global convergence guarantees only hold for convex functions
  • Requires gradient computation at every step - expensive when \(n\) is large or the gradient has no closed form

  • Sensitive to the choice of learning rate - too large diverges, too small converges slowly

  • Uses a scalar step size applied uniformly in all directions - performs poorly on ill-conditioned problems where the optimal step varies by direction

Newton’s Method

Newton’s Method in 1d

Let’s assume we have a 1d function \(f(x)\) we are trying to optimize, our current guess is \(x\) and we want to know how to generate our next step \(\Delta x\).

We start by constructing the 2nd order Taylor approximation of our function at \(x+\Delta x\),

\[ f(x + \Delta x) \approx \widehat{f}(x + \Delta x) = f(x) + \Delta x f'(x) + \frac{1}{2} (\Delta x)^2 f''(x) \]

Finding the Newton step

Our optimal step then becomes the value of \(\Delta x\) that minimizes the quadratic given by our Taylor approximation.

\[ \frac{\partial}{\partial \Delta x} \widehat{f}(x+\Delta x) = 0 \] \[ \frac{\partial}{\partial \Delta x} \left(f(x) + \Delta x f'(x) + \frac{1}{2} (\Delta x)^2 f''(x) \right) = 0\\ \] \[ f'(x) + \Delta x f''(x) = 0\\ \] \[ \Delta x = -\frac{f'(x)}{f''(x)} \]

this suggests an iterative update rule of

\[ x_{k+1} = x_{k} -\frac{f'(x_k)}{f''(x_k)} \]

Generalizing to \(n\)d

Based on the same argument we can see the following result for a function in \(\mathbb{R}^n\),

\[ f(x + \Delta x) \approx \widehat{f}(x) = f(x) + \Delta x^T \nabla f(x) + \frac{1}{2} \Delta x^T \, \nabla^2 f(x) \,\Delta x \]

then

\[ \frac{\partial}{\partial \Delta x} \widehat{f}(x) = 0 \\ \nabla f(x) + \nabla^2 f(x) \, \Delta x = 0\\ \Delta x = -\left(\nabla^2 f(x)\right)^{-1} \nabla f(x) \]

where

  • \(\nabla f(x)\) is the \(n \times 1\) gradient vector

  • and \(\nabla^2 f(x)\) is the \(n \times n\) Hessian matrix.

based on these results our \(n\)d update rule is

\[ x_{k+1} = x_{k} - (\nabla^2 f(x_k))^{-1} \, \nabla f(x_k) \]

Implementation

def newtons_method(x, f, grad, hess, max_iter=100, tol=1e-8):
    x = np.array(x)
    s = x.shape
    res = {"x": [x], "f": [f(x)]}
    
    for i in range(max_iter):
      x = x - np.linalg.solve(hess(x), grad(x))
      x = x.reshape(s)

      if np.sqrt(np.sum((x - res["x"][-1])**2)) < tol:
        break

      res["x"].append(x)
      res["f"].append(f(x))
    
    return res

A basic example

\[ \begin{aligned} f(x) &= x^2 \\ \nabla f(x) &= 2x \\ \nabla^2 f(x) &= 2 \end{aligned} \]

f = lambda x: np.array(x**2)
grad = lambda x: np.array([2*x])
hess = lambda x: np.array([[2]])
opt = newtons_method(-2., f, grad, hess)
plot_1d_traj( (-2, 2), f, opt )

opt = newtons_method(1., f, grad, hess)
plot_1d_traj( (-2, 2), f, opt )

1d Quartic

\[ \begin{aligned} f(x) &= x^4 + x^3 -x^2 - x \\ \nabla f(x) &= 4x^3 + 3x^2 - 2x - 1 \\ \nabla^2 f(x) &= 12x^2 + 6x - 2 \end{aligned} \]

f = lambda x: x**4 + x**3 - x**2 - x 
grad = lambda x: np.array([4*x**3 + 3*x**2 - 2*x - 1])
hess = lambda x: np.array([[12*x**2 + 6*x - 2]])
opt = newtons_method(-1.5, f, grad, hess)
plot_1d_traj( (-1.5, 1.5), f, opt )

opt = newtons_method(1.5, f, grad, hess)
plot_1d_traj( (-1.5, 1.5), f, opt)

2d quadratic cost function

f, grad, hess = mk_quad(0.7)
opt = newtons_method((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

f, grad, hess = mk_quad(0.05)
opt = newtons_method((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

Rosenbrock function

f, grad, hess = mk_rosenbrock()
opt = newtons_method((1.6, 1.1), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = newtons_method((-0.5, 0), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Damped / backtracking implementation

def newtons_method_damped(
  x, f, grad, hess, max_iter=100, max_back=10, tol=1e-8,
  alpha=0.5, beta=0.75
):
    res = {"x": [x], "f": [f(x)]}
    
    for i in range(max_iter):
      grad_f = grad(x)
      step = - np.linalg.solve(hess(x), grad_f) 
      t = 1
      for j in range(max_back):
        # Full Armijo-Goldstein condition
        if f(x+t*step) < f(x) + alpha * t * grad_f @ step:
          break
        t = t * beta
      
      x = x + t * step

      if np.sqrt(np.sum((x - res["x"][-1])**2)) < tol:
        break

      res["x"].append(x)
      res["f"].append(f(x))
    
    return res

2d quadratic cost function

f, grad, hess = mk_quad(0.7)
opt = newtons_method_damped((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

f, grad, hess = mk_quad(0.05)
opt = newtons_method_damped((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

Rosenbrock function

f, grad, hess = mk_rosenbrock()
opt = newtons_method_damped((1.6, 1.1), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = newtons_method_damped((-0.5, 0), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Limitations of Newton’s Method

  • Requires both the gradient and Hessian - second derivatives may be unavailable or expensive to compute

  • Storing the \(n \times n\) Hessian costs \(O(n^2)\) memory and inverting it costs \(O(n^3)\) per step - impractical for large \(n\)

  • Can diverge or behave poorly if the Hessian is indefinite (not positive definite) away from the minimum

  • Like GD, finds local minima and is sensitive to starting location

  • Damping (backtracking) helps with robustness but adds cost and complexity

Conjugate gradients

Conjugate gradients

is a general approach for solving a system of linear equations with the form \(\boldsymbol{A} x=\boldsymbol{b}\) where \(\boldsymbol{A}\) is an \(n \times n\) symmetric positive definite matrix and \(\boldsymbol{b}\) is \(n \times 1\) with \(x\) the unknown vector of interest.


This type of problem can also be expressed as a quadratic minimization problem of the form,

\[ \underset{x}{\text{arg min}} \; f(x) = \underset{x}{\text{arg min}} \; \frac{1}{2} x^T \, \boldsymbol{A} \, x - x^T \boldsymbol{b} \]

since the solution is given by

\[ \nabla f(x) = \boldsymbol{A}x - \boldsymbol{b} = 0 \]

A? Conjugate?

Taking things one step further we can also see that the matrix \(\boldsymbol{A}\) is given by the Hessian of \(f(x)\)

\[ \nabla^2 f(x) = \boldsymbol{A} \]


Additionally recall, two non-zero vectors \(\boldsymbol{u}\) and \(\boldsymbol{v}\) are conjugate with respect to \(\boldsymbol{A}\) if

\[\boldsymbol{u}^{T} \boldsymbol{A} \boldsymbol{v} = 0\]

Our goal then is to find \(n\) conjugate vectors \(P = \{\boldsymbol{p}_1, \ldots, \boldsymbol{p}_n\}\) to use as “optimal” step directions to traverse our objective function.

The big picture

The core problem with gradient descent on ill-conditioned functions is wasteful repetition - each step ignores the information gathered by all previous steps, so the algorithm tends to zigzag.

Conjugate gradients fixes this by choosing search directions that are non-interfering: once you have minimized along a direction \(\boldsymbol{p}_k\), the next step \(\boldsymbol{p}_{k+1}\) is chosen so that it does not undo that progress.

This is achieved by making directions conjugate with respect to the curvature (\(\boldsymbol{A}\)) rather than just orthogonal - think of it as orthogonality that has been warped to match the shape of the objective.

The payoff: for an \(n\)-dimensional quadratic, CG finds the exact minimum in at most \(n\) steps, regardless of conditioning - something GD could take thousands of steps to achieve.

Algorithm Sketch

For the \(k\)th step:

  • Define the residual \(\boldsymbol{r}_k = \boldsymbol{b} - \boldsymbol{A} x_k\)

  • Define conjugate vector \(\boldsymbol{p}_k\) using current residual and all previous search directions \[ \boldsymbol{p}_k = \boldsymbol{r}_k - \sum_{i<k} \frac{\boldsymbol{r}_k^T \boldsymbol{A} \boldsymbol{p}_i}{\boldsymbol{p}_i^T \boldsymbol{A} \boldsymbol{p}_i} \boldsymbol{p}_i \]

  • Define step size \(\alpha_k\) using \[ \alpha_k = \frac{\boldsymbol{p}_k^T\boldsymbol{r}_k}{\boldsymbol{p}_k^T \boldsymbol{A} \boldsymbol{p}_k} \]

  • Update \(x_{k+1} = x_k + \alpha_k \boldsymbol{p}_k\)

Algorithm in practice

Given \(x_0\) we set the following initial values, \[ \begin{align} r_0 &= \nabla f(x_0) \\ p_0 &= -r_0 \\ k &= 0 \end{align} \]

while \(\|r_k\|_2 > \text{tol}\), \[ \begin{align} \alpha_k &= \frac{r_k^T \, p_k}{p_k^T \, \nabla^2 f(x_k) \, p_k} \\ x_{k+1} &= x_k + \alpha_k \, p_k \\ r_{k+1} &= \nabla f(x_{k+1}) \\ \beta_{k+1} &= \frac{ r^T_{k+1} \, \nabla^2 f(x_k) \, p_{k} }{p_k^T \, \nabla^2 f(x_k) \, p_k} \\ p_{k+1} &= -r_{k+1} + \beta_{k+1} \, p_k \\ k &= k+1 \end{align} \]

def conjugate_gradient(x, f, grad, hess, max_iter=100, tol=1e-8):
    res = {"x": [x], "f": [f(x)]}
    r = grad(x)
    p = -r
    
    for i in range(max_iter):
      H = hess(x)
      a = - r.T @ p / (p.T @ H @ p)
      x = x + a * p
      r = grad(x)
      b = (r.T @ H @ p) / (p.T @ H @ p)
      p = -r + b * p
      
      if np.sqrt(np.sum(r**2)) < tol:
        break

      res["x"].append(x) 
      res["f"].append(f(x))
  
    return res

Trajectory

f, grad, hess = mk_quad(0.7)
opt = conjugate_gradient((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, title="$\\epsilon=0.7$", traj=opt)

f, grad, hess = mk_quad(0.05)
opt = conjugate_gradient((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, title="$\\epsilon=0.05$", traj=opt)

Rosenbrock’s function

f, grad, hess = mk_rosenbrock()
opt = conjugate_gradient((1.6, 1.1), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = conjugate_gradient((-0.5, 0), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Limitations of Conjugate Gradients

  • Requires the Hessian to be symmetric positive definite - not applicable to general non-convex problems without modification

  • Like Newton’s method, storing the Hessian costs \(O(n^2)\) memory and computing \(\boldsymbol{A}\boldsymbol{p}\) costs \(O(n^2)\) per step

  • Guaranteed to find the exact solution in at most \(n\) steps only for quadratic objectives - non-quadratic functions require restarts or approximate treatment

  • Sensitive to round-off error in practice; periodic restarts (resetting \(\boldsymbol{p}_k = -\boldsymbol{r}_k\)) are often used to maintain numerical stability

  • Performance degrades when the Hessian changes significantly between steps (highly non-quadratic regions)

Summary

Method Comparison

Method Convergence Cost per step Needs Hessian? Notes
Gradient Descent Linear \(O(n)\) No Simple but slow on ill-conditioned problems
Newton’s Method Quadratic \(O(n^3)\) Yes Fast near minimum; expensive for large \(n\)
Conjugate Gradient Superlinear \(O(n^2)\) Yes Between GD and Newton; exact in \(n\) steps for quadratics

Convergence rates matter in practice:

  • Linear convergence means the error reduces by a constant factor each step - fine for well-conditioned problems, slow otherwise

  • Quadratic convergence means the number of correct digits roughly doubles each step - Newton’s method achieves this near the minimum

  • Superlinear convergence is faster than linear but slower than quadratic

None of these are what we use in practice - more on what we do use next time.