隨機梯度下降法(Stochastic gradient descent, SGD)+python實現! |
文章目錄
- 一、設定樣本
- 二、梯度下降法原理
- 三、BGD:批量梯度下降法
- 3.1、python代碼實現:
- 3.2、損失函數值曲線
- 3.3、BGD總結
- 四、SGD:隨機梯度下降法
- 4.1、python代碼實現:
- 4.2、損失函數值曲線
- 4.3、SGD總結
- 五、MBGD:小批量梯度下降
- 5.1、python代碼實現:
- 5.2、損失函數值曲線
- 5.3、MBGD總結
- 參考文章
一、設定樣本
假設我們提供了這樣的數據樣本(樣本值取自于 y = 3 x 1 + 4 x 2 y=3x_{1}+4x_{2} y = 3 x 1 ? + 4 x 2 ? ):其中: x 1 x_{1} x 1 ? 和 x 2 x_{2} x 2 ? 是樣本值, y y y 是預測目標。
x 1 x_{1} x 1 ? | x 2 x_{2} x 2 ? | y y y |
---|---|---|
1 | 4 | 19 |
2 | 5 | 26 |
5 | 1 | 19 |
4 | 2 | 29 |
我們需要以一條直線來擬合上面的數據,待擬合的函數如下:
(1) h ( Θ ) = Θ 1 x 1 + Θ 2 x 2 h(\Theta)=\Theta_{1} x_{1}+\Theta_{2} x_{2}\tag{1}
h
(
Θ
)
=
Θ
1
?
x
1
?
+
Θ
2
?
x
2
?
(
1
)
我們的目的就是要求出
Θ 1 \Theta_{1}
Θ
1
?
和
Θ 2 \Theta_{2}
Θ
2
?
的值,讓
h ( Θ ) h(\Theta)
h
(
Θ
)
盡量逼近目標值
y y
y
。這是一個線性回歸問題,若對線性回歸有所了解的話我們知道:利用
最小二乘法則和梯度下降法
可以求出兩個參數,而深度學習也同樣可以利用這兩種方法求得所有的網絡參數,因此,在這里用這個數學模型來解釋BGD、SGD、MSGD這幾個概念。
二、梯度下降法原理
我們首先確定損失函數如下(均方誤差):
(2) J ( Θ ) = 1 2 m ∑ i = 1 m [ h Θ ( x i ) ? y i ] 2 J(\Theta)=\frac{1}{2 m} \sum_{i=1}^{m}\left[h_{\Theta}\left(x^{i}\right)-y^{i}\right]^{2}\tag{2}
J
(
Θ
)
=
2
m
1
?
i
=
1
∑
m
?
[
h
Θ
?
(
x
i
)
?
y
i
]
2
(
2
)
其中:
J ( Θ ) J(\Theta)
J
(
Θ
)
是損失函數,
m m
m
代表每次取多少樣本進行訓練,如果采用
S G D SGD
S
G
D
進行訓練,那每次隨機取一個樣本
m = 1 m=1
m
=
1
;如果是批處理,則
m m
m
等于每次抽取作為訓練樣本的數量。
Θ \Theta
Θ
是參數,對應式
( 1 ) (1)
(
1
)
的
Θ 1 \Theta_{1}
Θ
1
?
和
Θ 2 \Theta_{2}
Θ
2
?
。求出了
Θ 1 \Theta_{1}
Θ
1
?
和
Θ 2 \Theta_{2}
Θ
2
?
,
h ( Θ ) h(\Theta)
h
(
Θ
)
的表達式就出來了:
(3) h ( Θ ) = ∑ Θ j x j = Θ 1 x 1 + Θ 2 x 2 h(\Theta)=\sum \Theta_{j} x_{j}=\Theta_{1} x_{1}+\Theta_{2} x_{2}\tag{3}
h
(
Θ
)
=
∑
Θ
j
?
x
j
?
=
Θ
1
?
x
1
?
+
Θ
2
?
x
2
?
(
3
)
我們的目標是讓損失函數
J ( Θ ) J(\Theta)
J
(
Θ
)
的值最小,根據梯度下降法,首先要用
J ( Θ ) J(\Theta)
J
(
Θ
)
對
Θ \Theta
Θ
求偏導:
(4) σ J ( Θ ) σ Θ j = 2 1 2 m ∑ i = 1 m [ h Θ ( x i ) ? y i ] x j i = 1 m ∑ i = 1 m [ h Θ ( x i ) ? y i ] x j i \frac{\sigma J(\Theta)}{\sigma \Theta_{j}}=2 \frac{1}{2 m} \sum_{i=1}^{m}\left[h_{\Theta}\left(x^{i}\right)-y^{i}\right] x_{j}^{i}=\frac{1}{m} \sum_{i=1}^{m}\left[h_{\Theta}\left(x^{i}\right)-y^{i}\right] x_{j}^{i}\tag{4}
σ
Θ
j
?
σ
J
(
Θ
)
?
=
2
2
m
1
?
i
=
1
∑
m
?
[
h
Θ
?
(
x
i
)
?
y
i
]
x
j
i
?
=
m
1
?
i
=
1
∑
m
?
[
h
Θ
?
(
x
i
)
?
y
i
]
x
j
i
?
(
4
)
由于是要最小化損失函數,所以參數
Θ \Theta
Θ
按其負梯度方向來更新:其中
α \alpha
α
是學習率;
(5) Θ j ′ = Θ j ? α σ J ( Θ ) σ Θ j = Θ j ? α 1 m ∑ m i = 1 ( y i ? h Θ ( x i ) ) x j i \Theta_{j}^{\prime}=\Theta_{j}-\alpha\frac{\sigma J(\Theta)}{\sigma \Theta_{j}}=\Theta_{j}-\alpha \frac{1}{m} \sum_{m}^{i=1}\left(y^{i}-h_{\Theta}\left(x^{i}\right)\right) x_{j}^{i}\tag{5}
Θ
j
′
?
=
Θ
j
?
?
α
σ
Θ
j
?
σ
J
(
Θ
)
?
=
Θ
j
?
?
α
m
1
?
m
∑
i
=
1
?
(
y
i
?
h
Θ
?
(
x
i
)
)
x
j
i
?
(
5
)
三、BGD:批量梯度下降法
BGD(Batch gradient descent)批量梯度下降法:每次迭代使用所有的樣本!
- BGD(批量梯度下降):更新每一參數都用所有樣本更新,m=all,更新100次遍歷所有數據100次
- 優點 :每次迭代都需要把所有樣本都送入,這樣的好處是每次迭代都顧及了全部的樣本,能保證做的是全局最優化。
- 缺點 :由于這種方法是在一次更新中,就對整個數據集計算梯度,所以計算起來非常慢,遇到很大量的數據集也會非常棘手,而且不能投入新數據實時更新模型。
3.1、python代碼實現:
import random
import matplotlib.pyplot as plt
#用y = Θ1*x1 + Θ2*x2來擬合下面的輸入和輸出
#input1 1 2 5 4
#input2 4 5 1 2
#output 19 26 19 20
input_x = [[1, 4], [2, 5], [5, 1], [4, 2]]
output_y = [19, 26, 19, 20]
theta = [1, 1] #θ參數初始化
loss = 10 #loss先定義一個數,為了進入循環迭代
lr = 0.01 #學習率(步長)
eps =0.0001 #精度要求
max_iters = 10000 #最大迭代次數
error = 0 #損失值
iter_count = 0 #當前迭代次數
err1=[0,0,0,0] #求Θ1梯度的中間變量1
err2=[0,0,0,0] #求Θ2梯度的中間變量2
loss_curve = []
iter_curve= []
while loss > eps and iter_count < max_iters: #迭代條件
loss = 0
err1sum = 0
err2sum = 0
for i in range(4): #每次迭代所有的樣本都進行訓練
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 預測值
err1[i] = (pred_y - output_y[i]) * input_x[i][0]
err1sum += err1[i]
err2[i] = (pred_y - output_y[i]) * input_x[i][1]
err2sum += err2[i]
theta[0] = theta[0] - lr * err1sum / 4 # 對應公式(5)
theta[1] = theta[1] - lr * err2sum / 4 # 對應公式(5)
for i in range(4):
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 預測值
error = (1 / (2 * 4)) * (pred_y - output_y[i])**2 #損失值
loss = loss + error #總損失值
loss_curve.append(loss)
iter_curve.append(iter_count)
iter_count += 1
print('iter_count:', iter_count, 'loss:', loss)
print('final theta:', theta)
print('final loss:', loss)
print('final iter_count:', iter_count)
plt.plot(iter_curve, loss_curve, linewidth=3.0, label = ' loss value ')
plt.xlabel('iter_count')
plt.ylabel('loss value')
plt.legend(loc='upper right')
plt.show()
- 運行結果:
C:\Anaconda3\envs\tf2\python.exe E:/Codes/MyCodes/TF2/TF2_6/bgd.py
iter_count: 1 loss: 77.30604843750001
iter_count: 2 loss: 51.92155212969726
iter_count: 3 loss: 34.93305894023124
iter_count: 4 loss: 23.55788142744176
iter_count: 5 loss: 15.93612681814619
iter_count: 6 loss: 10.82463082981111
iter_count: 7 loss: 7.392409964576455
iter_count: 8 loss: 5.083959396333441
iter_count: 9 loss: 3.527892640305925
iter_count: 10 loss: 2.47588355411119
iter_count: 11 loss: 1.7618596986778925
iter_count: 12 loss: 1.2747299205213376
iter_count: 13 loss: 0.9401570394404473
iter_count: 14 loss: 0.7083755280604223
iter_count: 15 loss: 0.5460491800945114
iter_count: 16 loss: 0.4308288035166698
iter_count: 17 loss: 0.3477144254831318
iter_count: 18 loss: 0.28662352563217885
iter_count: 19 loss: 0.2407653218442937
iter_count: 20 loss: 0.20555379684871375
iter_count: 21 loss: 0.1778808149288243
iter_count: 22 loss: 0.1556299310282455
iter_count: 23 loss: 0.13735109414029084
iter_count: 24 loss: 0.12204291187159358
iter_count: 25 loss: 0.10900683012823309
iter_count: 26 loss: 0.09774940254077526
iter_count: 27 loss: 0.08791672434030241
iter_count: 28 loss: 0.07925038532524892
iter_count: 29 loss: 0.07155782538929152
iter_count: 30 loss: 0.06469233462112152
iter_count: 31 loss: 0.058539516396843426
iter_count: 32 loss: 0.053008085575962184
iter_count: 33 loss: 0.04802357824982824
iter_count: 34 loss: 0.04352402034121233
iter_count: 35 loss: 0.03945691714912546
iter_count: 36 loss: 0.03577713642499401
iter_count: 37 loss: 0.03244539834116959
iter_count: 38 loss: 0.029427179885821626
iter_count: 39 loss: 0.026691904238495462
iter_count: 40 loss: 0.024212327873428634
iter_count: 41 loss: 0.021964066404403914
iter_count: 42 loss: 0.019925219138275187
iter_count: 43 loss: 0.01807606502771293
iter_count: 44 loss: 0.01639881126834203
iter_count: 45 loss: 0.014877381549280188
iter_count: 46 loss: 0.013497234860471212
iter_count: 47 loss: 0.012245208401310822
iter_count: 48 loss: 0.011109379935017922
iter_count: 49 loss: 0.010078946167772749
iter_count: 50 loss: 0.009144114585438616
iter_count: 51 loss: 0.008296006777324977
iter_count: 52 loss: 0.007526571698840795
iter_count: 53 loss: 0.0068285076286013985
iter_count: 54 loss: 0.006195191797999068
iter_count: 55 loss: 0.005620616837548782
iter_count: 56 loss: 0.005099333311490803
iter_count: 57 loss: 0.004626397711632085
iter_count: 58 loss: 0.0041973253611029835
iter_count: 59 loss: 0.003808047743918568
iter_count: 60 loss: 0.003454873830664867
iter_count: 61 loss: 0.003134455016855624
iter_count: 62 loss: 0.0028437533303267305
iter_count: 63 loss: 0.002580012598749813
iter_count: 64 loss: 0.002340732298900542
iter_count: 65 loss: 0.0021236438364046354
iter_count: 66 loss: 0.0019266890288379284
iter_count: 67 loss: 0.0017480005866892083
iter_count: 68 loss: 0.001585884406131675
iter_count: 69 loss: 0.0014388035050594725
iter_count: 70 loss: 0.001305363449643687
iter_count: 71 loss: 0.0011842991329443218
iter_count: 72 loss: 0.0010744627800306976
iter_count: 73 loss: 0.0009748130657567269
iter_count: 74 loss: 0.0008844052419326704
iter_count: 75 loss: 0.0008023821802307486
iter_count: 76 loss: 0.0007279662458677689
iter_count: 77 loss: 0.000660451924992406
iter_count: 78 loss: 0.0005991991358649235
iter_count: 79 loss: 0.0005436271604005447
iter_count: 80 loss: 0.0004932091385371781
iter_count: 81 loss: 0.0004474670732236192
iter_count: 82 loss: 0.0004059672986700165
iter_count: 83 loss: 0.0003683163688928273
iter_count: 84 loss: 0.0003341573275743529
iter_count: 85 loss: 0.00030316632387097236
iter_count: 86 loss: 0.0002750495420852944
iter_count: 87 loss: 0.0002495404160923516
iter_count: 88 loss: 0.00022639710211123758
iter_count: 89 loss: 0.00020540018586152943
iter_count: 90 loss: 0.00018635060236680122
iter_count: 91 loss: 0.0001690677486837031
iter_count: 92 loss: 0.0001533877716635301
iter_count: 93 loss: 0.000139162014513147
iter_count: 94 loss: 0.00012625560742820846
iter_count: 95 loss: 0.00011454618893581613
iter_count: 96 loss: 0.00010392274582505133
iter_count: 97 loss: 9.428456066652548e-05
final theta: [3.0044552563214433, 3.9955447274498894]
final loss: 9.428456066652548e-05
final iter_count: 97
3.2、損失函數值曲線
- 可以發現下降趨勢比較平穩!

