【Python】各種カーネル関数を使ってサポートベクターマシンを実装する【irisデータセット】
この記事のポイント
  • ガウスカーネル(RBFカーネル),多項式カーネル,シグモイドカーネルを試す.
  • irisのデータセットを使用する.
  • プログラムの公開(任意でハイパーパラメータや使用するirisデータを変更できるようにしている).

こんにちは.けんゆー(@kenyu0501_)です.
この記事では,サポートベクターマシン2値問題を解くことを行います.

サポートベクターマシンが全く分からないという方は以下をさらっと見ておくことをお勧めします.
(参考:サポートベクターマシン(SVM)とは?〜基本からpython実装まで〜)

サポートベクターマシンでは,クラス分けをする際に,直線(もしくは超平面)での線形分離が不可能な場合があります.
その時は,特徴量を高次元へと写像して上手く分離ができる状態まで持っていきます.

しかし,入力する特徴量がもともと持っている次元数が,爆発的に増えるため,計算量も爆発的に増えます.
そのため,カーネルトリックという工夫を凝らして,マージン最大化の最適化問題を解くということをやるのでしたね!
(マージン最大化とは,各クラスのサポートベクトル(点)と分離超平面との距離を最大化すること)

今回は,irisデータセットを使って,各種カーネル関数がどのような分離超平面を形成するのか,について見ていきます.

使うデータ(irisデータセット)の確認

irisデータセットが具体的に何か分からないという方は,こちらの記事をご覧ください.
(参考:iris(アヤメ)のデータセットをpandasとseabornを使って可視化する.

今回は,irisデータのsepal_width とsepal_lengthのデータを使います.がく片の幅長さですね!
(プログラム内では変更できるようになっています)

サポートベクターマシンで分離をしますが,線形分離可能なデータと線形分離可能なデータを扱います.

<線形分離可能なデータ>

  • setosaversicolor を使う.

<線形分離可能なデータ>

  • versicolorvirginica を使う.

狙いとしては,線形分離不可能なデータを分離する際に,カーネル関数の種類によってどのような分離曲線が書かれるのか!ということを確認していきたいのです.
もちろん,ハイパーパラメータによって,分離線は違うと思いますが,ざっくりと知見を得ることは可能です.

試すカーネル関数について

上記の4つのパターンを比較します.

  1. カーネル関数なし
  2. ガウスカーネル $$K(x_i, x_j) = \exp (-\gamma||x_i – x_j||^2)$$
  3. 多項式カーネル $$K(x_i, x_j) = (x_i^T x_j+c)^d$$
  4. シグモイドカーネル $$K(x_i, x_j) = \tanh (bx_i^T x_j+c)$$

結果は!?

線形分離可能な場合

setosaversicolor の結果です.

setosaversicolor は各50サンプルずつあるのですが,トレーニングデータ(教師)に80%を使って,残りの20%をテストデータに使っています.
とりあえず,この線形分離可能なデータの分類は,テストデータで汎化性能を見た結果,正解率は100%になりました.
(当たり前ですね)

交差検証などはやっておらず,ハイパーパラメーターは以下です.
(γ=1,c=1, d=3,コスト関数C=1(ソフトマージン))

もちろん,どのカーネルを見てもあまり違いがわかりにくいです.
(直線に近い)

ペンのすけ

ほとんど直線みたいな結果だね!

線形分離不可能な場合

versicolorvirginica の結果です.

ハイパーパラメータなどはそのままで,扱うデータを変えたものです.
カーネル関数による形状を確認したかったので,交差検証などはやっていないのです.
正解率で一番大きいのは,何も手を加えていない直線っていうもの皮肉な話です.

そこらへんは,ハイパーパラメータを調整したりして検討しなければいけません.

ペンのすけ

ガウスカーネルの形状は,データが密集している箇所に局所的に領域があるみたいだね!


ペンのすけ

多項式カーネルの形状は,ある程度分布全体のばらつきを考慮して境界が歪んでいる感じだ!


ペンのすけ

シグモイドカーネルの形状は,写像後の特徴空間で急激な変動があったのか,境界面がいくつもできているね!


自分でプログラムを触ってみて,自分で扱うデータ点を変更してみてねー!

プログラム

作成したプログラムも載せておきます.

ペンのすけ

そのまま回したら,線形分離可能な場合の4つのグラフが出てくると思うよ!


プログラムの仕組みを丁寧に説明していきます.

どの特徴量を使うのか

今回,iris_datasetには,4つの特徴量があると思いますが,SVMで使用するのは,2つなので選択できるようにしました.

最終的に,X=np.hstack((X_sl,X_sw))とすると,Sepal lengthSepal widthのみを取り扱うことになります.

また,2値問題を解くので,ターゲットも2つで良いです.
ターゲットは0,1,2の3種類(それぞれ50個ずつのデータ)があるので,1つのターゲット群を除外します.
上のプログラムの例では,virginicaを除外するので,X=X[y!=2]としています.

最終的に,0と1のデータに成形された後に,標準化(StandardScaler)を行います.

標準化をするとデータの群が以下のようになります.

  • 平均→0
  • 分散(標準偏差)→1

全てのデータを平均値で引いて,引いたデータに対して,標準偏差で割るためです.

また,グラフ描写用の関数のターゲットの値も揃えてください.
ここでは今回,ターゲットのyが0と1を使用しているので,y==0とy==1です.
(ターゲットが変わるとyの値も変えてください)

ペンのすけ

これで,必要な設定はほとんど済んだと思います

svmのハイパーパラメータの設定

カーネル関数は,以下4つの場合で試しています.
(1つは線形)

これは,全てデフォルトの値を使用しています.
sklearnのライブラリを使用してますね!

ペンのすけ

線形分離不可能なデータに対して,SVMをとりあえず実装してみたい人はどうぞ!

ガウス(rbf)カーネルについては,以下の記事でチューニングする方法を書いています.
(参考:rbfカーネルのハイパーパラメータをグリッドサーチとベイズ最適化で探す)