Partial Label Maskingでマルチラベル分類問題のデータ不均衡に対応する
まとめ
マルチラベル分類問題におけるデータ不均衡に対応する手法として、Partial Label Masking (PLM) が利用できる。同手法の概要は次の通り。
- サンプルごとに各クラスに対する損失関数を確率的にマスクすることで、アンダーサンプリングに似た効果を期待する
- 任意の損失関数に対して適用できる
- マスクする確率は、分類器がそのクラスをどれだけ過剰に/過小に予測しているかに応じて決める
論文
K. Duarte, Y. Rawat and M. Shah, "PLM: Partial Label Masking for Imbalanced Multi-Label Classification," CVPR2021 Workshops, pp. 2739-2748.
Partial Label Masking (PLM)
いま、 番目のサンプルについて、 を目的変数(真値)、ネットワークの推論結果 と書くとき、すべてのクラス(個ある)に対する損失関数は次のように書ける。
Partial Label Masking (PLM) では、特定のクラスに対する損失関数をマスクしてしまう。具体的には、2値のマスク]を前述の損失関数に適用する。すなわち損失関数は以下のようになる。
式から明らかなように、PLMは任意の損失関数に対して適用できる。
適切にマスクを設定することで、損失関数に寄与する、それぞれのクラスに対するサンプル数を等しくできる。これはアンダーサンプリングの考え方と似ている。多クラス分類 (Multi-class classification) 問題では、アンダーサンプリングにより学習データ内のラベルの偏りを軽減することができる。しかしながら、マルチラベル分類 (Multi-label classification) では、同一のサンプルにおいて頻出なラベルと希少なラベルが共起する場合があり、アンダーサンプリングによりこのようなサンプルを取り除いても不均衡の解消にはつながらない。PLMを使えば、このようなラベルの共起に対応しつつ、学習データ内のラベル不均衡を解消できる。
マスクの生成
マスクの生成では、正例の負例に対する比率 を考える。この比率は、学習データに含まれる正例の数 を負例の数 で割ることで算出する。また、クラスの過剰あるいは過小な予測を最小化する、理想的な比率 を仮定する。すると、マスクは確率的な関数として次のように書ける。
ここで、]は確率で1に、確率で0になる関数である。上式はすなわち、正例が過剰に予測されている(ネットワークが入力を正例だと推論する傾向にある)なら、正例のうち だけを損失関数の計算に使うことを表している。反対に、正例が過小に予測されるなら(負例が過剰に予測されるなら)、負例のうち だけを損失関数の計算に使う。これにより、それぞれのクラスについて、正例と負例の比率を指定しながら分類器(ネットワーク)を訓練できる。
比率の適応的選定
では、理想的な比率 をどのように選べば良いだろうか。簡単な方法としては、学習データにて各クラスの比率を計算し、それらを平均することが考えられる。しかしながらこの方法は、著者らによると、うまく機能しない場合(データセット)があるそうだ。そこで著者らは、分類器の予測結果の確率分布から適応的にこの比率を推定する手法を提案している。
あるクラス の正例と負例について、分類器が出力する確率の分布をそれぞれ とする。分布は 個のビンに分け、離散的に扱う。また、学習データ(真値)の分布を とする。ただし、学習データでは確率が分布しているわけではなく、正例あるいは負例であることが確定しているので、確率1または0を含むビンにすべてのデータが入っている。
分類器が出力する確率の分布と、真値の分布とのずれをKullback-Leibler divergence(KLダイバージェンス、KL情報量とも)で表す:
さらに、これら標準化して および を得る。 が正のとき、分類器はクラス を過小に予測している傾向にある(反対に、負のとき、過剰に予測している)。この逆が に成り立つ。
と がバランスするように を更新すればよい。エポック における を、ひとつ前のエポックでの と の差 を用いて以下のように更新する:
ここで、 は比率を更新する割合を定めるハイパーパラメータである。
モデル訓練のループ
モデルを訓練するループは以下のようになる。
- を使ってマスク を生成する
- すべてのミニバッチに対して、マスクした損失関数を使ってパラメータを更新する(一般的なニューラルネットと同様に、順伝搬→逆伝搬を使う)。このとき、順伝搬した結果(推論結果である確率分布) を保存しておく
- 順伝搬した結果 と真値 から、 を算出する
- を用いて を更新する
実験結果
素朴なBinary cross entropy, Focal loss, サンプル数の逆数での重み付けといった手法と性能を比較している。分量が多いため、詳細は元論文を参照されたい。