作者 |?beyondma
轉(zhuǎn)載自CSDN博客
本月1日起,上海正式開始了“史上最嚴(yán)“垃圾分類的規(guī)定,扔錯垃圾最高可罰200元。全國其它46個城市也要陸續(xù)步入垃圾分類新時代。各種被垃圾分類逼瘋的段子在社交媒體上層出不窮。
其實(shí)從人工智能的角度看垃圾分類就是圖像處理中圖像分類任務(wù)的一種應(yīng)用,而這在2012年以來的ImageNet圖像分類任務(wù)的評比中,SENet模型以top-5測試集回歸2.25%錯誤率的成績可謂是技壓群雄,堪稱目前最強(qiáng)的圖像分類器。
筆者剛剛還到SENet的創(chuàng)造者momenta公司的網(wǎng)站上看了一下,他們最新的方向已經(jīng)是3D物體識別和標(biāo)定了,效果如下:
可以說他們提出的SENet進(jìn)行垃圾圖像處理是完全沒問題的。
Senet簡介
Senet的是由momenta和牛津大學(xué)共同提出的一種基于擠壓(squeeze)和激勵(Excitation)的模型,每個模塊通過“擠壓”操作嵌入來自全局感受野的信息,并且通過“激勵”操作選擇性地誘導(dǎo)響應(yīng)增強(qiáng)。我們可以看到歷年的ImageNet冠軍基本都是在使用加大模型數(shù)量和連接數(shù)量的方式來提高精度,而Senet在這種”大力出奇跡”的潮流下明顯是一股清流。其論文地址如下:http://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdf
其具體原理說明如下:
Sequeeze:對 C×H×W 進(jìn)行 global average pooling,得到 1×1×C 大小的特征圖,這個特征圖可以理解為具有全局感受野。翻譯論文原文來說:將每個二維的特征通道變成一個實(shí)數(shù),這個實(shí)數(shù)某種程度上具有全局的感受野,并且輸出的維度和輸入的特征通道數(shù)相匹配。它表征著在特征通道上響應(yīng)的全局分布,而且使得靠近輸入的層也可以獲得全局的感受野。
Excitation :使用一個全連接神經(jīng)網(wǎng)絡(luò),對 Sequeeze 之后的結(jié)果做一個非線性變換。它的機(jī)制一個類似于循環(huán)神經(jīng)網(wǎng)絡(luò)中的門。通過參數(shù) w 來為每個特征通道生成權(quán)重,其中參數(shù) w 被學(xué)習(xí)用來顯式地建模特征通道間的相關(guān)性。
特征重標(biāo)定:使用 Excitation 得到的結(jié)果作為權(quán)重,乘到輸入特征上。將Excitation輸出的權(quán)重可以認(rèn)為是特征通道的重要性反應(yīng),逐通道加權(quán)到放到先前的特征上,完成對原始特征的重標(biāo)定。
其模型架構(gòu)如下:
SENet 構(gòu)造非常簡單,而且很容易被部署,不需要引入新的函數(shù)或者層。其caffe模型可以通過百度下載(https://pan.baidu.com/s/1o7HdfAE?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=)
Senet的運(yùn)用
如果讀者布署有caffe那么直接下載剛剛的模型直接load進(jìn)來就可以使用了。如果沒有裝caffe而裝了tensorflow也沒關(guān)系,我們剛剛說了SENet沒有引入新的函數(shù)和層,很方便用tensorflow實(shí)現(xiàn)。
下載圖像集:經(jīng)筆者各方查找發(fā)現(xiàn)了這個數(shù)據(jù)集,雖然不大也沒有發(fā)揮出SENet的優(yōu)勢,不過也方便使用:
https://raw.githubusercontent.com/garythung/trashnet/master/data/dataset-resized.zip
建立SENet模型:使用tensorflow建立的模型在github上也有開源項(xiàng)目了,網(wǎng)址如下:https://github.com/taki0112/SENet-Tensorflow,只是他使用的是Cifar10數(shù)據(jù)集,不過這也沒關(guān)系,只需要在gitclone以下將其cifar10.py中的prepare_data函數(shù)做如下修改即可。
1def?prepare_data(): 2????print("======Loading?data======") 3????download_data() 4????data_dir?=?'e:/test/' 5????#data_dir?=?'./cifar-10-batches-py'#改為你的文件俠 6????image_dim?=?image_size?*?image_size?*?img_channels 7????#meta?=?unpickle(data_dir?+?'/batches.meta')#本數(shù)據(jù)集不使用meta文件分類,故需要修改 8????label_names?=?['cardboard','glass','metal','trash','paper','plastic'] 9????label_count?=?len(label_names)10????#train_files?=?['data_batch_%d'?%?d?for?d?in?range(1,?6)]11????train_files?=?[data_dir+s?for?s?in?label_names]#改為12????train_data,?train_labels?=?load_data(train_files,?data_dir,?label_count)13????test_data,?test_labels?=?load_data(['test_batch'],?data_dir,?label_count)1415????print("Train?data:",?np.shape(train_data),?np.shape(train_labels))16????print("Test?data?:",?np.shape(test_data),?np.shape(test_labels))17????print("======Load?finished======")1819????print("======Shuffling?data======")20????indices?=?np.random.permutation(len(train_data))21????train_data?=?train_data[indices]22????train_labels?=?train_labels[indices]23????print("======Prepare?Finished======")2425????return?train_data,?train_labels,?test_data,?test_labels
2
????
print
(
"======Loading?data======"
)
3
????download_data()
4
????data_dir?=?
'e:/test/'
5
????#data_dir?=?
'./cifar-10-batches-py'
#改為你的文件俠
6
????image_dim?=?image_size?*?image_size?*?img_channels
7
????#meta?=?unpickle(data_dir?+?
'/batches.meta'
)#本數(shù)據(jù)集不使用meta文件分類,故需要修改
8
????label_names?=?[
'cardboard'
,
'glass'
,
'metal'
,
'trash'
,
'paper'
,
'plastic'
]
9
????label_count?=?
len
(label_names)
10
????#train_files?=?[
'data_batch_%d'
?%?d?
for
?d?
in
?range(
1
,?
6
)]
11
????train_files?=?[data_dir+s?
for
?s?
in
?label_names]#改為
12
????train_data,?train_labels?=?load_data(train_files,?data_dir,?label_count)
13
????test_data,?test_labels?=?load_data([
'test_batch'
],?data_dir,?label_count)
14
15
????
print
(
"Train?data:"
,?np.shape(train_data),?np.shape(train_labels))
16
????
print
(
"Test?data?:"
,?np.shape(test_data),?np.shape(test_labels))
17
????
print
(
"======Load?finished======"
)
18
19
????
print
(
"======Shuffling?data======"
)
20
????indices?=?np.
random
.permutation(
len
(train_data))
21
????train_data?=?train_data[indices]
22
????train_labels?=?train_labels[indices]
23
????
print
(
"======Prepare?Finished======"
)
24
25
????
return
?train_data,?train_labels,?test_data,?test_labels
? ?
其最主要的建模代碼如下,其主要工作就是將SENet的模型結(jié)構(gòu)實(shí)現(xiàn)一下即可:
1import?tensorflow?as?tf 2from?tflearn.layers.conv?import?global_avg_pool 3from?tensorflow.contrib.layers?import?batch_norm,?flatten 4from?tensorflow.contrib.framework?import?arg_scope 5from?cifar10?import?* 6import?numpy?as?np 7 8weight_decay?=?0.0005 9momentum?=?0.9 10 11init_learning_rate?=?0.1 12 13reduction_ratio?=?4 14 15batch_size?=?128 16iteration?=?391 17#?128?*?391?~?50,000 18 19test_iteration?=?10 20 21total_epochs?=?100 22 23def?conv_layer(input,?filter,?kernel,?stride=1,?padding='SAME',?layer_name="conv",?activation=True): 24????with?tf.name_scope(layer_name): 25????????network?=?tf.layers.conv2d(inputs=input,?use_bias=True,?filters=filter,?kernel_size=kernel,?strides=stride,?padding=padding) 26????????if?activation?: 27????????????network?=?Relu(network) 28????????return?network 29 30def?Fully_connected(x,?units=class_num,?layer_name='fully_connected')?: 31????with?tf.name_scope(layer_name)?: 32????????return?tf.layers.dense(inputs=x,?use_bias=True,?units=units) 33 34def?Relu(x): 35????return?tf.nn.relu(x) 36 37def?Sigmoid(x): 38????return?tf.nn.sigmoid(x) 39 40def?Global_Average_Pooling(x): 41????return?global_avg_pool(x,?name='Global_avg_pooling') 42 43def?Max_pooling(x,?pool_size=[3,3],?stride=2,?padding='VALID')?: 44????return?tf.layers.max_pooling2d(inputs=x,?pool_size=pool_size,?strides=stride,?padding=padding) 45 46def?Batch_Normalization(x,?training,?scope): 47????with?arg_scope([batch_norm], 48???????????????????scope=scope, 49???????????????????updates_collections=None, 50???????????????????decay=0.9, 51???????????????????center=True, 52???????????????????scale=True, 53???????????????????zero_debias_moving_mean=True)?: 54????????return?tf.cond(training, 55???????????????????????lambda?:?batch_norm(inputs=x,?is_training=training,?reuse=None), 56???????????????????????lambda?:?batch_norm(inputs=x,?is_training=training,?reuse=True)) 57 58def?Concatenation(layers)?: 59????return?tf.concat(layers,?axis=3) 60 61def?Dropout(x,?rate,?training)?: 62????return?tf.layers.dropout(inputs=x,?rate=rate,?training=training) 63 64def?Evaluate(sess): 65????test_acc?=?0.0 66????test_loss?=?0.0 67????test_pre_index?=?0 68????add?=?1000 69 70????for?it?in?range(test_iteration): 71????????test_batch_x?=?test_x[test_pre_index:?test_pre_index?+?add] 72????????test_batch_y?=?test_y[test_pre_index:?test_pre_index?+?add] 73????????test_pre_index?=?test_pre_index?+?add 74 75????????test_feed_dict?=?{ 76????????????x:?test_batch_x, 77????????????label:?test_batch_y, 78????????????learning_rate:?epoch_learning_rate, 79????????????training_flag:?False 80????????} 81 82????????loss_,?acc_?=?sess.run([cost,?accuracy],?feed_dict=test_feed_dict) 83 84????????test_loss?+=?loss_ 85????????test_acc?+=?acc_ 86 87????test_loss?/=?test_iteration?#?average?loss 88????test_acc?/=?test_iteration?#?average?accuracy 89 90????summary?=?tf.Summary(value=[tf.Summary.Value(tag='test_loss',?simple_value=test_loss), 91????????????????????????????????tf.Summary.Value(tag='test_accuracy',?simple_value=test_acc)]) 92 93????return?test_acc,?test_loss,?summary 94 95class?SE_Inception_resnet_v2(): 96????def?__init__(self,?x,?training): 97????????self.training?=?training 98????????self.model?=?self.Build_SEnet(x) 99100????def?Stem(self,?x,?scope):101????????with?tf.name_scope(scope)?:102????????????x?=?conv_layer(x,?filter=32,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_conv1')103????????????x?=?conv_layer(x,?filter=32,?kernel=[3,3],?padding='VALID',?layer_name=scope+'_conv2')104????????????block_1?=?conv_layer(x,?filter=64,?kernel=[3,3],?layer_name=scope+'_conv3')105106????????????split_max_x?=?Max_pooling(block_1)107????????????split_conv_x?=?conv_layer(block_1,?filter=96,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv1')108????????????x?=?Concatenation([split_max_x,split_conv_x])109110????????????split_conv_x1?=?conv_layer(x,?filter=64,?kernel=[1,1],?layer_name=scope+'_split_conv2')111????????????split_conv_x1?=?conv_layer(split_conv_x1,?filter=96,?kernel=[3,3],?padding='VALID',?layer_name=scope+'_split_conv3')112113????????????split_conv_x2?=?conv_layer(x,?filter=64,?kernel=[1,1],?layer_name=scope+'_split_conv4')114????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=64,?kernel=[7,1],?layer_name=scope+'_split_conv5')115????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=64,?kernel=[1,7],?layer_name=scope+'_split_conv6')116????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=96,?kernel=[3,3],?padding='VALID',?layer_name=scope+'_split_conv7')117118????????????x?=?Concatenation([split_conv_x1,split_conv_x2])119120????????????split_conv_x?=?conv_layer(x,?filter=192,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv8')121????????????split_max_x?=?Max_pooling(x)122123????????????x?=?Concatenation([split_conv_x,?split_max_x])124125????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+'_batch1')126????????????x?=?Relu(x)127128????????????return?x129130????def?Inception_resnet_A(self,?x,?scope):131????????with?tf.name_scope(scope)?:132????????????init?=?x133134????????????split_conv_x1?=?conv_layer(x,?filter=32,?kernel=[1,1],?layer_name=scope+'_split_conv1')135136????????????split_conv_x2?=?conv_layer(x,?filter=32,?kernel=[1,1],?layer_name=scope+'_split_conv2')137????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=32,?kernel=[3,3],?layer_name=scope+'_split_conv3')138139????????????split_conv_x3?=?conv_layer(x,?filter=32,?kernel=[1,1],?layer_name=scope+'_split_conv4')140????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=48,?kernel=[3,3],?layer_name=scope+'_split_conv5')141????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=64,?kernel=[3,3],?layer_name=scope+'_split_conv6')142143????????????x?=?Concatenation([split_conv_x1,split_conv_x2,split_conv_x3])144????????????x?=?conv_layer(x,?filter=384,?kernel=[1,1],?layer_name=scope+'_final_conv1',?activation=False)145146????????????x?=?x*0.1147????????????x?=?init?+?x148149????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+'_batch1')150????????????x?=?Relu(x)151152????????????return?x153154????def?Inception_resnet_B(self,?x,?scope):155????????with?tf.name_scope(scope)?:156????????????init?=?x157158????????????split_conv_x1?=?conv_layer(x,?filter=192,?kernel=[1,1],?layer_name=scope+'_split_conv1')159160????????????split_conv_x2?=?conv_layer(x,?filter=128,?kernel=[1,1],?layer_name=scope+'_split_conv2')161????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=160,?kernel=[1,7],?layer_name=scope+'_split_conv3')162????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=192,?kernel=[7,1],?layer_name=scope+'_split_conv4')163164????????????x?=?Concatenation([split_conv_x1,?split_conv_x2])165????????????x?=?conv_layer(x,?filter=1152,?kernel=[1,1],?layer_name=scope+'_final_conv1',?activation=False)166????????????#?1154167????????????x?=?x?*?0.1168????????????x?=?init?+?x169170????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+'_batch1')171????????????x?=?Relu(x)172173????????????return?x174175????def?Inception_resnet_C(self,?x,?scope):176????????with?tf.name_scope(scope)?:177????????????init?=?x178179????????????split_conv_x1?=?conv_layer(x,?filter=192,?kernel=[1,1],?layer_name=scope+'_split_conv1')180181????????????split_conv_x2?=?conv_layer(x,?filter=192,?kernel=[1,?1],?layer_name=scope?+?'_split_conv2')182????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=224,?kernel=[1,?3],?layer_name=scope?+?'_split_conv3')183????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=256,?kernel=[3,?1],?layer_name=scope?+?'_split_conv4')184185????????????x?=?Concatenation([split_conv_x1,split_conv_x2])186????????????x?=?conv_layer(x,?filter=2144,?kernel=[1,1],?layer_name=scope+'_final_conv2',?activation=False)187????????????#?2048188????????????x?=?x?*?0.1189????????????x?=?init?+?x190191????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+'_batch1')192????????????x?=?Relu(x)193194????????????return?x195196????def?Reduction_A(self,?x,?scope):197????????with?tf.name_scope(scope)?:198????????????k?=?256199????????????l?=?256200????????????m?=?384201????????????n?=?384202203????????????split_max_x?=?Max_pooling(x)204205????????????split_conv_x1?=?conv_layer(x,?filter=n,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv1')206207????????????split_conv_x2?=?conv_layer(x,?filter=k,?kernel=[1,1],?layer_name=scope+'_split_conv2')208????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=l,?kernel=[3,3],?layer_name=scope+'_split_conv3')209????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=m,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv4')210211????????????x?=?Concatenation([split_max_x,?split_conv_x1,?split_conv_x2])212213????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+'_batch1')214????????????x?=?Relu(x)215216????????????return?x217218????def?Reduction_B(self,?x,?scope):219????????with?tf.name_scope(scope)?:220????????????split_max_x?=?Max_pooling(x)221222????????????split_conv_x1?=?conv_layer(x,?filter=256,?kernel=[1,1],?layer_name=scope+'_split_conv1')223????????????split_conv_x1?=?conv_layer(split_conv_x1,?filter=384,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv2')224225????????????split_conv_x2?=?conv_layer(x,?filter=256,?kernel=[1,1],?layer_name=scope+'_split_conv3')226????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=288,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv4')227228????????????split_conv_x3?=?conv_layer(x,?filter=256,?kernel=[1,1],?layer_name=scope+'_split_conv5')229????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=288,?kernel=[3,3],?layer_name=scope+'_split_conv6')230????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=320,?kernel=[3,3],?stride=2,?padding='VALID',?layer_name=scope+'_split_conv7')231232????????????x?=?Concatenation([split_max_x,?split_conv_x1,?split_conv_x2,?split_conv_x3])233234????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+'_batch1')235????????????x?=?Relu(x)236237????????????return?x238239????def?Squeeze_excitation_layer(self,?input_x,?out_dim,?ratio,?layer_name):240????????with?tf.name_scope(layer_name)?:241242243????????????squeeze?=?Global_Average_Pooling(input_x)244245????????????excitation?=?Fully_connected(squeeze,?units=out_dim?/?ratio,?layer_name=layer_name+'_fully_connected1')246????????????excitation?=?Relu(excitation)247????????????excitation?=?Fully_connected(excitation,?units=out_dim,?layer_name=layer_name+'_fully_connected2')248????????????excitation?=?Sigmoid(excitation)249250????????????excitation?=?tf.reshape(excitation,?[-1,1,1,out_dim])251????????????scale?=?input_x?*?excitation252253????????????return?scale254255????def?Build_SEnet(self,?input_x):256????????input_x?=?tf.pad(input_x,?[[0,?0],?[32,?32],?[32,?32],?[0,?0]])257????????#?size?32?->?96258????????print(np.shape(input_x))259????????#?only?cifar10?architecture260261????????x?=?self.Stem(input_x,?scope='stem')262263????????for?i?in?range(5)?:264????????????x?=?self.Inception_resnet_A(x,?scope='Inception_A'+str(i))265????????????channel?=?int(np.shape(x)[-1])266????????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_A'+str(i))267268????????x?=?self.Reduction_A(x,?scope='Reduction_A')269270????????channel?=?int(np.shape(x)[-1])271????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_A')272273????????for?i?in?range(10)??:274????????????x?=?self.Inception_resnet_B(x,?scope='Inception_B'+str(i))275????????????channel?=?int(np.shape(x)[-1])276????????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_B'+str(i))277278????????x?=?self.Reduction_B(x,?scope='Reduction_B')279280????????channel?=?int(np.shape(x)[-1])281????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_B')282283????????for?i?in?range(5)?:284????????????x?=?self.Inception_resnet_C(x,?scope='Inception_C'+str(i))285????????????channel?=?int(np.shape(x)[-1])286????????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_C'+str(i))287288289????????#?channel?=?int(np.shape(x)[-1])290????????#?x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_C')291292????????x?=?Global_Average_Pooling(x)293????????x?=?Dropout(x,?rate=0.2,?training=self.training)294????????x?=?flatten(x)295296????????x?=?Fully_connected(x,?layer_name='final_fully_connected')297????????return?x298299300train_x,?train_y,?test_x,?test_y?=?prepare_data()301train_x,?test_x?=?color_preprocessing(train_x,?test_x)302303304#?image_size?=?32,?img_channels?=?3,?class_num?=?10?in?cifar10305x?=?tf.placeholder(tf.float32,?shape=[None,?image_size,?image_size,?img_channels])306label?=?tf.placeholder(tf.float32,?shape=[None,?class_num])307308training_flag?=?tf.placeholder(tf.bool)309310311learning_rate?=?tf.placeholder(tf.float32,?name='learning_rate')312313logits?=?SE_Inception_resnet_v2(x,?training=training_flag).model314cost?=?tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,?logits=logits))315316l2_loss?=?tf.add_n([tf.nn.l2_loss(var)?for?var?in?tf.trainable_variables()])317optimizer?=?tf.train.MomentumOptimizer(learning_rate=learning_rate,?momentum=momentum,?use_nesterov=True)318train?=?optimizer.minimize(cost?+?l2_loss?*?weight_decay)319320correct_prediction?=?tf.equal(tf.argmax(logits,?1),?tf.argmax(label,?1))321accuracy?=?tf.reduce_mean(tf.cast(correct_prediction,?tf.float32))322323saver?=?tf.train.Saver(tf.global_variables())324325with?tf.Session()?as?sess:326????ckpt?=?tf.train.get_checkpoint_state('./model')327????if?ckpt?and?tf.train.checkpoint_exists(ckpt.model_checkpoint_path):328????????saver.restore(sess,?ckpt.model_checkpoint_path)329????else:330????????sess.run(tf.global_variables_initializer())331332????summary_writer?=?tf.summary.FileWriter('./logs',?sess.graph)333334????epoch_learning_rate?=?init_learning_rate335????for?epoch?in?range(1,?total_epochs?+?1):336????????if?epoch?%?30?==?0?:337????????????epoch_learning_rate?=?epoch_learning_rate?/?10338339????????pre_index?=?0340????????train_acc?=?0.0341????????train_loss?=?0.0342343????????for?step?in?range(1,?iteration?+?1):344????????????if?pre_index?+?batch_size?<?50000:345????????????????batch_x?=?train_x[pre_index:?pre_index?+?batch_size]346????????????????batch_y?=?train_y[pre_index:?pre_index?+?batch_size]347????????????else:348????????????????batch_x?=?train_x[pre_index:]349????????????????batch_y?=?train_y[pre_index:]350351????????????batch_x?=?data_augmentation(batch_x)352353????????????train_feed_dict?=?{354????????????????x:?batch_x,355????????????????label:?batch_y,356????????????????learning_rate:?epoch_learning_rate,357????????????????training_flag:?True358????????????}359360????????????_,?batch_loss?=?sess.run([train,?cost],?feed_dict=train_feed_dict)361????????????batch_acc?=?accuracy.eval(feed_dict=train_feed_dict)362363????????????train_loss?+=?batch_loss364????????????train_acc?+=?batch_acc365????????????pre_index?+=?batch_size366367368????????train_loss?/=?iteration?#?average?loss369????????train_acc?/=?iteration?#?average?accuracy370371????????train_summary?=?tf.Summary(value=[tf.Summary.Value(tag='train_loss',?simple_value=train_loss),372??????????????????????????????????????????tf.Summary.Value(tag='train_accuracy',?simple_value=train_acc)])373374????????test_acc,?test_loss,?test_summary?=?Evaluate(sess)375376????????summary_writer.add_summary(summary=train_summary,?global_step=epoch)377????????summary_writer.add_summary(summary=test_summary,?global_step=epoch)378????????summary_writer.flush()379380????????line?=?"epoch:?%d/%d,?train_loss:?%.4f,?train_acc:?%.4f,?test_loss:?%.4f,?test_acc:?%.4f?\n"?%?(381????????????epoch,?total_epochs,?train_loss,?train_acc,?test_loss,?test_acc)382????????print(line)383384????????with?open('logs.txt',?'a')?as?f:385????????????f.write(line)386387????????saver.save(sess=sess,?save_path='./model/Inception_resnet_v2.ckpt')
import
?tensorflow?
as
?tf
2
from
?tflearn.layers.conv?
import
?global_avg_pool
3
from
?tensorflow.contrib.layers?
import
?batch_norm,?flatten
4
from
?tensorflow.contrib.framework?
import
?arg_scope
5
from
?cifar10?
import
?*
6
import
?numpy?
as
?np
7
8
weight_decay?=?
0.0005
9
momentum?=?
0.9
10
11
init_learning_rate?=?
0.1
12
13
reduction_ratio?=?
4
14
15
batch_size?=?
128
16
iteration?=?
391
17
#?128?*?391?~?50,000
18
19
test_iteration?=?
10
20
21
total_epochs?=?
100
22
23
def
?
conv_layer
(input,?filter,?kernel,?stride=
1
,?padding=
'SAME'
,?layer_name=
"conv"
,?activation=True)
:
24
????
with
?tf.name_scope(layer_name):
25
????????network?=?tf.layers.conv2d(inputs=input,?use_bias=
True
,?filters=filter,?kernel_size=kernel,?strides=stride,?padding=padding)
26
????????
if
?activation?:
27
????????????network?=?Relu(network)
28
????????
return
?network
29
30
def
?
Fully_connected
(x,?units=class_num,?layer_name=
'fully_connected'
)
?:
31
????
with
?tf.name_scope(layer_name)?:
32
????????
return
?tf.layers.dense(inputs=x,?use_bias=
True
,?units=units)
33
34
def
?
Relu
(x)
:
35
????
return
?tf.nn.relu(x)
36
37
def
?
Sigmoid
(x)
:
38
????
return
?tf.nn.sigmoid(x)
39
40
def
?
Global_Average_Pooling
(x)
:
41
????
return
?global_avg_pool(x,?name=
'Global_avg_pooling'
)
42
43
def
?
Max_pooling
(x,?pool_size=[
3
,
3
],?stride=
2
,?padding=
'VALID'
)
?:
44
????
return
?tf.layers.max_pooling2d(inputs=x,?pool_size=pool_size,?strides=stride,?padding=padding)
45
46
def
?
Batch_Normalization
(x,?training,?scope)
:
47
????
with
?arg_scope([batch_norm],
48
???????????????????scope=scope,
49
???????????????????updates_collections=
None
,
50
???????????????????decay=
0.9
,
51
???????????????????center=
True
,
52
???????????????????scale=
True
,
53
???????????????????zero_debias_moving_mean=
True
)?:
54
????????
return
?tf.cond(training,
55
???????????????????????
lambda
?:?batch_norm(inputs=x,?is_training=training,?reuse=
None
),
56
???????????????????????
lambda
?:?batch_norm(inputs=x,?is_training=training,?reuse=
True
))
57
58
def
?
Concatenation
(layers)
?:
59
????
return
?tf.concat(layers,?axis=
3
)
60
61
def
?
Dropout
(x,?rate,?training)
?:
62
????
return
?tf.layers.dropout(inputs=x,?rate=rate,?training=training)
63
64
def
?
Evaluate
(sess)
:
65
????test_acc?=?
0.0
66
????test_loss?=?
0.0
67
????test_pre_index?=?
0
68
????add?=?
1000
69
70
????
for
?it?
in
?range(test_iteration):
71
????????test_batch_x?=?test_x[test_pre_index:?test_pre_index?+?add]
72
????????test_batch_y?=?test_y[test_pre_index:?test_pre_index?+?add]
73
????????test_pre_index?=?test_pre_index?+?add
74
75
????????test_feed_dict?=?{
76
????????????x:?test_batch_x,
77
????????????label:?test_batch_y,
78
????????????learning_rate:?epoch_learning_rate,
79
????????????training_flag:?
False
80
????????}
81
82
????????loss_,?acc_?=?sess.run([cost,?accuracy],?feed_dict=test_feed_dict)
83
84
????????test_loss?+=?loss_
85
????????test_acc?+=?acc_
86
87
????test_loss?/=?test_iteration?
#?average?loss
88
????test_acc?/=?test_iteration?
#?average?accuracy
89
90
????summary?=?tf.Summary(value=[tf.Summary.Value(tag=
'test_loss'
,?simple_value=test_loss),
91
????????????????????????????????tf.Summary.Value(tag=
'test_accuracy'
,?simple_value=test_acc)])
92
93
????
return
?test_acc,?test_loss,?summary
94
95
class
?
SE_Inception_resnet_v2
()
:
96
????
def
?
__init__
(self,?x,?training)
:
97
????????self.training?=?training
98
????????self.model?=?self.Build_SEnet(x)
99
100
????
def
?
Stem
(self,?x,?scope)
:
101
????????
with
?tf.name_scope(scope)?:
102
????????????x?=?conv_layer(x,?filter=
32
,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_conv1'
)
103
????????????x?=?conv_layer(x,?filter=
32
,?kernel=[
3
,
3
],?padding=
'VALID'
,?layer_name=scope+
'_conv2'
)
104
????????????block_1?=?conv_layer(x,?filter=
64
,?kernel=[
3
,
3
],?layer_name=scope+
'_conv3'
)
105
106
????????????split_max_x?=?Max_pooling(block_1)
107
????????????split_conv_x?=?conv_layer(block_1,?filter=
96
,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv1'
)
108
????????????x?=?Concatenation([split_max_x,split_conv_x])
109
110
????????????split_conv_x1?=?conv_layer(x,?filter=
64
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv2'
)
111
????????????split_conv_x1?=?conv_layer(split_conv_x1,?filter=
96
,?kernel=[
3
,
3
],?padding=
'VALID'
,?layer_name=scope+
'_split_conv3'
)
112
113
????????????split_conv_x2?=?conv_layer(x,?filter=
64
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv4'
)
114
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
64
,?kernel=[
7
,
1
],?layer_name=scope+
'_split_conv5'
)
115
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
64
,?kernel=[
1
,
7
],?layer_name=scope+
'_split_conv6'
)
116
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
96
,?kernel=[
3
,
3
],?padding=
'VALID'
,?layer_name=scope+
'_split_conv7'
)
117
118
????????????x?=?Concatenation([split_conv_x1,split_conv_x2])
119
120
????????????split_conv_x?=?conv_layer(x,?filter=
192
,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv8'
)
121
????????????split_max_x?=?Max_pooling(x)
122
123
????????????x?=?Concatenation([split_conv_x,?split_max_x])
124
125
????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+
'_batch1'
)
126
????????????x?=?Relu(x)
127
128
????????????
return
?x
129
130
????
def
?
Inception_resnet_A
(self,?x,?scope)
:
131
????????
with
?tf.name_scope(scope)?:
132
????????????init?=?x
133
134
????????????split_conv_x1?=?conv_layer(x,?filter=
32
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv1'
)
135
136
????????????split_conv_x2?=?conv_layer(x,?filter=
32
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv2'
)
137
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
32
,?kernel=[
3
,
3
],?layer_name=scope+
'_split_conv3'
)
138
139
????????????split_conv_x3?=?conv_layer(x,?filter=
32
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv4'
)
140
????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=
48
,?kernel=[
3
,
3
],?layer_name=scope+
'_split_conv5'
)
141
????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=
64
,?kernel=[
3
,
3
],?layer_name=scope+
'_split_conv6'
)
142
143
????????????x?=?Concatenation([split_conv_x1,split_conv_x2,split_conv_x3])
144
????????????x?=?conv_layer(x,?filter=
384
,?kernel=[
1
,
1
],?layer_name=scope+
'_final_conv1'
,?activation=
False
)
145
146
????????????x?=?x*
0.1
147
????????????x?=?init?+?x
148
149
????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+
'_batch1'
)
150
????????????x?=?Relu(x)
151
152
????????????
return
?x
153
154
????
def
?
Inception_resnet_B
(self,?x,?scope)
:
155
????????
with
?tf.name_scope(scope)?:
156
????????????init?=?x
157
158
????????????split_conv_x1?=?conv_layer(x,?filter=
192
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv1'
)
159
160
????????????split_conv_x2?=?conv_layer(x,?filter=
128
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv2'
)
161
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
160
,?kernel=[
1
,
7
],?layer_name=scope+
'_split_conv3'
)
162
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
192
,?kernel=[
7
,
1
],?layer_name=scope+
'_split_conv4'
)
163
164
????????????x?=?Concatenation([split_conv_x1,?split_conv_x2])
165
????????????x?=?conv_layer(x,?filter=
1152
,?kernel=[
1
,
1
],?layer_name=scope+
'_final_conv1'
,?activation=
False
)
166
????????????
#?1154
167
????????????x?=?x?*?
0.1
168
????????????x?=?init?+?x
169
170
????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+
'_batch1'
)
171
????????????x?=?Relu(x)
172
173
????????????
return
?x
174
175
????
def
?
Inception_resnet_C
(self,?x,?scope)
:
176
????????
with
?tf.name_scope(scope)?:
177
????????????init?=?x
178
179
????????????split_conv_x1?=?conv_layer(x,?filter=
192
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv1'
)
180
181
????????????split_conv_x2?=?conv_layer(x,?filter=
192
,?kernel=[
1
,?
1
],?layer_name=scope?+?
'_split_conv2'
)
182
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
224
,?kernel=[
1
,?
3
],?layer_name=scope?+?
'_split_conv3'
)
183
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
256
,?kernel=[
3
,?
1
],?layer_name=scope?+?
'_split_conv4'
)
184
185
????????????x?=?Concatenation([split_conv_x1,split_conv_x2])
186
????????????x?=?conv_layer(x,?filter=
2144
,?kernel=[
1
,
1
],?layer_name=scope+
'_final_conv2'
,?activation=
False
)
187
????????????
#?2048
188
????????????x?=?x?*?
0.1
189
????????????x?=?init?+?x
190
191
????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+
'_batch1'
)
192
????????????x?=?Relu(x)
193
194
????????????
return
?x
195
196
????
def
?
Reduction_A
(self,?x,?scope)
:
197
????????
with
?tf.name_scope(scope)?:
198
????????????k?=?
256
199
????????????l?=?
256
200
????????????m?=?
384
201
????????????n?=?
384
202
203
????????????split_max_x?=?Max_pooling(x)
204
205
????????????split_conv_x1?=?conv_layer(x,?filter=n,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv1'
)
206
207
????????????split_conv_x2?=?conv_layer(x,?filter=k,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv2'
)
208
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=l,?kernel=[
3
,
3
],?layer_name=scope+
'_split_conv3'
)
209
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=m,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv4'
)
210
211
????????????x?=?Concatenation([split_max_x,?split_conv_x1,?split_conv_x2])
212
213
????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+
'_batch1'
)
214
????????????x?=?Relu(x)
215
216
????????????
return
?x
217
218
????
def
?
Reduction_B
(self,?x,?scope)
:
219
????????
with
?tf.name_scope(scope)?:
220
????????????split_max_x?=?Max_pooling(x)
221
222
????????????split_conv_x1?=?conv_layer(x,?filter=
256
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv1'
)
223
????????????split_conv_x1?=?conv_layer(split_conv_x1,?filter=
384
,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv2'
)
224
225
????????????split_conv_x2?=?conv_layer(x,?filter=
256
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv3'
)
226
????????????split_conv_x2?=?conv_layer(split_conv_x2,?filter=
288
,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv4'
)
227
228
????????????split_conv_x3?=?conv_layer(x,?filter=
256
,?kernel=[
1
,
1
],?layer_name=scope+
'_split_conv5'
)
229
????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=
288
,?kernel=[
3
,
3
],?layer_name=scope+
'_split_conv6'
)
230
????????????split_conv_x3?=?conv_layer(split_conv_x3,?filter=
320
,?kernel=[
3
,
3
],?stride=
2
,?padding=
'VALID'
,?layer_name=scope+
'_split_conv7'
)
231
232
????????????x?=?Concatenation([split_max_x,?split_conv_x1,?split_conv_x2,?split_conv_x3])
233
234
????????????x?=?Batch_Normalization(x,?training=self.training,?scope=scope+
'_batch1'
)
235
????????????x?=?Relu(x)
236
237
????????????
return
?x
238
239
????
def
?
Squeeze_excitation_layer
(self,?input_x,?out_dim,?ratio,?layer_name)
:
240
????????
with
?tf.name_scope(layer_name)?:
241
242
243
????????????squeeze?=?Global_Average_Pooling(input_x)
244
245
????????????excitation?=?Fully_connected(squeeze,?units=out_dim?/?ratio,?layer_name=layer_name+
'_fully_connected1'
)
246
????????????excitation?=?Relu(excitation)
247
????????????excitation?=?Fully_connected(excitation,?units=out_dim,?layer_name=layer_name+
'_fully_connected2'
)
248
????????????excitation?=?Sigmoid(excitation)
249
250
????????????excitation?=?tf.reshape(excitation,?[
-1
,
1
,
1
,out_dim])
251
????????????scale?=?input_x?*?excitation
252
253
????????????
return
?scale
254
255
????
def
?
Build_SEnet
(self,?input_x)
:
256
????????input_x?=?tf.pad(input_x,?[[
0
,?
0
],?[
32
,?
32
],?[
32
,?
32
],?[
0
,?
0
]])
257
????????
#?size?32?->?96
258
????????print(np.shape(input_x))
259
????????
#?only?cifar10?architecture
260
261
????????x?=?self.Stem(input_x,?scope=
'stem'
)
262
263
????????
for
?i?
in
?range(
5
)?:
264
????????????x?=?self.Inception_resnet_A(x,?scope=
'Inception_A'
+str(i))
265
????????????channel?=?int(np.shape(x)[
-1
])
266
????????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name=
'SE_A'
+str(i))
267
268
????????x?=?self.Reduction_A(x,?scope=
'Reduction_A'
)
269
270
????????channel?=?int(np.shape(x)[
-1
])
271
????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name=
'SE_A'
)
272
273
????????
for
?i?
in
?range(
10
)??:
274
????????????x?=?self.Inception_resnet_B(x,?scope=
'Inception_B'
+str(i))
275
????????????channel?=?int(np.shape(x)[
-1
])
276
????????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name=
'SE_B'
+str(i))
277
278
????????x?=?self.Reduction_B(x,?scope=
'Reduction_B'
)
279
280
????????channel?=?int(np.shape(x)[
-1
])
281
????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name=
'SE_B'
)
282
283
????????
for
?i?
in
?range(
5
)?:
284
????????????x?=?self.Inception_resnet_C(x,?scope=
'Inception_C'
+str(i))
285
????????????channel?=?int(np.shape(x)[
-1
])
286
????????????x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name=
'SE_C'
+str(i))
287
288
289
????????
#?channel?=?int(np.shape(x)[-1])
290
????????
#?x?=?self.Squeeze_excitation_layer(x,?out_dim=channel,?ratio=reduction_ratio,?layer_name='SE_C')
291
292
????????x?=?Global_Average_Pooling(x)
293
????????x?=?Dropout(x,?rate=
0.2
,?training=self.training)
294
????????x?=?flatten(x)
295
296
????????x?=?Fully_connected(x,?layer_name=
'final_fully_connected'
)
297
????????
return
?x
298
299
300
train_x,?train_y,?test_x,?test_y?=?prepare_data()
301
train_x,?test_x?=?color_preprocessing(train_x,?test_x)
302
303
304
#?image_size?=?32,?img_channels?=?3,?class_num?=?10?in?cifar10
305
x?=?tf.placeholder(tf.float32,?shape=[
None
,?image_size,?image_size,?img_channels])
306
label?=?tf.placeholder(tf.float32,?shape=[
None
,?class_num])
307
308
training_flag?=?tf.placeholder(tf.bool)
309
310
311
learning_rate?=?tf.placeholder(tf.float32,?name=
'learning_rate'
)
312
313
logits?=?SE_Inception_resnet_v2(x,?training=training_flag).model
314
cost?=?tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,?logits=logits))
315
316
l2_loss?=?tf.add_n([tf.nn.l2_loss(var)?
for
?var?
in
?tf.trainable_variables()])
317
optimizer?=?tf.train.MomentumOptimizer(learning_rate=learning_rate,?momentum=momentum,?use_nesterov=
True
)
318
train?=?optimizer.minimize(cost?+?l2_loss?*?weight_decay)
319
320
correct_prediction?=?tf.equal(tf.argmax(logits,?
1
),?tf.argmax(label,?
1
))
321
accuracy?=?tf.reduce_mean(tf.cast(correct_prediction,?tf.float32))
322
323
saver?=?tf.train.Saver(tf.global_variables())
324
325
with
?tf.Session()?
as
?sess:
326
????ckpt?=?tf.train.get_checkpoint_state(
'./model'
)
327
????
if
?ckpt?
and
?tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
328
????????saver.restore(sess,?ckpt.model_checkpoint_path)
329
????
else
:
330
????????sess.run(tf.global_variables_initializer())
331
332
????summary_writer?=?tf.summary.FileWriter(
'./logs'
,?sess.graph)
333
334
????epoch_learning_rate?=?init_learning_rate
335
????
for
?epoch?
in
?range(
1
,?total_epochs?+?
1
):
336
????????
if
?epoch?%?
30
?==?
0
?:
337
????????????epoch_learning_rate?=?epoch_learning_rate?/?
10
338
339
????????pre_index?=?
0
340
????????train_acc?=?
0.0
341
????????train_loss?=?
0.0
342
343
????????
for
?step?
in
?range(
1
,?iteration?+?
1
):
344
????????????
if
?pre_index?+?batch_size?<?
50000
:
345
????????????????batch_x?=?train_x[pre_index:?pre_index?+?batch_size]
346
????????????????batch_y?=?train_y[pre_index:?pre_index?+?batch_size]
347
????????????
else
:
348
????????????????batch_x?=?train_x[pre_index:]
349
????????????????batch_y?=?train_y[pre_index:]
350
351
????????????batch_x?=?data_augmentation(batch_x)
352
353
????????????train_feed_dict?=?{
354
????????????????x:?batch_x,
355
????????????????label:?batch_y,
356
????????????????learning_rate:?epoch_learning_rate,
357
????????????????training_flag:?
True
358
????????????}
359
360
????????????_,?batch_loss?=?sess.run([train,?cost],?feed_dict=train_feed_dict)
361
????????????batch_acc?=?accuracy.eval(feed_dict=train_feed_dict)
362
363
????????????train_loss?+=?batch_loss
364
????????????train_acc?+=?batch_acc
365
????????????pre_index?+=?batch_size
366
367
368
????????train_loss?/=?iteration?
#?average?loss
369
????????train_acc?/=?iteration?
#?average?accuracy
370
371
????????train_summary?=?tf.Summary(value=[tf.Summary.Value(tag=
'train_loss'
,?simple_value=train_loss),
372
??????????????????????????????????????????tf.Summary.Value(tag=
'train_accuracy'
,?simple_value=train_acc)])
373
374
????????test_acc,?test_loss,?test_summary?=?Evaluate(sess)
375
376
????????summary_writer.add_summary(summary=train_summary,?global_step=epoch)
377
????????summary_writer.add_summary(summary=test_summary,?global_step=epoch)
378
????????summary_writer.flush()
379
380
????????line?=?
"epoch:?%d/%d,?train_loss:?%.4f,?train_acc:?%.4f,?test_loss:?%.4f,?test_acc:?%.4f?\n"
?%?(
381
????????????epoch,?total_epochs,?train_loss,?train_acc,?test_loss,?test_acc)
382
????????print(line)
383
384
????????
with
?open(
'logs.txt'
,?
'a'
)?
as
?f:
385
????????????f.write(line)
386
387
????????saver.save(sess=sess,?save_path=
'./model/Inception_resnet_v2.ckpt'
)
其實(shí)使用SENet做垃圾分類真是大才小用了,不過大家也可以感受一下他的實(shí)力強(qiáng)大。
原文鏈接:
https://blog.csdn.net/BEYONDMA/article/details/94888771
(*本文為 AI科技大本營轉(zhuǎn)載文章,轉(zhuǎn)載請聯(lián)系原作者)
◆
精彩推薦
◆
“只講技術(shù),拒絕空談 ! ” 2019 AI開發(fā)者大會將于9月6日-7日在北京舉行, 這一屆AI開發(fā)者大會有哪些亮點(diǎn)? 一線公司的大牛們都在關(guān)注什么? AI行業(yè)的風(fēng)向是什么? 2019 AI開發(fā)者大會,傾聽大牛分享,聚焦技術(shù)實(shí)踐,和萬千開發(fā)者共成長。
目前,大會盲訂票限量發(fā)售中~掃碼購票,領(lǐng)先一步!
推薦閱讀
-
大戰(zhàn)三回合: XGBoost、LightGBM和Catboost一決高低 | 程序員硬核算法評測
-
Hinton等人最新研究:大幅提升模型準(zhǔn)確率,標(biāo)簽平滑技術(shù)到底怎么用?
-
智能文本信息抽取算法的進(jìn)階
-
入門必備 | 一文讀懂神經(jīng)架構(gòu)搜索
-
印度人才出口: 一半美國科技企業(yè)CEO是印度裔 | 數(shù)據(jù)分析中印青年
-
為什么說“大公司的技術(shù)頑疾根本挽救不了”
-
25 年 IT 老兵零基礎(chǔ)寫小說,作品堪比《三體》| 人物志
-
中小企業(yè)搭建混合云,服務(wù)器如何選?
-
從0到1 | 文本挖掘的傳統(tǒng)與深度學(xué)習(xí)算法
-
一覽微軟在機(jī)器閱讀理解、推薦系統(tǒng)、人機(jī)對話等最新研究進(jìn)展 | ACL 2019
-
1.2w星!火爆GitHub的Python學(xué)習(xí)100天刷爆朋友圈!

更多文章、技術(shù)交流、商務(wù)合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯(lián)系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點(diǎn)擊下面給點(diǎn)支持吧,站長非常感激您!手機(jī)微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點(diǎn)擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