3.3、BGD總結
這里我們只有4個樣本,所以訓練的時間不長。但是,如果面對數量巨大的樣本量(如40萬個),采取這種訓練方式,所耗費的時間會非常長。Batch gradient descent 對于凸函數可以收斂到全局極小值,對于非凸函數可以收斂到局部極小值。
四、SGD:隨機梯度下降法
SGD(Stochastic gradientdescent)隨機梯度下降法:每次迭代使用一個樣本!
- 針對BGD算法訓練速度過慢的缺點,提出了SGD算法,普通的BGD算法是每次迭代把所有樣本都過一遍,每訓練一組樣本就把梯度更新一次。而SGD算法是從樣本中隨機抽出一組,訓練后按梯度更新一次,然后再抽取一組,再更新一次,在樣本量及其大的情況下,可能不用訓練完所有的樣本就可以獲得一個損失值在可接受范圍之內的模型了。
- SGD(隨機梯度下降):更新每一參數都隨機選擇一個樣本更新, m = 1 m=1 m = 1 。
4.1、python代碼實現:
- 代碼如下:
import random
#用y = Θ1*x1 + Θ2*x2來擬合下面的輸入和輸出
#input1 1 2 5 4
#input2 4 5 1 2
#output 19 26 19 20
input_x = [[1, 4], [2, 5], [5, 1], [4, 2]]
output_y = [19, 26, 19, 20]
theta = [1, 1] #θ參數初始化
loss = 10 #loss先定義一個數,為了進入循環迭代
lr = 0.01 #學習率(步長)
eps =0.0001 #精度要求
max_iters = 10000 #最大迭代次數
error = 0 #損失值
iter_count = 0 #當前迭代次數
while loss > eps and iter_count < max_iters: #迭代條件
loss = 0
# 0、1、2、3隨意一個,包括3
i = random.randint(0, 3) #每次迭代在input_x中隨機選取一組樣本進行權重的更新
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 預測值
theta[0] = theta[0] - lr * (pred_y - output_y[i]) * input_x[i][0]
theta[1] = theta[1] - lr * (pred_y - output_y[i]) * input_x[i][1]
for i in range(4):
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 預測值
error = 0.5 * (pred_y - output_y[i])**2 #損失值
loss = loss + error #總損失值
iter_count += 1
print('iter_count:', iter_count, 'loss:', loss)
print('final theta:', theta)
print('final loss:', loss)
print('final iter_count:', iter_count)
- 運行結果:
C:\Anaconda3\envs\tf2\python.exe E:/Codes/MyCodes/TF2/TF2_6/sgd.py
iter_count: 1 loss: 260.8896
iter_count: 2 loss: 204.61415423999998
iter_count: 3 loss: 157.22159618457601
iter_count: 4 loss: 124.88134971162623
iter_count: 5 loss: 111.84245738309359
iter_count: 6 loss: 103.09215224055406
iter_count: 7 loss: 88.76429814860212
iter_count: 8 loss: 60.029919860818374
iter_count: 9 loss: 59.32261643585773
iter_count: 10 loss: 58.80735498266485
iter_count: 11 loss: 53.54691203311109
iter_count: 12 loss: 49.70774185517555
iter_count: 13 loss: 27.482152507395202
iter_count: 14 loss: 29.053473973655233
iter_count: 15 loss: 17.241137945844635
iter_count: 16 loss: 18.27205019986861
iter_count: 17 loss: 18.279734003553486
iter_count: 18 loss: 13.291935986982754
iter_count: 19 loss: 14.080979216135848
iter_count: 20 loss: 14.91676557815299
iter_count: 21 loss: 10.592300543806674
iter_count: 22 loss: 11.235546322899923
iter_count: 23 loss: 11.836768171833583
iter_count: 24 loss: 11.485334037115793
iter_count: 25 loss: 6.810177856305382
iter_count: 26 loss: 4.6959930466488675
iter_count: 27 loss: 3.9105771662118256
iter_count: 28 loss: 3.444454636311172
iter_count: 29 loss: 3.0123913259123145
iter_count: 30 loss: 2.5679800277626557
iter_count: 31 loss: 2.3099925227435065
iter_count: 32 loss: 1.9590397523779994
iter_count: 33 loss: 1.9666361109780515
iter_count: 34 loss: 1.957111930033281
iter_count: 35 loss: 2.051533423257222
iter_count: 36 loss: 1.5311103292311716
iter_count: 37 loss: 1.6083965712408976
iter_count: 38 loss: 1.195994912123778
iter_count: 39 loss: 1.2294713235041943
iter_count: 40 loss: 0.9457370295093082
iter_count: 41 loss: 0.7388360036353944
iter_count: 42 loss: 0.7067997228867577
iter_count: 43 loss: 0.7302176054470183
iter_count: 44 loss: 0.7733523408616976
iter_count: 45 loss: 0.8194031381390366
iter_count: 46 loss: 0.4936013810958017
iter_count: 47 loss: 0.37792875276479576
iter_count: 48 loss: 0.3917742838122076
iter_count: 49 loss: 0.27063075933550573
iter_count: 50 loss: 0.21948139835860636
iter_count: 51 loss: 0.200739559185535
iter_count: 52 loss: 0.16909904922322447
iter_count: 53 loss: 0.14881528969674812
iter_count: 54 loss: 0.139454946004233
iter_count: 55 loss: 0.13791962027556026
iter_count: 56 loss: 0.13940575127970778
iter_count: 57 loss: 0.11836942697279389
iter_count: 58 loss: 0.11781711553567996
iter_count: 59 loss: 0.1193375232933355
iter_count: 60 loss: 0.12138112869067261
iter_count: 61 loss: 0.12332118692472657
iter_count: 62 loss: 0.12467078283377861
iter_count: 63 loss: 0.07858332982329701
iter_count: 64 loss: 0.07002100640810523
iter_count: 65 loss: 0.06684354978608165
iter_count: 66 loss: 0.06681954120255698
iter_count: 67 loss: 0.05813640578322409
iter_count: 68 loss: 0.04822861760481913
iter_count: 69 loss: 0.04278262637099636
iter_count: 70 loss: 0.041198736725976785
iter_count: 71 loss: 0.04122491647644537
iter_count: 72 loss: 0.041360454874579164
iter_count: 73 loss: 0.028430758340382445
iter_count: 74 loss: 0.02673444676702969
iter_count: 75 loss: 0.024607291352450787
iter_count: 76 loss: 0.023584031391060585
iter_count: 77 loss: 0.018509550739981902
iter_count: 78 loss: 0.017020377213216822
iter_count: 79 loss: 0.01639800972323979
iter_count: 80 loss: 0.016217979857842655
iter_count: 81 loss: 0.01173575386765937
iter_count: 82 loss: 0.011095127889579639
iter_count: 83 loss: 0.011025658424848761
iter_count: 84 loss: 0.011162231122708137
iter_count: 85 loss: 0.011281435021340219
iter_count: 86 loss: 0.007484744689449223
iter_count: 87 loss: 0.007443897686619277
iter_count: 88 loss: 0.00753805887938067
iter_count: 89 loss: 0.005266781977291559
iter_count: 90 loss: 0.005050442824531053
iter_count: 91 loss: 0.005061511925216842
iter_count: 92 loss: 0.005082620784454073
iter_count: 93 loss: 0.004233203826419785
iter_count: 94 loss: 0.0033139392411621234
iter_count: 95 loss: 0.0032913233079854493
iter_count: 96 loss: 0.0028481973056087625
iter_count: 97 loss: 0.002706631981691858
iter_count: 98 loss: 0.002492232603268663
iter_count: 99 loss: 0.002221730694417584
iter_count: 100 loss: 0.0018638132079459395
iter_count: 101 loss: 0.0018362099891547448
iter_count: 102 loss: 0.0016114266187654506
iter_count: 103 loss: 0.001557626575305546
iter_count: 104 loss: 0.0014751564038771846
iter_count: 105 loss: 0.0015177552847051115
iter_count: 106 loss: 0.0015527624369436643
iter_count: 107 loss: 0.0010676134405512782
iter_count: 108 loss: 0.0008875592185810756
iter_count: 109 loss: 0.0008989428657822404
iter_count: 110 loss: 0.0009062642743213555
iter_count: 111 loss: 0.0006607681707515327
iter_count: 112 loss: 0.0005677737130591256
iter_count: 113 loss: 0.0005661179132734488
iter_count: 114 loss: 0.0004906413380605921
iter_count: 115 loss: 0.0004485281635358341
iter_count: 116 loss: 0.0004277189872613536
iter_count: 117 loss: 0.0004272993211319688
iter_count: 118 loss: 0.00042831950236645467
iter_count: 119 loss: 0.0004337135792476141
iter_count: 120 loss: 0.0004413233162548959
iter_count: 121 loss: 0.0003468706133226126
iter_count: 122 loss: 0.0003488193444937763
iter_count: 123 loss: 0.00028928141012031104
iter_count: 124 loss: 0.00028583399647533843
iter_count: 125 loss: 0.0002901585029896304
iter_count: 126 loss: 0.0002939322848970706
iter_count: 127 loss: 0.00023730769207386048
iter_count: 128 loss: 0.0001779396656655188
iter_count: 129 loss: 0.00015978049098605478
iter_count: 130 loss: 0.00013564931442016493
iter_count: 131 loss: 0.0001361042816188513
iter_count: 132 loss: 0.0001153612600831637
iter_count: 133 loss: 0.0001132019901158441
iter_count: 134 loss: 0.00010928002673695943
iter_count: 135 loss: 0.00011264583765529378
iter_count: 136 loss: 8.89461228685356e-05
final theta: [3.002053708602476, 3.997626634178193]
final loss: 8.89461228685356e-05
final iter_count: 136
4.2、損失函數值曲線
- 可以發現下降趨勢優點波折,沒有剛才那樣平穩(這里效果不是特別明顯)!

