ロジスティック回帰
ロジスティック回帰の概要
ロジスティック回帰とは、「\(0\) か \(1\) か」や「Yes か No か」のような二値分類を行うための統計モデルです。
名前に「回帰」とありますが、実際は分類問題でよく使われます。
例えば、次のような判定ができます。
- メールがスパムかスパムでないか
- 試験に合格するか合格しないか
- ある商品を買うか買わないか
アルゴリズム(ロジスティック回帰)
次のような2値分類のデータセットを考える。
\[
\{(\boldsymbol{x}_i,y_i) \mid \boldsymbol{x}_i\in\mathbb{R}^d,~y_i\in\{0,1\}\}_{i=1}^N
\]
-
ハイパーパラメータである、エポック数 \(E\) と学習率 \(\eta\) を設定する。
-
パラメータ \(\boldsymbol{w}\) を \(E\) 回更新する。
\[
\boldsymbol{w}\gets\boldsymbol{w}-\eta\cdot\frac{1}{N}\sum_{i=1}^N \left\{ \sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) - y_i \right\}\boldsymbol{x}_i
\]
ただし、\(\sigma(t)=\dfrac{1}{1+e^{-t}}\) である。
-
学習された分類器 \(f(\boldsymbol{x};\boldsymbol{w})\) にデータを代入する。
\[
f(\boldsymbol{x}_i;\boldsymbol{w})=\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) \quad (i=1,\cdots,N)
\]
-
分類結果を出力する。
\[
\hat{y}_i=
\begin{cases}
0 & (f(\boldsymbol{x}_i;\boldsymbol{w}) \lt 0.5) \\
1 & (f(\boldsymbol{x}_i;\boldsymbol{w}) \ge 0.5)
\end{cases}
\]
ロジスティック回帰の理論
次のようなデータセットが与えられたとします。
\[
\{(\boldsymbol{x}_i,y_i) \mid \boldsymbol{x}_i\in\mathbb{R}^d,~y_i\in\{0,1\}\}_{i=1}^N
\]
説明変数 \(\boldsymbol{x}_i\) から、目的変数 \(y_i\) がクラス \(0,1\) のどちらに属するかを予測する、分類モデルを考えます。
分類方法は、クラス \(1\) に属する確率 \(p\) とするとき
\[
\hat{y}_i=
\begin{cases}
0 & (p \lt 0.5) \\
1 & (p \ge 0.5)
\end{cases}
\]
説明変数 \(\boldsymbol{x}_i\in\mathbb{R}^d\) とパラメータ(重み) \(\boldsymbol{w}\in\mathbb{R}^d\) を
\[
\boldsymbol{x}_i=
\begin{bmatrix}
x_{i1} \\ x_{i2} \\ \vdots \\ x_{id}
\end{bmatrix}
,\quad
\boldsymbol{w}=
\begin{bmatrix}
w_1 \\ w_2 \\ \vdots \\ w_d
\end{bmatrix}
\]
\[
\boldsymbol{w}^\top\boldsymbol{x}_i=w_1x_{i1}+w_2x_{i2}+\cdots+w_dx_{id}
\]
を考えます。
しかし、このままだと値は実数全体 \((-\infty,\infty)\) をとるので、この値域が \([0,1]\) になるように変換する必要があります。
そこで、次の関数を考えます。
\[
\sigma(t)=\frac{1}{1+e^{-t}}
\]
この関数はシグモイド関数と呼ばれ、出力は必ず \(0\) から \(1\) の範囲に収まります。
シグモイド関数に \(\boldsymbol{w}^\top\boldsymbol{x}_i\) を代入したものを、クラス \(1\) に属する確率とします。
つまり分類器は
\[
f(\boldsymbol{x};\boldsymbol{w})=\sigma(\boldsymbol{w}^\top\boldsymbol{x})
\]
確率値による分類なので、交差エントロピー損失
\[
\ell(y,\hat{y}) = -y\log\hat{y}-(1-y)\log(1-\hat{y})
\]
を用います。
\[
\ell(y,f(\boldsymbol{x};\boldsymbol{w})) = -y\log\sigma(\boldsymbol{w}^\top\boldsymbol{x})-(1-y)\log(1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}))
\]
よって、損失関数は
\[
L(\boldsymbol{w}) = -\frac{1}{N}\sum_{i=1}^N \left\{ y_i\log\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) + (1-y_i)\log(1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)) \right\}
\]
となり、最適化問題
\[
\min_{\boldsymbol{w}}L(\boldsymbol{w})
\]
を考えます。
\[
\begin{align}
\frac{\partial L(\boldsymbol{w})}{\partial\boldsymbol{w}}
&=-\frac{1}{N}\frac{\partial}{\partial\boldsymbol{w}}\sum_{i=1}^N \left[ y_i\log \{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}+(1-y_i)\log\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\} \right]\\
&=-\frac{1}{N}\sum_{i=1}^N \frac{\partial}{\partial\boldsymbol{w}} \left[ y_i\log \{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}+(1-y_i)\log\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\} \right]\\
&=-\frac{1}{N}\sum_{i=1}^N \left\{ y_i \frac{\sigma_{\boldsymbol{w}}(\boldsymbol{w}^\top\boldsymbol{x}_i)}{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)}+(1-y_i)\frac{-\sigma_{\boldsymbol{w}}(\boldsymbol{w}^\top\boldsymbol{x}_i)}{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)} \right\}\\
&=-\frac{1}{N}\sum_{i=1}^N \left\{ y_i \frac{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}\boldsymbol{x}_i}{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)}-(1-y_i)\frac{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}\boldsymbol{x}_i}{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)} \right\}\\
&=-\frac{1}{N}\sum_{i=1}^N \left\{ y_i \{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}\boldsymbol{x}_i-(1-y_i)\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\boldsymbol{x}_i \right\}\\
&=-\frac{1}{N}\sum_{i=1}^N \left\{ y_i-y_i\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)+y_i\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) \right\}\boldsymbol{x}_i\\
&=-\frac{1}{N}\sum_{i=1}^N \left\{ y_i-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) \right\}\boldsymbol{x}_i\\
&=\frac{1}{N}\sum_{i=1}^N \left\{ \sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) - y_i \right\}\boldsymbol{x}_i
\end{align}
\]
よって、パラメータの更新式は次のようになります。
\[
\boldsymbol{w}\gets\boldsymbol{w}-\eta\cdot\frac{1}{N}\sum_{i=1}^N \left\{ \sigma(\boldsymbol{w}^\top\boldsymbol{x}_i) - y_i \right\}\boldsymbol{x}_i
\]
また、データ行列 \(X\) と出力ベクトル \(\boldsymbol{y}\) を
\[
X=
\begin{bmatrix}
\boldsymbol{x}_1^\top \\
\boldsymbol{x}_2^\top \\
\vdots \\
\boldsymbol{x}_N^\top
\end{bmatrix}
\in\mathbb{R}^{N\times d}
,\quad
\boldsymbol{y}=
\begin{bmatrix}
y_1 \\ y_2 \\ \vdots \\ y_N
\end{bmatrix}
\in\mathbb{R}^{N}
\]
とすると、次のようにも書けます。
\[
\boldsymbol{w}\gets\boldsymbol{w}-\eta\cdot\frac{1}{N}X^\top(\sigma(X\boldsymbol{w}) - \boldsymbol{y})
\]
損失関数
\[
\{y_i\,|\,\boldsymbol{x}_i,\boldsymbol{w}\}\sim \mathrm{Bernoulli}(\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i))
\]
なので、尤度は
\[
\begin{align}
P(\mathcal{Y}\,|\,\mathcal{X},\boldsymbol{w})
&=\prod_{i=1}^N P(y_i\,|\,\boldsymbol{x}_i,\boldsymbol{w})\\
&=\prod_{i=1}^N \{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}^{y_i}\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}^{1-y_i}
\end{align}
\]
負の対数尤度を考えると
\[
\begin{align}
&-\log P(\mathcal{Y}\,|\,\mathcal{X},\boldsymbol{w})\\
&=-\sum_{i=1}^N \log \left[ \{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}^{y_i}\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}^{1-y_i} \right]\\
&=-\sum_{i=1}^N \left[ y_i\log \{\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\}+(1-y_i)\log\{1-\sigma(\boldsymbol{w}^\top\boldsymbol{x}_i)\} \right]
\end{align}
\]