diff --git a/EM/EM.ipynb b/EM/EM.ipynb new file mode 100644 index 0000000..59a70e7 --- /dev/null +++ b/EM/EM.ipynb @@ -0,0 +1,324 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 91, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import copy\n", + "import math\n", + "import matplotlib.pyplot as plt\n", + "# mathplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 创建数据集" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "一个样本中有两个高斯分布,均值不同,方差相同;\n", + "实例:从一个学校里随机抽取学生,抽取n个,其中包含a个女学生,b个男学生。" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# n 样本数量 ,k 几个高斯分布,sigma:方差,real_mu 包含两个分布的真实期望\n", + "\n", + "def create_data(n,k,real_mu,sigma):\n", + " X = np.zeros((n,1))\n", + " for i in range(0,n):\n", + "# 随机抽取 男女生\n", + " if np.random.random()>0.5:\n", + " X[i][0] = np.random.normal()*sigma + real_mu[0]\n", + " else:\n", + " X[i][0] = np.random.normal()*sigma + real_mu[1]\n", + " return X" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 高斯分布" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "\n", + "def gaussian_dist(x,mu,sigma):\n", + " exponent = np.exp(-np.power((x-mu),2)/(2*np.power(sigma,2)))\n", + " result = 1/np.power(2*np.pi*sigma**2 , 1/2) *exponent\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Expectation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "固定$\\theta$ ,优化条件期望" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "简单而言:已知 $\\theta$ ,求条件期望,即在某个样本结果出现的条件下,是男生或者女生的概率" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def e_step(X,mu,sigma,expectations):\n", + " length = X.shape[0]\n", + " k = mu.shape[0]\n", + "\n", + " for i in range(0,length):\n", + " denominator = 0;\n", + " for j in range(0,k):\n", + " denominator += gaussian_dist(X[i][0],mu[j],sigma)\n", + " \n", + " for j in range(0,k): \n", + " numberator = gaussian_dist(X[i][0],mu[j],sigma) \n", + " expectations[i][j]= numberator / denominator\n", + " \n", + " return expectations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Maximization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "已知条件期望,求解$\\theta$在此次迭代中的最大值" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# 这种方法更直观,\n", + "# def m_step(X,mu,sigma,expectations):\n", + "# n = X.shape[0]\n", + "# k = mu.shape[0] \n", + "# for i in range(0,k): \n", + "# numberator = 0\n", + "# denominator = 0\n", + "# for j in range(0,n):\n", + "# numberator += expectations[j,i]*X[j,0]\n", + "# denominator += expectations[j,i]\n", + " \n", + "# mu[i] = numberator/denominator\n", + " \n", + "# return mu" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# 向量法,简单高效\n", + "def m_step(X,mu,sigma,expectations):\n", + " numbers = np.transpose(X).dot(expectations)\n", + " denominator= np.sum(expectations,axis=0)\n", + "\n", + " result = numbers/denominator\n", + "# result是1*2的矩阵,\n", + "# 我们需要返回一个一维向量,包含两个分量而已。\n", + " mu[0] = result[0][0]\n", + " mu[1] = result[0][1]\n", + " return mu" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 训练" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def train(X,sigma,expectations,iteration,epsilon):\n", + " mu = np.random.random(2)\n", + " for i in range(0,iteration):\n", + " # deepcopy 使得两个值相互不影响\n", + " old_mu = copy.deepcopy(mu)\n", + " # E步:求出条件期望\n", + " expectations = e_step(X,mu,sigma,expectations)\n", + " \n", + " print(i,mu)\n", + " # M步,利用E步的期望,求最大的theta()\n", + " mu = m_step(X,mu,sigma,expectations)\n", + " # 如果两次迭代的差值小于阈值,则训练结束,也可以认为找到局部最优值,\n", + " if np.abs(np.sum(old_mu) - np.sum(mu)) < epsilon:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 [0.71085143 0.97453594]\n", + "1 [168.60550916 169.62356611]\n", + "2 [167.44608098 171.32745846]\n", + "3 [163.31669663 175.570302 ]\n", + "4 [160.18984499 179.48883559]\n", + "5 [160.04240354 180.32256705]\n", + "6 [160.14223833 180.53002579]\n", + "7 [160.19102044 180.60482685]\n", + "8 [160.21080304 180.63380363]\n", + "9 [160.21861444 180.64513455]\n", + "10 [160.22168114 180.64957042]\n", + "11 [160.22288306 180.65130723]\n", + "12 [160.22335384 180.65198726]\n", + "13 [160.22353819 180.65225352]\n", + "14 [160.22361038 180.65235778]\n", + "15 [160.22363865 180.6523986 ]\n", + "16 [160.22364971 180.65241458]\n", + "17 [160.22365405 180.65242084]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAADYhJREFUeJzt3X+MZXdZx/H3hy5FKyCtO8WGdtxK\noEljIpSRVBHQFgSKofgD0yYgCmYjkdo2IilBgcR/KiKKidEstFItKQg0gopKrVRiAgvdpdCWpZQf\nBbYtLYQI/EXBPv5xz2aHzcx2554zs3Mf3q9kcs895+zc59nvvZ8599wf31QVkqReHna8C5AkTc9w\nl6SGDHdJashwl6SGDHdJashwl6SGDHdJashwl6SGDHdJamjHVt7Yzp07a9euXVt5k5K08Pbt2/f1\nqlrayL/Z0nDftWsXN99881bepCQtvCRf2ui/8bSMJDVkuEtSQ4a7JDVkuEtSQ4a7JDX0kOGe5Ook\n9ye5bdW6U5LckOTO4fLkzS1TkrQRx3Lk/nbguUesuwK4saqeANw4XJckbRMPGe5V9WHgG0esvhC4\nZli+BnjhxHVJkkaY95z7Y6vqXoDh8tTpSpIkjbXpn1BNshvYDbC8vLzZN6cFt+uKf11z/V1XPn+L\nK5EW27xH7vclOQ1guLx/vR2rak9VrVTVytLShr4aQZI0p3nD/f3AS4fllwLvm6YcSdIUjuWtkNcB\nHwHOSnIwycuBK4FnJ7kTePZwXZK0TTzkOfequnidTedPXIskaSJ+QlWSGjLcJakhw12SGjLcJakh\nw12SGjLcJakhw12SGjLcJakhw12SGjLcJakhw12SGjLcJakhw12SGjLcJakhw12SGtr0OVS1eJzH\ndHubanwc5948cpekhgx3SWrIcJekhgx3SWrIcJekhgx3SWrIcJekhgx3SWrIcJekhgx3SWrIcJek\nhgx3SWrIcJekhgx3SWrIcJekhgx3SWpoVLgnuTzJ7UluS3Jdkh+aqjBJ0vzmDvckjwN+H1ipqp8C\nTgAumqowSdL8xp6W2QH8cJIdwEnAPeNLkiSNNXe4V9XdwJuALwP3At+sqg9OVZgkaX5zT5Cd5GTg\nQuBM4H+Bdyd5cVVde8R+u4HdAMvLyyNK1UPpPOGxk0JLGzPmtMyzgC9W1deq6rvA9cDPHblTVe2p\nqpWqWllaWhpxc5KkYzUm3L8MnJvkpCQBzgcOTFOWJGmMMefc9wLvAfYDtw6/a89EdUmSRpj7nDtA\nVb0eeP1EtUiSJuInVCWpIcNdkhoy3CWpIcNdkhoy3CWpIcNdkhoy3CWpIcNdkhoy3CWpIcNdkhoy\n3CWpIcNdkhoy3CWpIcNdkhoy3CWpIcNdkhoaNVmHFsN6k0Ifz9t2Qur+jna/c/w3n0fuktSQ4S5J\nDRnuktSQ4S5JDRnuktSQ4S5JDRnuktSQ4S5JDRnuktSQ4S5JDRnuktSQ4S5JDRnuktSQ4S5JDRnu\nktSQ4S5JDY0K9ySPSfKeJJ9JciDJz05VmCRpfmNnYnoL8O9V9etJTgROmqAmSdJIc4d7kkcDzwB+\nC6CqHgAemKYsSdIYY07L/CTwNeDvknwiyduS/MhEdUmSRhhzWmYHcA5wSVXtTfIW4Argj1fvlGQ3\nsBtgeXl5xM394DmeE1uvZcoJj7dbb1uha89d+1p0Y47cDwIHq2rvcP09zML++1TVnqpaqaqVpaWl\nETcnSTpWc4d7VX0V+EqSs4ZV5wOfnqQqSdIoY98tcwnwjuGdMl8Afnt8SZKksUaFe1XdAqxMVIsk\naSJ+QlWSGjLcJakhw12SGjLcJakhw12SGjLcJakhw12SGjLcJakhw12SGjLcJakhw12SGjLcJakh\nw12SGjLcJakhw12SGho7WYcEOI/mdrZIY7NerRudo1ceuUtSS4a7JDVkuEtSQ4a7JDVkuEtSQ4a7\nJDVkuEtSQ4a7JDVkuEtSQ4a7JDVkuEtSQ4a7JDVkuEtSQ4a7JDVkuEtSQ4a7JDU0OtyTnJDkE0n+\nZYqCJEnjTXHkfilwYILfI0mayKhwT3I68HzgbdOUI0mawtgj978EXg08OEEtkqSJzD1BdpJfBu6v\nqn1JfuEo++0GdgMsLy/Pe3OtLdIExtuNEyofttn3o0W6n3q/GHfk/jTgBUnuAt4JnJfk2iN3qqo9\nVbVSVStLS0sjbk6SdKzmDveqek1VnV5Vu4CLgP+qqhdPVpkkaW6+z12SGpr7nPtqVXUTcNMUv0uS\nNJ5H7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z\n7pLUkOEuSQ0Z7pLUkOEuSQ1NMlmHvt8iTSTc1UbHYKMTKjvG4/j/t/k8cpekhgx3SWrIcJekhgx3\nSWrIcJekhgx3SWrIcJekhgx3SWrIcJekhgx3SWrIcJekhgx3SWrIcJekhgx3SWrIcJekhgx3SWpo\n7nBPckaSDyU5kOT2JJdOWZgkaX5jZmL6HvAHVbU/yaOAfUluqKpPT1SbJGlOcx+5V9W9VbV/WP42\ncAB43FSFSZLmN8kcqkl2AU8G9q6xbTewG2B5eXmKm5O0ibbj/KYbneNWE7ygmuSRwHuBy6rqW0du\nr6o9VbVSVStLS0tjb06SdAxGhXuShzML9ndU1fXTlCRJGmvMu2UCXAUcqKo3T1eSJGmsMUfuTwNe\nApyX5Jbh54KJ6pIkjTD3C6pV9T9AJqxFkjQRP6EqSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEu\nSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ0Z7pLUkOEuSQ1NMkH2VjheE+QebbJgJ+eVtNp2\nmsjbI3dJashwl6SGDHdJashwl6SGDHdJashwl6SGDHdJashwl6SGDHdJashwl6SGDHdJashwl6SG\nDHdJashwl6SGDHdJashwl6SGRoV7kucmuSPJ55JcMVVRkqRx5g73JCcAfw08DzgbuDjJ2VMVJkma\n35gj96cCn6uqL1TVA8A7gQunKUuSNMaYcH8c8JVV1w8O6yRJx1mqar5/mLwIeE5V/c5w/SXAU6vq\nkiP22w3sHq6eBdyxavNO4OtzFbD9de2ta1/Qt7eufcEPTm8/UVVLG/nHO0bc8EHgjFXXTwfuOXKn\nqtoD7FnrFyS5uapWRtSwbXXtrWtf0Le3rn2BvR3NmNMyHweekOTMJCcCFwHvH/H7JEkTmfvIvaq+\nl+SVwH8AJwBXV9Xtk1UmSZrbmNMyVNUHgA+M+BVrnq5pomtvXfuCvr117QvsbV1zv6AqSdq+/PoB\nSWpoU8M9ydVJ7k9y2xrbXpWkkuwcrifJXw1fZfCpJOdsZm1jrNVXkjckuTvJLcPPBau2vWbo644k\nzzk+VR+b9cYsySVD/bcneeOq9QvR2zpj9q5V43VXkltWbVuIvmDd3p6U5KNDbzcneeqwfmEeZ7Bu\nbz+d5CNJbk3yz0kevWrbQoxbkjOSfCjJgeExdemw/pQkNyS5c7g8eVi/8XGrqk37AZ4BnAPcdsT6\nM5i9EPslYOew7gLg34AA5wJ7N7O2qfsC3gC8ao19zwY+CTwCOBP4PHDC8e5hg739IvCfwCOG66cu\nWm/r3RdXbf9z4HWL1tdRxuyDwPOG5QuAm1YtL8Tj7Ci9fRx45rD8MuBPFm3cgNOAc4blRwGfHep/\nI3DFsP4K4E/nHbdNPXKvqg8D31hj018ArwZWn/C/EPj7mvko8Jgkp21mffM6Sl9ruRB4Z1V9p6q+\nCHyO2Vc3bEvr9PYK4Mqq+s6wz/3D+oXp7WhjliTAbwDXDasWpi9Yt7cCDh3R/iiHP4OyMI8zWLe3\ns4APD8s3AL82LC/MuFXVvVW1f1j+NnCA2Sf8LwSuGXa7BnjhsLzhcdvyc+5JXgDcXVWfPGJTh68z\neOXwlOnqQ0+n6NHXE4GnJ9mb5L+T/MywvkNvAE8H7quqO4frHfq6DPizJF8B3gS8ZljfobfbgBcM\nyy/i8IcpF7K3JLuAJwN7gcdW1b0w+wMAnDrstuHetjTck5wEvBZ43Vqb11i3SG/l+Rvg8cCTgHuZ\nPc2Hxe8LZm+ZPZnZ08E/BP5xONrt0BvAxRw+aocefb0CuLyqzgAuB64a1nfo7WXA7yXZx+yUxgPD\n+oXrLckjgfcCl1XVt4626xrrjtrbVh+5P57ZubBPJrmL2VcW7E/y4xzj1xlsV1V1X1X9X1U9CLyV\nw08HF7qvwUHg+uEp4ceAB5l978XC95ZkB/CrwLtWrV74voCXAtcPy++m0f2xqj5TVb9UVU9h9kf5\n88OmheotycOZBfs7qurQWN136HTLcHnoFOiGe9vScK+qW6vq1KraVVW7mBV8TlV9ldlXF/zm8Krw\nucA3Dz09WQRHnP/6FWZPHWHW10VJHpHkTOAJwMe2ur6R/gk4DyDJE4ETmX2hUYfengV8pqoOrlrX\noa97gGcOy+cBh045LfTjDCDJqcPlw4A/Av522LQw4zY8870KOFBVb1616f3M/jAzXL5v1fqNjdsm\nvyJ8HbNTFN9lFuQvP2L7XRx+t0yYTf7xeeBWYGWrX8Ee0xfwD0PdnxoG4rRV+7926OsOhncwbNef\ndXo7EbiW2R+s/cB5i9bbevdF4O3A766x/0L0dZQx+3lgH7N3j+wFnjLsuzCPs6P0dimzd5d8FriS\n4cOYizRuw/jUkBe3DD8XAD8G3Mjsj/GNwCnzjpufUJWkhvyEqiQ1ZLhLUkOGuyQ1ZLhLUkOGuyQ1\nZLhLUkOGuyQ1ZLhLUkP/DzazO+/xdYuZAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# 200个抽样样本,(男女生集合)\n", + "n=200\n", + "k=2\n", + "# 真实期望\n", + "real_mu = np.array([180,160])\n", + "sigma=6\n", + "iteration=1000\n", + "epsilon=0.00001\n", + "expectations = np.zeros((n,k))\n", + "X = create_data(n,k,real_mu,sigma)\n", + "train(X,sigma,expectations,iteration, epsilon)\n", + "plt.hist(X,50)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + }, + "toc": { + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "249px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}