4.3、SGD總結
隨機梯度下降是通過每個樣本來迭代更新一次,如果樣本量很大的情況,那么可能只用其中部分的樣本,就已經將theta迭代到最優解了,對比上面的批量梯度下降,迭代一次需要用到十幾萬訓練樣本,一次迭代不可能最優,如果迭代10次的話就需要遍歷訓練樣本10次。缺點是SGD的噪音較BGD要多,使得SGD并不是每次迭代都向著整體最優化方向。所以雖然訓練速度快,但是準確度下降,并不是全局最優。雖然包含一定的隨機性,但是從期望上來看,它是等于正確的導數的。
缺點:
-
SGD 因為更新比較頻繁,會造成 cost function 有嚴重的震蕩。
-
BGD 可以收斂到局部極小值,當然 SGD 的震蕩可能會跳到更好的局部極小值處。
-
當我們稍微減小 learning rate,SGD 和 BGD 的收斂性是一樣的。
五、MBGD:小批量梯度下降
MBGD(Mini-batch gradient descent)小批量梯度下降:每次迭代使用b組樣本!
- SGD相對來說要快很多,但是也有存在問題,由于單個樣本的訓練可能會帶來很多噪聲, 使得SGD并不是每次迭代都向著整體最優化方向,因此在剛開始訓練時可能收斂得很快,但是訓練一段時間后就會變得很慢。在此基礎上又提出了 小批量梯度下降法 ,它是每次從樣本中隨機抽取一小批進行訓練,而不是一組。
5.1、python代碼實現:
import random
import matplotlib.pyplot as plt
#用y = Θ1*x1 + Θ2*x2來擬合下面的輸入和輸出
#input1 1 2 5 4
#input2 4 5 1 2
#output 19 26 19 20
input_x = [[1, 4], [2, 5], [5, 1], [4, 2]]
output_y = [19, 26, 19, 20]
theta = [1, 1] #θ參數初始化
loss = 10 #loss先定義一個數,為了進入循環迭代
lr = 0.01 #學習率(步長)
eps =0.0001 #精度要求
max_iters = 10000 #最大迭代次數
error = 0 #損失值
iter_count = 0 #當前迭代次數
loss_curve = []
iter_curve= []
while loss > eps and iter_count < max_iters: #迭代條件
loss = 0
# 這里每次批量選取的是2個樣本進行更新,另一個點是隨機點+1的相鄰點
# 0、1、2、3隨意一個,包括3
i = random.randint(0, 3) # 隨機抽取一組樣本
j = (i + 1) % 4 # 抽取另一組樣本,j=i+1
pred_y0 = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 預測值
pred_y1 = theta[0] * input_x[j][0] + theta[1] * input_x[j][1] # 預測值
theta[0] = theta[0] - lr * (1 / 2) * ((pred_y0 - output_y[i]) * input_x[i][0] + (pred_y1 - output_y[j]) * input_x[j][0])# 對應5式
theta[1] = theta[1] - lr * (1 / 2) * ((pred_y0 - output_y[i]) * input_x[i][1] + (pred_y1 - output_y[j]) * input_x[j][1])
for i in range(4):
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 預測值
error = (1/(2*2)) * (pred_y - output_y[i])**2 #損失值
loss = loss + error #總損失值
loss_curve.append(loss)
iter_curve.append(iter_count)
iter_count += 1
print('iter_count:', iter_count, 'loss:', loss)
print('final theta:', theta)
print('final loss:', loss)
print('final iter_count:', iter_count)
plt.plot(iter_curve, loss_curve, linewidth=3.0, label = ' SGD loss value ')
plt.xlabel('iter_count')
plt.ylabel('loss value')
plt.legend(loc='upper right')
plt.show()
- 運行結果:
C:\Anaconda3\envs\tf2\python.exe E:/Codes/MyCodes/TF2/TF2_6/MBGD.py
iter_count: 1 loss: 145.51273750000001
iter_count: 2 loss: 93.61212678531253
iter_count: 3 loss: 63.84553483319602
iter_count: 4 loss: 40.78188950700523
iter_count: 5 loss: 26.635296256279755
iter_count: 6 loss: 19.14453037224417
iter_count: 7 loss: 11.863074835686199
iter_count: 8 loss: 7.351282634603321
iter_count: 9 loss: 5.285122186891444
iter_count: 10 loss: 3.5150322864765666
iter_count: 11 loss: 2.5222674128094504
iter_count: 12 loss: 1.726271885060386
iter_count: 13 loss: 1.0782688839189605
iter_count: 14 loss: 0.7734545491420018
iter_count: 15 loss: 0.4847190164962192
iter_count: 16 loss: 0.30268382725301546
iter_count: 17 loss: 0.21710714208249052
iter_count: 18 loss: 0.15636548583181112
iter_count: 19 loss: 0.0951813297472256
iter_count: 20 loss: 0.06014949607215332
iter_count: 21 loss: 0.04555522775364396
iter_count: 22 loss: 0.029448135848294105
iter_count: 23 loss: 0.019291337709467643
iter_count: 24 loss: 0.012850369417312057
iter_count: 25 loss: 0.00873525233733187
iter_count: 26 loss: 0.005002257350216796
iter_count: 27 loss: 0.004675521915666117
iter_count: 28 loss: 0.0034010908555124775
iter_count: 29 loss: 0.003314771376617083
iter_count: 30 loss: 0.002003304782844371
iter_count: 31 loss: 0.002035716129121544
iter_count: 32 loss: 0.001600873625308816
iter_count: 33 loss: 0.0016385823787498255
iter_count: 34 loss: 0.001668187043426684
iter_count: 35 loss: 0.0010776974297530592
iter_count: 36 loss: 0.0009406763734116385
iter_count: 37 loss: 0.0007948887717540301
iter_count: 38 loss: 0.0007110442660078077
iter_count: 39 loss: 0.0007262540103407296
iter_count: 40 loss: 0.0007492868334222326
iter_count: 41 loss: 0.0007731591748400462
iter_count: 42 loss: 0.0005294763719741898
iter_count: 43 loss: 0.0005440182877813707
iter_count: 44 loss: 0.00045517250370294903
iter_count: 45 loss: 0.00040394528620568096
iter_count: 46 loss: 0.0003627723578104964
iter_count: 47 loss: 0.00036950515296137464
iter_count: 48 loss: 0.0003127529335059308
iter_count: 49 loss: 0.0002801383258107215
iter_count: 50 loss: 0.00024161953198902414
iter_count: 51 loss: 0.0002452223039598993
iter_count: 52 loss: 0.0001869776452880176
iter_count: 53 loss: 0.00018574447211436057
iter_count: 54 loss: 0.00015989716877477456
iter_count: 55 loss: 0.00013838733780820178
iter_count: 56 loss: 0.00012025272881104103
iter_count: 57 loss: 0.00012094637362213334
iter_count: 58 loss: 0.00012421504197512097
iter_count: 59 loss: 9.101480560449152e-05
final theta: [3.0025974709772676, 3.996606737971915]
final loss: 9.101480560449152e-05
final iter_count: 59
5.2、損失函數值曲線

