/*先把標題給寫了、這樣就能經常提醒自己*/
1. 感知機模型
我們先來定義一下什么是感知機。所謂感知機,就是二類分類的線性分類模型,其輸入為樣本的特征向量,輸出為樣本的類別,取+1和-1二值,即通過某樣本的特征,就可以準確判斷該樣本屬于哪一類。顧名思義,感知機能夠解決的問題首先要求特征空間是線性可分的,再者是二類分類,即將樣本分為{+1, -1}兩類。從比較學術的層面來說,由輸入空間到輸出空間的函數:
???????????????????????????????????????????????????????????????????????????????????????????????????????? (1)
稱為感知機,w和b為感知機參數,w為權值(weight),b為偏置(bias)。sign為符號函數:
???????????????????????????????????????????????????????????????????????????????????????????????????????? (2)
感知機模型的假設空間是定義在特征空間中的所有線性分類模型,即函數集合{f|f(x) = w·x + b}。在感知機的定義中,線性方程w·x + b = 0對應于問題空間中的一個超平面S,位于這個超平面兩側的樣本分別被歸為兩類,例如下圖,紅色作為一類,藍色作為另一類,它們的特征很簡單,就是它們的坐標
圖1
作為監(jiān)督學習的一種方法,感知機學習由訓練集求得感知機模型,即求得模型參數w,b,這里x和y分別是特征向量和類別(也稱為目標)。基于此,感知機模型可以對新的輸入樣本進行分類。
前面半抄書半自說自話把感知機的定義以及是用來干嘛的簡單記錄了一下,作為早期的機器學習方法(1957年由Frank Rosenblatt提出),它是最簡單的前饋神經網絡,對之后的機器學習方法如神經網絡起一個基礎的作用,下一節(jié)將詳細介紹感知機學習策略。
2. 感知機學習策略
上節(jié)說到,感知機是一個簡單的二類分類的線性分類模型,要求我們的樣本是線性可分的,什么樣的樣本是線性可分的呢?舉例來說,在二維平面中,可以用一條直線將+1類和-1類完美分開,那么這個樣本空間就是線性可分的。如圖1就是線性可分的,圖2中的樣本就是線性不可分的,感知機就不能處理這種情況。因此,在本章中的所有問題都基于一個前提,就是問題空間線性可分。
圖2
為方便說明問題,我們假設數據集中所有的的實例i,有;對的實例有。
這里先給出輸入空間中任一點到超平面S的距離:
?????????????????????????????????????????????????????????????????????????????????????????????????????????????? (3)
這里||w||是w的范數。
對于誤分類的數據,根據我們之前的假設,有
????????????????????????????????????????????????????????????????????????????????????????????????????????? (4)
因此誤分類點到超平面S的距離可以寫作:
??????????????????????????????????????????????????????????????????????????????????????????????????????????? (5)
假設超平面S的誤分類點集合為M,那么所有誤分類點到超平面S的總距離為
??????????????????????????????????????????????????????????????????????????????????????????????????? (6)
這里的||w||值是固定的,不必考慮,這樣就得到了感知機學習的損失函數。根據我們的定義,這個損失函數自然是越小越好,因為這樣就代表著誤分類點越少、誤分類點距離超平面S的距離越近,即我們的分類面越正確。顯然,這個損失函數是非負的,若所有的樣本都分類正確,那么我們的損失函數值為0。一個特定的樣本集T的損失函數:在誤分類時是參數w,b的線性函數。也就是說,為求得正確的參數w,b,我們的目標函數為
????????????????????????????????????????????????????????????????????????????????????????? (7)
而它是連續(xù)可導的,這就使得我們可以比較容易的求得其最小值。
感知機學習的策略是在假設空間中選取使我們的損失函數(7)最小的模型參數w,b,即感知機模型。
根據感知機定義以及我們的假設,得到了感知機的模型,即目標函數(7),將其最小化的本質就是使得分類面S盡可能的正確,下一節(jié)介紹將其最小化的方法——隨機梯度下降。
3. 感知機學習算法
根據感知機學習的策略,我們已經將尋找超平面S的問題轉化為求解式(7)的最優(yōu)化問題,最優(yōu)化的方法是隨機梯度下降法,書中介紹了兩種形式:原始形式和對偶形式,并證明了在訓練集線性可分時算法的收斂性。
3.1 原始形式
所謂原始形式,就是我們用梯度下降的方法,對參數w和b進行不斷的迭代更新。具體來說,就是先任意選取一個超平面,對應的參數分別為和,當然現在是可以任意賦值的,比如說選取為全為0的向量,的值為0。然后用梯度下降不斷地極小化損失函數(7)。由于隨機梯度下降(stochastic?????gradient descent)的效率要高于批量梯度下降(batch gradient descent)(詳情可參考Andrew Ng教授的 講義 ,在Part 1的LMS algorithm部分),所以這里采用隨機梯度下降的方法,每次隨機選取一個誤分類點對w和b進行更新。
設誤分類點集合M是固定的,為求式(7)的最小值,我們需要知道往哪個方向下降速率最快,這是可由對損失函數L(w, b)求梯度得到,L(w, b)的梯度為
接下來隨機選取一個誤分類點對w,b進行更新
????????????????????????????????????????????????????????????????????????????????????????????????????????????????(8)
??????????????????????????????????????????????????????????????????????????????????????????????????????????????????????(9)
其中為步長,也稱為學習速率(learning rate),一般在0到1之間取值,步長越大,我們梯度下降的速度越快,也就能更快接近極小點。如果步長過大,就有直接跨過極小點導致函數發(fā)散的問題;如果步長過小,可能會耗費比較長的時間才能達到極小點。通過這樣的迭代,我們的損失函數就不斷減小,直到為0。綜上所述,得到如下算法:
算法1 (感知機學習算法的原始形式)
輸入:訓練數據集,其中,,i = 1,2,…,N;學習率
輸出:w,b;感知機模型
(1)選取初始值,
(2)在訓練集中選取數據
(3)如果(從公式(3)變換而來)
(4)轉至(2),直至訓練集中沒有誤分類點
這種學習算法直觀上有如下解釋:當一個樣本被誤分類時,就調整w和b的值,使超平面S向誤分類點的一側移動,以減少該誤分類點到超平面的距離,直至超平面越過該點使之被正確分類。
書上還給出了一個例題,這是我推崇這本書的原因之一,凡是只講理論不給例子的行為都是耍流氓!
例1? 如圖3所示的訓練數據集,其正實例點是,,負實例點是,試用感知機學習算法的原始形式求感知機模型,即求出w和b。這里,
圖3
這里我們取初值,取。具體問題解釋不寫了,求解的方法就是 算法1 。下面給出這道題的Java代碼(終于有一段是自己純原創(chuàng)的了)。
package org.juefan.perceptron; import java.util.ArrayList; import org.juefan.basic.FileIO; public class PrimevalPerceptron { public static ArrayList<Integer> w = new ArrayList<> (); public static int b ; /* 初始化參數 */ public PrimevalPerceptron(){ w.add( 5 ); w.add( -2 ); b = 3 ; } /** * 判斷是否分類正確 * @param data 待判斷數據 * @return 返回判斷正確與否 */ public static boolean getValue(Data data){ int state = 0 ; for ( int i = 0; i < data.x.size(); i++ ){ state += w.get(i) * data.x.get(i); } state += b; return state * data.y > 0? true : false ; } // 此算法基于數據是線性可分的,如果線性不可分,則會進入死循環(huán) public static boolean isStop(ArrayList<Data> datas){ boolean isStop = true ; for (Data data: datas){ isStop = isStop && getValue(data); } return isStop; } public static void main(String[] args) { PrimevalPerceptron model = new PrimevalPerceptron(); ArrayList <Data> datas = new ArrayList<> (); FileIO fileIO = new FileIO(); fileIO.setFileName( ".//file//perceptron.txt" ); fileIO.FileRead(); for (String data: fileIO.fileList){ datas.add( new Data(data)); } /** * 如果全部數據都分類正確則結束迭代 */ while (! isStop(datas)){ for ( int i = 0; i < datas.size(); i++ ){ if (!getValue(datas.get(i))){ // 這里面可以理解為是一個簡單的梯度下降法 for ( int j = 0; j < datas.get(i).x.size(); j++ ) w.set(j, w.get(j) + datas.get(i).y * datas.get(i).x.get(j)); b += datas.get(i).y; System.out.println(w + "\t" + b); } } } System.out.println(w + "\t" + b); // 輸出最終的結果 } }
最后解得(這里應該是寫錯了,最終結果b=-3)。不過,如果選取的初值不同,或者選取的誤分類點不同,我們得到的超平面S也不盡相同,畢竟感知機模型的解是一組符合條件的超平面的集合,而不是某一個最優(yōu)超平面。
3.2 算法的收斂性
一節(jié)純數學的東西,用了兩整頁證明了Novikoff定理,看到這里才知道智商真的是個硬傷,反復看了兩遍,又把證明的過程自己推導了一遍,算是看懂了,如果憑空證明的話,自己的功力還差得遠。
Novikoff于1962年證明了感知機算法的收斂性,作為一個懶人,由于這一節(jié)涉及大量公式,即使有l(wèi)atex插件也是個麻煩的工作,具體什么情況我就不談了,同時,哥倫比亞大學有這樣的一篇叫《
Convergence Proof for the Perceptron Algorithm
》的筆記,講解了這個定理的證明過程,也給了我一個偷懶的理由
。
3.3 感知機學習算法的對偶形式
書上說對偶形式的基本想法是,將w和b表示為實例和的線性組合形式,通過求解其系數而求得w和b。這個想法及下面的算法描述很容易看懂,只不過為什么要這么做?將w和b用x和y來表示有什么好處呢?看起來也不怎么直觀,計算量上似乎并沒有減少。如果說支持向量機的求最優(yōu)過程使用對偶形式是為了方便引入核函數,那這里的對偶形式是用來做什么的呢?暫且認為是對后面支持向量機的一個鋪墊吧,或者是求解這種優(yōu)化問題的一個普遍解法。
繼續(xù)正題,為了方便推導,可將初始值和都設為0,據上文,我們對誤分類點通過
來更新w,b,假設我們通過誤分類點更新參數的次數為次,那么w,b關于的增量為和,為方便,可將用來表示,很容易可以得到
????????????????????????????????????????????????????????????????????????????????????????????????????????? (10)
??????????????????????????????????????????????????????????????????????????????????????????????????????????????? (11)
這里i = 1,2,…,N。當時,表示第i個樣本由于被誤分類而進行更新的次數。某樣本更新次數越多,表示它距離超平面S越近,也就越難正確分類。換句話說,這樣的樣本對學習結果影響最大。
算法2 (感知機學習算法的對偶形式)
輸入:訓練數據集,其中,,i = 1,2,…,N;學習率
輸出:,b;感知機模型
其中
(1),
(2)在訓練集中選取樣本
(3)如果
(4)轉至(2)直到沒有誤分類樣本出現
由于訓練實例僅以內積的形式出現,為方便,可預先將訓練集中實例間的內積計算出來并以矩陣形式存儲(就是那個的部分),這就是所謂的Gram矩陣(線性代數學的不好的飄過To T):
又到例題時間!再說一遍,這本書最大的好處就是有實例,凡是只講理論不給例子的行為都是耍流氓!
例2 ?同 例1 ,只不過是用對偶形式來求解。
同樣過程不再分析,給出我的求解代碼:
?
package org.juefan.perceptron; import java.util.ArrayList; import org.juefan.basic.FileIO; public class GramPerceptrom { public static ArrayList<Integer> a = new ArrayList<> (); public static int b ; /* 初始化參數 */ public GramPerceptrom( int num){ for ( int i = 0; i < num; i++ ) a.add( 0 ); b = 0 ; } /** Gram矩陣 */ public static ArrayList<ArrayList<Integer>> gram = new ArrayList<> (); public void setGram(ArrayList<Data> datas){ for ( int i = 0; i < datas.size(); i++ ){ ArrayList <Integer> rowGram = new ArrayList<> (); for ( int j = 0; j < datas.size(); j++ ){ rowGram.add(Data.getInner(datas.get(i), datas.get(j))); } gram.add(rowGram); } } /** 是否正確分類 */ public static boolean isCorrect( int i, ArrayList<Data> datas){ int value = 0 ; for ( int j = 0; j < datas.size(); j++ ) value += a.get(j)*datas.get(j).y * gram.get(j).get(i); value = datas.get(i).y * (value + b); return value > 0 ? true : false ; } // 此算法基于數據是線性可分的,如果線性不可分,則會進入死循環(huán) public static boolean isStop(ArrayList<Data> datas){ boolean isStop = true ; for ( int i = 0; i < datas.size(); i++ ){ isStop = isStop && isCorrect(i, datas); } return isStop; } public static void main(String[] args) { ArrayList <Data> datas = new ArrayList<> (); FileIO fileIO = new FileIO(); fileIO.setFileName( ".//file//perceptron.txt" ); fileIO.FileRead(); for (String data: fileIO.fileList){ datas.add( new Data(data)); } GramPerceptrom gram = new GramPerceptrom(datas.size()); gram.setGram(datas); System.out.println(datas.size()); while (! isStop(datas)){ for ( int i = 0; i < datas.size(); i++ ) if (! isCorrect(i, datas)){ a.set(i, a.get(i) + 1 ); b += datas.get(i).y; System.out.println(a + "\t" + b); } } } }
?
4. 小結
終于寫完一章的內容了,用了不少功夫,果然是說起來容易做起來難呢。不過通過這樣記錄的方式(雖然大部分是抄書),自己對相關算法的理論及過程就又學了一遍,覺得這個時間還是花的值得的。
本章介紹了統(tǒng)計學習中最簡單的一種算法——感知機,對現在的機器學習理論來說,這個算法的確是太簡單了,但這樣簡單的東西卻是很多現在流行算法的基礎,比如神經網絡,比如支持向量機,Deep Learning還沒了解過,不知道有多大聯系。當然,實際應用價值不能說沒有,可以用于一些簡單的線性分類的情況,也可以由簡單的二類分類擴展到多類分類(詳見 此PPT ),可以用于自然語義處理等領域。
再將思路整理一下,盡量掌握感知機算法,再好好看看 維基百科鏈接 中的有關文獻。
?
PS:本文的內容轉載自?http://www.cnblogs.com/OldPanda/archive/2013/04/12/3017100.html
原博主是Python寫的代碼,我這邊改成Java了
對代碼有興趣的可以上本人的GitHub查看: https://github.com/JueFan/StatisticsLearningMethod/
2014-06-29:倆種算法的代碼都完成了,更新完畢
更多文章、技術交流、商務合作、聯系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
