久しぶりのDeepLearning関連の記事です。
最近、昔の記事を引用してくれることが増えたのですが、すごい汚いコードを参考にさせてしまって本当に申し訳ないです。もはや恥ずかしささえも感じる・・・。時間があれば昔の記事も更新していきたいです。
挨拶はこの辺にして早速FCNの紹介からやっていきます。
ちなみに、この記事の全コードはこちらのgithubにあげてるので、気になる部分がありましたらこちらを参照お願いします。 github.com
FCNとは
FCNはFully Convolutional Networksの頭をとって名付けられたもので、画像から物体をpixel-wise(ピクセル単位)で予測をする"segmentation"に用いる技術です。こちらに参考元の論文があります。
僕がこの論文に最初に目を通したのは半年くらい前なのですが、当時は「segmentationのためにFCNを使って、後ろにこんな層をくっつけるともっと精度良くなるぜ」的な論文が多くあったような気がします。つまりいい感じに流行っていた記憶があります。
今はどうなってるか知りません。(論文全然読めていないので・・・。)
実は時代遅れになっていて、もしかしたらハロウィンを楽しんでいたJKに「今更FCNとかwwwwウケるwwww」と一蹴されてしまうかも。
そういえば、いきなり"segmentation"という言葉を使って全く説明しないのもどうかと思うので先に"segmentation"に少し触れます。
segmentationの軽い説明。
さっきの言葉だけではsegmentationのイメージがつかないかもしれないので、論文より画像を引用します。
一番右の画像(Image)に対して、その一つ隣の画像(Ground Truth)を学習させることにより左側2列のような画像を出力できるようにしたい。というのが最終目標。ちなみにFCN-8sが今回記事で紹介している方法で、SDSはまた別の手法です。論文の中で比べてるだけです。紛らわしくてすいません。
もう少し、segmentationの位置付けを明確にしてみましょう
画像系DeepLearningのネタとしてよく取り上げられる「人間の顔を、誰なのか判別する」などは"classification"と呼ばれ、画像からお前は何者だ?、という"what"のみを出力します。
その次にレベルが一段階上がったのが"detection"と呼ばれるものです。これは画像の中から判別したい対象を、場所とともに認識します。以前書いたアニメ顔認識のようなものは、場所も出力されているのでdetectionのように思われる方もいるかもしれませんが、あれは顔の部分だけ切り取って判別しているのでdetectionではなくcalssificationです。
画像で見るとこんな感じです。
つまり画像から"what"と"where"を認識するわけですね。
それらに対して"segmentation"は、detectionのような矩形での認識ではなくピクセル単位での認識を行います。他の言い方をすると「ピクセルごとに"classification"をしている」とか言えるかも。
対比して無理やり英語でまとめるならば、"what"と"where"に"how"を足したようなものでしょうか。
すぐに想像つくと思いますが、自動運転や医療の場で活躍できる技術ですね。
以上で、segmentationの説明を終わります。これで最初にあげたsegmentationの画像の意味が最初より理解できたと思います。
もういっかい、FCNとは
それでは、もう少し踏み込んだFCNの説明に移ります。
FCNはその名の通り、全てがConvolution層で構成されています。構成が複雑ではなく、計算コストは畳み込みが主なため、end-to-endな学習ができ予測計算にかかる時間も少ないです。また畳み込みなので、Inputする画像の大きさに縛りがありません。(今回は簡単のために224*224に大きさを固定しています。)
さすがに予測精度はまぁまぁといったところですが、 構成が簡単な割に、十分な働きをする素晴らしい技術だと思います。これ以上の詳しい内容は論文を参照してください。
モデルの設計
まずモデルを可視化したのがこちらです。可視化の方法は以前書いた記事を参考にしてください。(色々ミスったりして一番みられたくない記事ですが。w)
長いけど、内容は全然大したことないです。つーか画質わるい。申し訳ないですが確認したかったらDLして拡大してください・・・。
モデルを作っている部分はこんな感じ。
kerasのSequentialを使うと層の分岐ができないので、Modelを使って構築しています。
モデルを組むこと自体は問題なかったのですが、ちょっとDeconvolution2Dの部分でハマりました。
畳み込みをしまくった後、今度は畳み込みによって得られたスコアを元の大きさになるまで広げていく必要があります。ここでDeconvolution2D層を使うのですが、この層が出力するoutputのテンソルがどうもオカシイ。
kerasのドキュメントには「outputのshapeはこうやって出してるよ〜」という計算式があり、パディングを表す"p"が計算式の中にあるくせにpaddingを引数で設定できません。
padding周辺を牛耳ってるのはborder_modeっていう引数なんですが"same"と"valid"しか設定できないんですよね。"same"はinputと同じ大きさにするようにpaddingするもので、Deconvでは使うことなさそうな引数。それに対し"valid"はpaddingに関して何にもしません。つまりこの二つのどっちかではpadding込みのテンソルが出力できない。
ん〜。何か見落としてるんだろうなぁ・・・。誰かわかる方いらっしゃいましたら教えてください。お願いします。
ですのでCropping2D層をDeconvolution2D層の後ろにつけることによって半ば強引にこの問題を解消しました。ちなみにCropping2Dはfeature mapや画像の端っこを切り落としてくれる層です。理論的にこれで問題はないはずなのですが、理論的に問題があるとお気付きのかたがいらっしゃいましたらコメントまでよろしくお願いします。
がくしゅ〜
今回、モデルを学習するにあたってPASCAL VOC 2012のデータセットをお借りしました。先ほども少し触れましたが教師画像の大きさは224*224に揃えてから学習させています。
以下が学習に用いた画像です。左が教師データで、右がtargetになります。物体の分類は色の違いで区別されています。
また画像はメモリをたくさん食うので学習の際にfit_generator関数を使っています。初めて使ったけど使い勝手いい感じです。一応そこの部分のコードをのっけます。
まずgenerator部分。binarylab関数はtargetの画像の中で色のついてる部分に1を立てる関数です。
そんでこっちが実際に動かすtrain関数。
こうすることで画像をcpuで展開して、計算はgpuで並列して行うことができるそうです。
学習はAmazon EC2 インスタンスを用いて行ないました。Amazon Linuxはよくわからなかったので、Bitfusion Boost Ubuntu 14 Torch 7を借りています。こっちならubuntuだし最初からcuDNNも入っているので楽チンです。こちらに紹介が載っています。
当たり前ですが、classificationなどに比べて収束は格段に遅いです。
学習データ1450枚に対して、epochは100で回しました。
学習結果
とりあえず100epoch終えた後のスクリーンショット。
トータルで約285秒 * 100エポックなので28500秒。だいたい8時間いかないくらいです。cpuだと1エポックで2時間弱なのでGPUは必須です。
まずは学習で使ったデータから。
暗い色で設定しちゃってるクラスが見辛いので、背景は真っ白にしちゃってます。
左から、オリジナルの写真。学習させたデータ。予測結果。です。
バスいい感じ!!
顔面が・・・。でも割とちゃんとできてます。
うまくいってないのも載せておきます。
次に学習で使っていない未知データに対する予測。ここが重要。
猫。すごいうまくいってる!!感動!!
リア充爆発した。
え、結構いけてるじゃん。学習に使った犬より全然いい感じ。
以上です。学習の際にクラスごとのweightを設定してないので、その辺もっと厳密にやれば絶対に精度上がります。
一応学習後のパラメータはここに置いておきますね。
もしも手元で試したかったら、このパラメータを使ってpredict.pyで遊んでみてください。精度は大して良くないし、224*224の画像を用意しなきゃいけないので使う人はほぼいないと思いますがw
まとめ
- segmentationについてまとめました。
- FCNを軽く説明しました。
- 学習結果を紹介しました。
あたかも論文からそのまま実装したかのように紹介してましたが、参考サイトがなければ絶対に詰んでました。もっと実装力を上げていかないと。
次回の予定は、AlphaGoの理論で作るそこそこ強いオセロ。とか、KerasによるFater-RCNNの実装。とかを予定しています。前者は学習がうまくいけばそろそろアップできるかもですが、後者は全くやってませんw
あとは今回実装したFCNを使って、もっと精度のいいsegmentationとかやってみたいですね。研究との兼ね合いで更新速度は変わりますが、よかったらそちらも見てください。
そういえば英語を勉強したいといってる友人がいるので、半ば強引に「俺の記事英訳してくんね?」とお願いしたら快く受け入れてくれました。ですので、英語版もアップするかもです。海外ではkerasがすごい流行ってますからね。(露骨なpv稼ぎ。)
今回の記事を書くにあたって参考にしたのは以下のサイトです。
memo: Fully Convolutional Networks 〜 Chainerによる実装 〜
大変参考になりました。ありがとうございます。
それでは。