Chainerで交差検証する方法
スポンサーリンク
Chainerで10分割交差検証を行いたかったのですが、
インターネットで調べても split_dataset_random を使う方法、つまり cross validationを行わない方法しかありませんでした。
なので自分で色々やってみて、get_cross_validation_datasets を使えるようになったので自分の備忘録も兼ねて解説します。
Chainerを使っている方必見です。
Chainerでn分割交差検証する方法
Chainerで交差検証するときは、get_cross_validation_datasets か get_cross_validation_datasets_random を使います。
実際に使うソースコードは以下のようになります。
import chainer import chainer.links as L import chainer.functions as F from chainer.datasets import TupleDataset from chainer.datasets import get_cross_validation_datasets dataset = TupleDataset(x,y) cross = get_cross_validation_datasets(dataset,10) for data in cross: print(data)
xとyは適当に用意してください。
そのあと、いつも通りTupleDatasetでデータセットを作ります。
get_cross_validation_datasetsは引数にデータセットと n を取ります。
nの値を変えることで、分割幅を変えることができます。
上記コードを出力すると以下のような出力が得られます。
(<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32898>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32860>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32978>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32828>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32940>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a329b0>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32780>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a328d0>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32748>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32710>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a326d8>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a326a0>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32668>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32630>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a325f8>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a325c0>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32588>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32550>) (<chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a32518>, <chainer.datasets.sub_dataset.SubDataset object at 0x7f2085a324e0>)
10個のタプルになったデータが返ってきます。
順に訓練データ、テストデータとなっています。
なので、実際に訓練するときは以下のような感じです。上記のcrossを使います。
for datasets in cross: train,test = datasets train_iter = iterators.SerialIterator(train,batch) test_iter = iterators.SerialIterator(test,batch,False,False) #あとはいつも通り
こんな感じで、forでループさせて中にネットワークの訓練を書けば10回トレーニングします。
ネットワークの構築部分は外に書いたほうが良いです。
forループで毎回定義していると遅くなります。
あと、交差検証にした場合は、accuracyをリストに保持しておくと管理が簡単です。
その後pandasでデータフレームに入れたりもできますからね。
まとめ
今回は簡単にChainerで交差検証する方法をまとめました。
調べても日本語記事がなかったので誰かの役に立ったらいいなあ的な感じです。
今後も少しずつプログラミング記事も増やしていきたいなあと思っています!