class
?
Network
(object)
:
????
def
?
__init__
(self,?sizes)
:
????????
"""The?list?``sizes``?contains?the?number?of?neurons?in?the
????????respective?layers?of?the?network.??For?example,?if?the?list
????????was?[2,?3,?1]?then?it?would?be?a?three-layer?network,?with?the
????????first?layer?containing?2?neurons,?the?second?layer?3?neurons,
????????and?the?third?layer?1?neuron.??The?biases?and?weights?for?the
????????network?are?initialized?randomly,?using?a?Gaussian
????????distribution?with?mean?0,?and?variance?1.??Note?that?the?first
????????layer?is?assumed?to?be?an?input?layer,?and?by?convention?we
????????won't?set?any?biases?for?those?neurons,?since?biases?are?only
????????ever?used?in?computing?the?outputs?from?later?layers."""
????????self.num_layers?=?len(sizes)
????????self.sizes?=?sizes
????????self.biases?=?[np.random.randn(y,?
1
)?
for
?y?
in
?sizes[
1
:]]
????????self.weights?=?[np.random.randn(y,?x)
????????????????????????
for
?x,?y?
in
?zip(sizes[:
-1
],?sizes[
1
:])]
def
?
update_mini_batch
(
self
,?mini_batch,?eta)
:
????
"""Update?the?network's?weights?and?biases?by?applying
????gradient?descent?using?backpropagation?to?a?single?mini?batch.
????The?``mini_batch``?is?a?list?of?tuples?``(x,?y)``,?and?``eta``
????is?the?learning?rate."""
????nabla_b?=?[np.zeros(b.shape)?
for
?b?
in
?
self
.biases]
????nabla_w?=?[np.zeros(w.shape)?
for
?w?
in
?
self
.weights]
????
for
?x,?y?
in
?
mini_batch:
????????delta_nabla_b,?delta_nabla_w?=?
self
.backprop(x,?y)
????????nabla_b?=?[nb+dnb?
for
?nb,?dnb?
in
?zip(nabla_b,?delta_nabla_b)]
????????nabla_w?=?[nw+dnw?
for
?nw,?dnw?
in
?zip(nabla_w,?delta_nabla_w)]
????
self
.weights?=?[w-(eta/len(mini_batch))*nw
????????????????????
for
?w,?nw?
in
?zip(
self
.weights,?nabla_w)]
????
self
.biases?=?[b-(eta/len(mini_batch))*nb
???????????????????
for
?b,?nb?
in
?zip(
self
.biases,?nabla_b)]
use
?
ndarray
::Array2
;
#
[derive(Debug)]
struct
?
Network
?{
????
num_layers
:?usize,
????sizes:?Vec
????biases:?Vec
????weights:?Vec
}
use
?
rand
::
distributions
::
StandardNormal
;
use
?
ndarray
::{
Array
,?
Array2
};
use
?
ndarray_rand
::
RandomExt
;
impl?Network?{
???????fn?
new
(sizes:?&[usize])?->?Network?{
????????let?num_layers?=?sizes.len();
????????let?mut?biases:?Vec
????????let?mut?weights:?Vec
????????
for
?i?in?
1.
.num_layers?{
????????????biases.push(
Array
::random((sizes[i],?
1
),?StandardNormal));
????????????weights.push(
Array
::random((sizes[i],?sizes[i?-?
1
]),?StandardNormal));
????????}
????????Network?{
????????????num_layers:?num_layers,
????????????sizes:?sizes.to_owned(),
????????????biases:?biases,
????????????weights:?weights,
????????}
????}?
}
impl?Network?{
????
fn?
update_mini_batch
(
????????&mut?self,
????????training_data:?&[MnistImage],
????????mini_batch_indices:?&[usize],
????????eta:?f64,
????
)?
{
????????
let
?mut?nabla_b:?Vec
????????
let
?mut?nabla_w:?Vec
????????
for
?i?
in
?mini_batch_indices?{
????????????
let
?(delta_nabla_b,?delta_nabla_w)?=?self.backprop(&training_data[*i]);
????????????
for
?(nb,?dnb)?
in
?nabla_b.iter_mut().zip(delta_nabla_b.iter())?{
????????????????*nb?+=?dnb;
????????????}
????????????
for
?(nw,?dnw)?
in
?nabla_w.iter_mut().zip(delta_nabla_w.iter())?{
????????????????*nw?+=?dnw;
????????????}
????????}
????????
let
?nbatch?=?mini_batch_indices.len()?
as
?f64;
????????
for
?(w,?nw)?
in
?self.weights.iter_mut().zip(nabla_w.iter())?{
????????????*w?-=?&nw.mapv(|x|?x?*?eta?/?nbatch);
????????}
????????
for
?(b,?nb)?
in
?self.biases.iter_mut().zip(nabla_b.iter())?{
????????????*b?-=?&nb.mapv(|x|?x?*?eta?/?nbatch);
????????}
????}
}
fn?to_tuple(inp:?&[usize])?->?(usize,?usize)?{
????match?inp?{
????????[a,?b]?=>?(*a,?*b),
????????_?=>?panic!(),
????}
}
fn?zero_vec_like(inp:?&[Array2
????inp.iter()
????????.map(|x|?Array2::zeros(to_tuple(x.shape())))
????????.collect()
}
let ?(delta_nabla_b,?delta_nabla_w)?=?self.backprop(&training_data[*i]);
for
?(nb,?dnb)?
in
?nabla_b.iter_mut().zip(delta_nabla_b.iter())?{
????*nb?+=?dnb;
}
for
?(nw,?dnw)?
in
?nabla_w.iter_mut().zip(delta_nabla_w.iter())?{
????*nw?+=?dnw;
}
let?nbatch?=?mini_batch_indices.len()?as?f64;
for
?(w,?nw)?
in
?
self
.weights.iter_mut().zip(nabla_w.iter())?{
????*w?-=?&nw.mapv(
|x|
?x?*?eta?/?nbatch);
}
for
?(b,?nb)?
in
?
self
.biases.iter_mut().zip(nabla_b.iter())?{
????*b?-=?&nb.mapv(
|x|
?x?*?eta?/?nbatch);
}
六大主題報(bào)告,四大技術(shù)專題,AI開發(fā)者大會(huì)首日精華內(nèi)容全回顧
AI ProCon圓滿落幕,五大技術(shù)專場(chǎng)精彩瞬間不容錯(cuò)過(guò)
CSDN“2019 優(yōu)秀AI、IoT應(yīng)用案例TOP 30+”正式發(fā) 布
如何打造高質(zhì)量的機(jī)器學(xué)習(xí)數(shù)據(jù)集?
從模型到應(yīng)用,一文讀懂因子分解機(jī)
用Python爬取淘寶2000款套套
7段代碼帶你玩轉(zhuǎn)Python條件語(yǔ)句
高級(jí)軟件工程師教會(huì)小白的那些事!
誰(shuí)說(shuō) C++ 的強(qiáng)制類型轉(zhuǎn)換很難懂?
更多文章、技術(shù)交流、商務(wù)合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號(hào)聯(lián)系: 360901061
您的支持是博主寫作最大的動(dòng)力,如果您喜歡我的文章,感覺我的文章對(duì)您有幫助,請(qǐng)用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點(diǎn)擊下面給點(diǎn)支持吧,站長(zhǎng)非常感激您!手機(jī)微信長(zhǎng)按不能支付解決辦法:請(qǐng)將微信支付二維碼保存到相冊(cè),切換到微信,然后點(diǎn)擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對(duì)您有幫助就好】元