5.3、MBGD總結
MBGD(小批量梯度下降):更新每一參數都選 m m m 個樣本平均梯度更新, 1 < m < a l l 1<m<all 1 < m < a l l ;超參數 m m m 設定值: m m m 一般取值在 50 ~ 256 50~256 5 0 ~ 2 5 6
- MBGD 每一次利用一小批樣本,即 m 個樣本進行計算, 這樣它可以降低參數更新時的方差,收斂更穩定 ,另一方面可以充分地 利用深度學習庫中高度優化的矩陣操作來進行更有效的梯度計算。
缺點:(兩大缺點)
- 不過 Mini-batch gradient descent 不能保證很好的收斂性,learning rate 如果選擇的太小,收斂速度會很慢,如果太大,loss function 就會在極小值處不停地震蕩甚至偏離。(有一種措施是先設定大一點的學習率,當兩次迭代之間的變化低于某個閾值后,就減小 learning rate,不過這個閾值的設定需要提前寫好,這樣的話就不能夠適應數據集的特點。) 對于非凸函數, 還要避免陷于局部極小值處,或者鞍點處, 因為鞍點周圍的error是一樣的,所有維度的梯度都接近于0,SGD 很容易被困在這里。( 會在鞍點或者局部最小點震蕩跳動,因為在此點處,如果是訓練集全集帶入即BGD,則優化會停止不動,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就會發生震蕩,來回跳動。 )
- SGD對所有參數更新時應用同樣的 learning rate,如果我們的數據是稀疏的, 我們更希望對出現頻率低的特征進行大一點的更新。LR會隨著更新的次數逐漸變小。
鞍點就是:一個光滑函數的鞍點鄰域的曲線,曲面,或超曲面,都位于這點的切線的不同邊。例如這個二維圖形,像個馬鞍:在x-軸方向往上曲,在y-軸方向往下曲,鞍點就是(0,0)。

參考文章
參考了一下作者的文章,在這里表示感謝!
- https://blog.csdn.net/kwame211/article/details/80364079
- https://www.cnblogs.com/guoyaohua/p/8542554.html
更多文章、技術交流、商務合作、聯系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
