ナイーブベイズ分類器とクロスバリデーションの実装 [Python]
ナイーブベイズ分類による機械学習と、それをクロスバリデーション(10-fold CV)により分類精度を調べるプログラムをPythonで書きました。
大学院の講義の課題で出されたもので、丸一日のたうち回りながら何とか完成させました。
ナイーブベイズ分類器について
ナイーブベイズ分類とは、一言で言うとベイズの法則を使ってデータの分類を行うことです。
先に参考にしたサイトを紹介しておきます。↓
ナイーブベイズを用いたテキスト分類 - 人工知能に関する断創録
ナイーブベイズの概要はここに詳しく書いてあります。こちらのブログには度々お世話になっていますm(__)m(信号処理だとか、ゲーム作ったりだとか)
ナイーブベイズを理解するにあたっては、条件付き確率を完璧に理解している必要があります。確立が苦手だとまずここで苦労します。。。
分類対象データ
分類対象のデータはこちらのサイトから「Iris」のデータを拝借しました。このサイトには多種多様なサンプルデータがあるので、機械学習の勉強をするには便利だと思います。
このデータは、アイリスという花を「"がく"の長さと幅」、「花弁の長さと幅」という4つのパラメータから3つの種類に分類したものです。(正直後々になって選ぶデータミスったなって思いました…笑)
Irisのデータを示しておきます。列は","で区切ってあって、左から、がくの長さ、がくの幅、花弁の長さ、花弁の幅、アイリスの種類を示しています。データの総数は150です。
5.2,3.5,1.5,0.2,Iris-setosa 5.2,3.4,1.4,0.2,Iris-setosa 4.7,3.2,1.6,0.2,Iris-setosa 4.8,3.1,1.6,0.2,Iris-setosa 5.4,3.4,1.5,0.4,Iris-setosa 7.0,3.2,4.7,1.4,Iris-versicolor 6.4,3.2,4.5,1.5,Iris-versicolor 6.9,3.1,4.9,1.5,Iris-versicolor 5.5,2.3,4.0,1.3,Iris-versicolor 6.5,2.8,4.6,1.5,Iris-versicolor 6.3,3.3,6.0,2.5,Iris-virginica 5.8,2.7,5.1,1.9,Iris-virginica 7.1,3.0,5.9,2.1,Iris-virginica 6.3,2.9,5.6,1.8,Iris-virginica 6.5,3.0,5.8,2.2,Iris-virginica …
こんなデータを分類できて何が嬉しいんだと思うかもしれませんが、このような機械学習は様々な分野に応用できるわけです。(詳しくは最初に紹介したブログを参照)
ナイーブベイズ分類器の実装
ナイーブベイズの中心となる式はベイズの法則より下の式で表せます。
確立P(iris|size)はアイリスの大きさ(size)が与えられた時にアイリスの種類(iris)が与えられる時の確率です。P(iris)はデータ全体の中でそのアイリスの種類が出現する確立で、P(size)がデータ全体の中でそのアイリスの大きさが出現する確立です。また、P(size)はどのアイリスの種類にも共通なので無視できます。
そして、最も重要なのがP(size|iris)で、アイリスの種類(iris)が与えられた時に、アイリスの大きさ(size)が生成される確立を表しています。今回は、sizeを4つのパラメータの組み合わせとして考えました。つまり、
P(size|iris)=あるiris中で出現するあるsizeの組み合わせが出現する確立 / 全体データの中であるirisが出現する確立
で求められます。(わかりにくくてすみません)
そして最終的に分類されるアイリスiris_mapは
で表されます。argmacf(x)はf(x)が最大になるxを返します。P(size|iris)は非常に小さい値のため、アンダーフローを防ぐために、対数をとって掛け算を足し算化します。
また、「ゼロ頻度問題」を解決するために、ラプラススムージングを施しています。(詳しくは最初のリンク参照)
クロスバリデーションの実装
クロスバリデーションとは、統計学において標本データを分割し、その一部をまず解析して、残る部分でその解析のテストを行い、解析自身の妥当性を検証・確認する手法のことを言います。
引用:
モデルの精度を推定する
http://musashi.sourceforge.jp/tutorial/mining/xtclassify/accuracy.html
今回は10-fold Cross-validationということで、学習データ:テストデータ = 135:15 = 9:1 としました。
つまり、10回のクロスバリーデションでナイーブ分類の精度を計測し、その平均値と標準偏差を求めます。また、精度の判断基準は「分類されたアイリスの種類がと正しいアイリスの種類が一致するか否か」としました。
ソースコード(Python)
以上のことをPythonで実装したプログラムを示します。急いでたのもありますが、正直途中で書いててイヤになるほど汚いコードに仕上がって、人様に見せるのもはばかられます(笑)決して真似をしないで参考程度にしてください。
それでもなんとか動くものができたので良かったです。
#-*- coding: utf-8 -*- import numpy as np import sys from collections import defaultdict import math import random class NaiveBays: global MAX global ONE global CV MAX = 150 ONE = 15 CV = 10 def __init__(self): self.categories = set() #アイリスの集合 self.feature = set() #サイズの集合 self.catcount = {} #catcount[cat] アイリスの出現回数 self.sizecount = {} #sizecount[cat]][x] アイリスでのサイズの出現回数 self.denominator = {} #denominator[cat] P(size|iris)の分母の値 def train(self,data): #ナイーブベイズ分類器の訓練 #初期化 for d in data: cat = d[0] self.categories.add(cat) for cat in self.categories: self.sizecount[cat] = defaultdict(int) self.catcount[cat] = 0 #アイリスとサイズをカウント for d in data: cat,feat = d[0],d[1:] self.catcount[cat] += 1 for x in feat: self.feature.add(x) self.sizecount[cat][x] += 1 #P(word|cat)の分母の値を計算 for cat in self.categories: self.denominator[cat] = sum(self.sizecount[cat].values()) def sizeProb(self, size, cat): #サイズの条件付き確率P(size|iris)を求める return float(self.sizecount[cat][size] + 1) / float(self.denominator[cat]) def score(self, test, cat): #log(P(iris|size))を求める total = sum(self.catcount.values()) score = math.log(float(self.catcount[cat]) / float(total)) score += math.log(self.sizeProb(test, cat)) return score def classify(self, test): #log(P(iris|size))が最大となるアイリスを返す best = None max = -sys.maxint for cat in self.catcount.keys(): p = self.score(test, cat) if p > max: max = p best = cat return best def __str__(self): total = sum(self.catcount.values()) return "data: %d, patterns: %d, categories: %d" % (total, len(self.feature), len(self.categories)) if __name__ == "__main__": nb = NaiveBays() size = [0]*MAX iris = [0]*MAX argvs = sys.argv iris_kind = {"Iris-setosa","Iris-versicolor","Iris-virginica"} print "Traindata: %s" % argvs[1] f = open(argvs[1]) line = f.readline() data = [] rate_list =[] #データの読み込み for i in range(MAX): a = line.split(',') size[i] = "%s,%s,%s,%s" % (int(round(float(a[0]),0)),int(round(float(a[1]),0)),int(round(float(a[2]),0)),int(round(float(a[3]),0))) iris[i] = a[4] list = [iris[i],size[i]] data.append(list) line = f.readline() random.shuffle(data) k = 0 #10-fold クロスバリデーション for j in range(CV): one = data[k:k+15] if j == 0: nine = data[15:150] elif j == 9: nine = data[135:150] else: nine == data[0:k] nine == data[k+15:150] k += ONE print "Try : %d" % (j+1) print "Now Training..." nb.train(nine) print nb print "Training finished!" print "Now Testing..." yes_count = 0 for i in range(ONE): print "%s : %s" % (one[i][1],nb.classify(one[i][1]).replace('\n','')), if nb.classify(one[i][1]) == one[i][0]: print " : YES" yes_count += 1 else: print " : NO" rate = float(yes_count) / float(ONE) print "Testing finished!" print "Rate : %f\n" % rate rate_list.append(rate) print "Average Rate : %f" % np.average(rate_list) print "Standard Devison : %f" % np.std(rate_list) f.close()
実行結果
上記のプログラムの実行結果を示します。分類したアイリスの種類と正しいアイリスの種類が一致した場合は「YES」、一致しない場合は「NO」と表示しています。10回それぞれクロスバリデーションの精度を求めて、最後にその平均値を出力しています。
ちなみに精度を高めるためにアイリスの大きさは少数第一位を四捨五入して整数に加工しています。
プログラム名は「nbc.py」でコマンドライン引数で分類対象データ「iris.data」を渡しています。
$ python nbc.py iris.data Traindata: iris.data Try : 1 Now Training... data: 135, patterns: 34, categories: 3 Training finished! Now Testing... 6,4,2,0 : Iris-setosa : YES 7,3,5,2 : Iris-virginica : NO 6,3,5,1 : Iris-versicolor : YES 4,3,1,0 : Iris-setosa : YES 6,3,4,1 : Iris-versicolor : YES 5,3,2,0 : Iris-setosa : YES 5,4,1,0 : Iris-setosa : YES 6,3,4,1 : Iris-versicolor : YES 6,3,4,1 : Iris-versicolor : YES 5,3,2,0 : Iris-setosa : YES 5,3,2,0 : Iris-setosa : YES 6,3,4,1 : Iris-versicolor : YES 6,3,4,1 : Iris-versicolor : YES 7,3,5,2 : Iris-virginica : YES 5,3,2,0 : Iris-setosa : YES Testing finished! Rate : 0.933333 Try : 2 Now Training... data: 135, patterns: 34, categories: 3 Training finished! Now Testing... 6,3,4,1 : Iris-versicolor : YES 6,3,5,2 : Iris-virginica : YES 7,3,5,2 : Iris-virginica : YES 6,3,5,2 : Iris-virginica : NO 5,3,2,0 : Iris-setosa : YES 6,3,5,2 : Iris-virginica : YES 6,3,5,2 : Iris-virginica : YES 7,3,6,2 : Iris-virginica : YES 6,3,5,2 : Iris-virginica : YES 6,3,5,2 : Iris-virginica : NO 7,3,5,2 : Iris-virginica : YES 8,3,6,2 : Iris-virginica : YES 5,3,2,0 : Iris-setosa : YES 6,3,5,2 : Iris-virginica : YES 5,3,1,0 : Iris-setosa : YES Testing finished! Rate : 0.866667 … Try : 10 Now Training... data: 15, patterns: 34, categories: 3 Training finished! Now Testing... 6,2,4,1 : Iris-versicolor : YES 6,3,6,2 : Iris-virginica : YES 5,4,2,0 : Iris-setosa : YES 5,3,1,0 : Iris-setosa : YES 5,4,2,0 : Iris-setosa : YES 6,3,5,2 : Iris-versicolor : YES 5,3,2,0 : Iris-setosa : YES 5,4,1,0 : Iris-setosa : YES 7,3,6,2 : Iris-virginica : YES 6,3,4,1 : Iris-versicolor : YES 5,3,1,0 : Iris-setosa : YES 6,3,4,1 : Iris-versicolor : YES 6,3,4,1 : Iris-versicolor : YES 5,3,1,0 : Iris-setosa : YES 8,3,7,2 : Iris-virginica : YES Testing finished! Rate : 1.000000 Average Rate : 0.913333 Standard Devison : 0.084591
クロスバリデーションの結果、このナイーブベイズ分類器の精度は平均が0.913333、標準偏差が0.084591だということが分かりました。果たしてこの数字は大きいのか小さいのか…。中には精度が100%だった試行もあったようですね。
データがある程度整列済みだったので、一度シャッフルしてまんべんなくテストデータを選ぶようにしています。そのため、精度が実行する度に変化しますが、いずれにせよ0.9を±0.02くらいで行ったり来たりする結果となりました。