From 6a33a4fc1e35888f897e9ffe42fba90a5757f4f4 Mon Sep 17 00:00:00 2001 From: Alexander Ratner Date: Sun, 18 Dec 2016 23:49:40 -0800 Subject: [PATCH] Runs end to end --- notebooks/1_candidate_generation.ipynb | 557 ++++--------------------- notebooks/2_labeling_functions.ipynb | 554 ++++-------------------- run.sh | 6 + set_env.sh | 9 + 4 files changed, 189 insertions(+), 937 deletions(-) create mode 100755 run.sh create mode 100755 set_env.sh diff --git a/notebooks/1_candidate_generation.ipynb b/notebooks/1_candidate_generation.ipynb index 4913599..58a0e68 100644 --- a/notebooks/1_candidate_generation.ipynb +++ b/notebooks/1_candidate_generation.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "collapsed": false }, @@ -19,7 +19,7 @@ "%autoreload 2\n", "\n", "import os\n", - "os.environ['SNORKELDB']=\"postgres://jhusson@localhost:5432/snorkel_strom\"\n", + "os.environ['SNORKELDB']=\"postgres:///stromatolite\"\n", "\n", "from snorkel import SnorkelSession\n", "session = SnorkelSession()" @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "collapsed": false }, @@ -42,7 +42,8 @@ "source": [ "from snorkel.models import Sentence\n", "\n", - "sentences = session.query(Sentence).limit(200000).all()" + "sentences = session.query(Sentence).all()\n", + "len(sentences)" ] }, { @@ -57,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "collapsed": false }, @@ -88,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "collapsed": true }, @@ -109,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "collapsed": false }, @@ -122,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "collapsed": false }, @@ -170,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false }, @@ -198,24 +199,12 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false, "scrolled": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "CPU times: user 9min 18s, sys: 12.2 s, total: 9min 30s\n", - "Wall time: 11min 17s\n", - "Number of candidates: 139\n" - ] - } - ], + "outputs": [], "source": [ "%time c = ce.extract(sentences, 'Candidate Set', session)\n", "print \"Number of candidates:\", len(c)" @@ -230,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "collapsed": false }, @@ -244,481 +233,136 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Reloading the candidates" + "### Splitting into train / test sets now...\n", + "\n", + "Splitting by _document_; first, let's see the distribution of candidates by document:" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "from snorkel.models import CandidateSet\n", - "c = session.query(CandidateSet).filter(CandidateSet.name == 'Candidate Set').one()\n" + "from collections import defaultdict\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "candidates_by_doc = defaultdict(set)\n", + "for cand in c:\n", + " candidates_by_doc[cand[0].parent.document.id].add(cand)\n", + "\n", + "plt.hist(map(len, candidates_by_doc.values()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Using the `Viewer` to inspect candidates\n", - "\n", - "Next, we'll use the `Viewer` class--here, specifically, the `SentenceNgramViewer`--to inspect the data.\n", - "\n", - "It is important to note, our goal here is to **maximize the recall of true candidates** extracted, **not** to extract _only_ the correct candidates. Learning to distinguish true candidates from false candidates is covered in Tutorial 4.\n", - "\n", - "First, we instantiate the `Viewer` object, which groups the input `Candidate` objects by `Sentence`:" + "And total number of documents:" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { - "collapsed": false, - "scrolled": true + "collapsed": false }, - "outputs": [ - { - "data": { - "application/javascript": [ - "require.undef('viewer');\n", - "\n", - "// NOTE: all elements should be selected using this.$el.find to avoid collisions with other Viewers\n", - "\n", - "define('viewer', [\"jupyter-js-widgets\"], function(widgets) {\n", - " var ViewerView = widgets.DOMWidgetView.extend({\n", - " render: function() {\n", - " this.cids = this.model.get('cids');\n", - " this.nPages = this.cids.length;\n", - " this.pid = 0;\n", - " this.cxid = 0;\n", - " this.cid = 0;\n", - "\n", - " // Insert the html payload\n", - " this.$el.append(this.model.get('html'));\n", - "\n", - " // Initialize all labels from previous sessions\n", - " this.labels = this.deserializeDict(this.model.get('_labels_serialized'));\n", - " for (var i=0; i < this.nPages; i++) {\n", - " this.pid = i;\n", - " for (var j=0; j < this.cids[i].length; j++) {\n", - " this.cxid = j;\n", - " for (var k=0; k < this.cids[i][j].length; k++) {\n", - " this.cid = k;\n", - " if (this.cids[i][j][k] in this.labels) {\n", - " this.markCurrentCandidate(false);\n", - " }\n", - " }\n", - " }\n", - " }\n", - " this.pid = 0;\n", - " this.cxid = 0;\n", - " this.cid = 0;\n", - "\n", - " // Enable button functionality for navigation\n", - " var that = this;\n", - " this.$el.find(\"#next-cand\").click(function() {\n", - " that.switchCandidate(1);\n", - " });\n", - " this.$el.find(\"#prev-cand\").click(function() {\n", - " that.switchCandidate(-1);\n", - " });\n", - " this.$el.find(\"#next-context\").click(function() {\n", - " that.switchContext(1);\n", - " });\n", - " this.$el.find(\"#prev-context\").click(function() {\n", - " that.switchContext(-1);\n", - " });\n", - " this.$el.find(\"#next-page\").click(function() {\n", - " that.switchPage(1);\n", - " });\n", - " this.$el.find(\"#prev-page\").click(function() {\n", - " that.switchPage(-1);\n", - " });\n", - " this.$el.find(\"#label-true\").click(function() {\n", - " that.labelCandidate(true, true);\n", - " });\n", - " this.$el.find(\"#label-false\").click(function() {\n", - " that.labelCandidate(false, true);\n", - " });\n", - "\n", - " // Arrow key functionality\n", - " this.$el.keydown(function(e) {\n", - " switch(e.which) {\n", - " case 74: // j\n", - " that.switchCandidate(-1);\n", - " break;\n", - "\n", - " case 73: // i\n", - " that.switchPage(-1);\n", - " break;\n", - "\n", - " case 76: // l\n", - " that.switchCandidate(1);\n", - " break;\n", - "\n", - " case 75: // k\n", - " that.switchPage(1);\n", - " break;\n", - "\n", - " case 84: // t\n", - " that.labelCandidate(true, true);\n", - " break;\n", - "\n", - " case 70: // f\n", - " that.labelCandidate(false, true);\n", - " break;\n", - " }\n", - " });\n", - "\n", - " // Show the first page and highlight the first candidate\n", - " this.$el.find(\"#viewer-page-0\").show();\n", - " this.switchCandidate(0);\n", - " },\n", - "\n", - " // Get candidate selector for currently selected candidate, escaping id properly\n", - " getCandidate: function() {\n", - " return this.$el.find(\".\"+this.cids[this.pid][this.cxid][this.cid]);\n", - " }, \n", - "\n", - " // Color the candidate correctly according to registered label, as well as set highlighting\n", - " markCurrentCandidate: function(highlight) {\n", - " var cid = this.cids[this.pid][this.cxid][this.cid];\n", - " var tags = this.$el.find(\".\"+cid);\n", - "\n", - " // Clear color classes\n", - " tags.removeClass(\"candidate-h\");\n", - " tags.removeClass(\"true-candidate\");\n", - " tags.removeClass(\"true-candidate-h\");\n", - " tags.removeClass(\"false-candidate\");\n", - " tags.removeClass(\"false-candidate-h\");\n", - " tags.removeClass(\"highlighted\");\n", - "\n", - " if (highlight) {\n", - " if (cid in this.labels) {\n", - " tags.addClass(String(this.labels[cid]) + \"-candidate-h\");\n", - " } else {\n", - " tags.addClass(\"candidate-h\");\n", - " }\n", - " \n", - " // If un-highlighting, leave with first non-null coloring\n", - " } else {\n", - " var that = this;\n", - " tags.each(function() {\n", - " var cids = $(this).attr('class').split(/\\s+/).map(function(item) {\n", - " return parseInt(item);\n", - " });\n", - " cids.sort();\n", - " for (var i in cids) {\n", - " if (cids[i] in that.labels) {\n", - " var label = that.labels[cids[i]];\n", - " $(this).addClass(String(label) + \"-candidate\");\n", - " $(this).removeClass(String(!label) + \"-candidate\");\n", - " break;\n", - " }\n", - " }\n", - " });\n", - " }\n", - "\n", - " // Extra highlighting css\n", - " if (highlight) {\n", - " tags.addClass(\"highlighted\");\n", - " }\n", - "\n", - " // Classes for showing direction of relation\n", - " if (highlight) {\n", - " this.$el.find(\".\"+cid+\"-0\").addClass(\"left-candidate\");\n", - " this.$el.find(\".\"+cid+\"-1\").addClass(\"right-candidate\");\n", - " } else {\n", - " this.$el.find(\".\"+cid+\"-0\").removeClass(\"left-candidate\");\n", - " this.$el.find(\".\"+cid+\"-1\").removeClass(\"right-candidate\");\n", - " }\n", - " },\n", - "\n", - " // Cycle through candidates and highlight, by increment inc\n", - " switchCandidate: function(inc) {\n", - " var N = this.cids[this.pid].length\n", - " var M = this.cids[this.pid][this.cxid].length;\n", - " if (N == 0 || M == 0) { return false; }\n", - "\n", - " // Clear highlighting from previous candidate\n", - " if (inc != 0) {\n", - " this.markCurrentCandidate(false);\n", - "\n", - " // Increment the cid counter\n", - "\n", - " // Move to next context\n", - " if (this.cid + inc >= M) {\n", - " while (this.cid + inc >= M) {\n", - " \n", - " // At last context on page, halt\n", - " if (this.cxid == N - 1) {\n", - " this.cid = M - 1;\n", - " inc = 0;\n", - " break;\n", - " \n", - " // Increment to next context\n", - " } else {\n", - " inc -= M - this.cid;\n", - " this.cxid += 1;\n", - " M = this.cids[this.pid][this.cxid].length;\n", - " this.cid = 0;\n", - " }\n", - " }\n", - "\n", - " // Move to previous context\n", - " } else if (this.cid + inc < 0) {\n", - " while (this.cid + inc < 0) {\n", - " \n", - " // At first context on page, halt\n", - " if (this.cxid == 0) {\n", - " this.cid = 0;\n", - " inc = 0;\n", - " break;\n", - " \n", - " // Increment to previous context\n", - " } else {\n", - " inc += this.cid + 1;\n", - " this.cxid -= 1;\n", - " M = this.cids[this.pid][this.cxid].length;\n", - " this.cid = M - 1;\n", - " }\n", - " }\n", - " }\n", - "\n", - " // Move within current context\n", - " this.cid += inc;\n", - " }\n", - " this.markCurrentCandidate(true);\n", - "\n", - " // Push this new cid to the model\n", - " this.model.set('_selected_cid', this.cids[this.pid][this.cxid][this.cid]);\n", - " this.touch();\n", - " },\n", - "\n", - " // Switch through contexts\n", - " switchContext: function(inc) {\n", - " this.markCurrentCandidate(false);\n", - "\n", - " // Iterate context on this page\n", - " var M = this.cids[this.pid].length;\n", - " if (this.cxid + inc < 0) {\n", - " this.cxid = 0;\n", - " } else if (this.cxid + inc >= M) {\n", - " this.cxid = M - 1;\n", - " } else {\n", - " this.cxid += inc;\n", - " }\n", - "\n", - " // Reset cid and set to first candidate\n", - " this.cid = 0;\n", - " this.switchCandidate(0);\n", - " },\n", - "\n", - " // Switch through pages\n", - " switchPage: function(inc) {\n", - " this.markCurrentCandidate(false);\n", - " this.$el.find(\".viewer-page\").hide();\n", - " if (this.pid + inc < 0) {\n", - " this.pid = 0;\n", - " } else if (this.pid + inc > this.nPages - 1) {\n", - " this.pid = this.nPages - 1;\n", - " } else {\n", - " this.pid += inc;\n", - " }\n", - " this.$el.find(\"#viewer-page-\"+this.pid).show();\n", - "\n", - " // Show pagination\n", - " this.$el.find(\"#page\").html(this.pid);\n", - "\n", - " // Reset cid and set to first candidate\n", - " this.cid = 0;\n", - " this.cxid = 0;\n", - " this.switchCandidate(0);\n", - " },\n", - "\n", - " // Label currently-selected candidate\n", - " labelCandidate: function(label, highlighted) {\n", - " var c = this.getCandidate();\n", - " var cid = this.cids[this.pid][this.cxid][this.cid];\n", - " var cl = String(label) + \"-candidate\";\n", - " var clh = String(label) + \"-candidate-h\";\n", - " var cln = String(!label) + \"-candidate\";\n", - " var clnh = String(!label) + \"-candidate-h\";\n", - "\n", - " // Toggle label highlighting\n", - " if (c.hasClass(cl) || c.hasClass(clh)) {\n", - " c.removeClass(cl);\n", - " c.removeClass(clh);\n", - " if (highlighted) {\n", - " c.addClass(\"candidate-h\");\n", - " }\n", - " this.labels[cid] = null;\n", - " this.send({event: 'delete_label', cid: cid});\n", - " } else {\n", - " c.removeClass(cln);\n", - " c.removeClass(clnh);\n", - " if (highlighted) {\n", - " c.addClass(clh);\n", - " } else {\n", - " c.addClass(cl);\n", - " }\n", - " this.labels[cid] = label;\n", - " this.send({event: 'set_label', cid: cid, value: label});\n", - " }\n", - "\n", - " // Set the label and pass back to the model\n", - " this.model.set('_labels_serialized', this.serializeDict(this.labels));\n", - " this.touch();\n", - " },\n", - "\n", - " // Serialization of hash maps, because traitlets Dict doesn't seem to work...\n", - " serializeDict: function(d) {\n", - " var s = [];\n", - " for (var key in d) {\n", - " s.push(key+\"~~\"+d[key]);\n", - " }\n", - " return s.join();\n", - " },\n", - "\n", - " // Deserialization of hash maps\n", - " deserializeDict: function(s) {\n", - " var d = {};\n", - " var entries = s.split(/,/);\n", - " var kv;\n", - " for (var i in entries) {\n", - " kv = entries[i].split(/~~/);\n", - " if (kv[1] == \"true\") {\n", - " d[kv[0]] = true;\n", - " } else if (kv[1] == \"false\") {\n", - " d[kv[0]] = false;\n", - " }\n", - " }\n", - " return d;\n", - " },\n", - " });\n", - "\n", - " return {\n", - " ViewerView: ViewerView\n", - " };\n", - "});\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "from snorkel.viewer import SentenceNgramViewer\n", - "\n", - "# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook\n", - "# You should ignore this!\n", - "import os\n", - "if 'CI' not in os.environ:\n", - " sv = SentenceNgramViewer(c, session, annotator_name=\"Tutorial Part 2 User\")\n", - "else:\n", - " sv = None" + "len(candidates_by_doc.keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we render the `Viewer." + "Now, split the candidates into train / test:" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { - "collapsed": false + "collapsed": true }, "outputs": [], "source": [ - "sv" + "from random import shuffle\n", + "\n", + "doc_ids = list(candidates_by_doc.keys())\n", + "shuffle(doc_ids)\n", + "split = int(0.66 * len(doc_ids))\n", + "\n", + "train = CandidateSet(name='Training Candidates')\n", + "session.add(train)\n", + "for doc_id in doc_ids[:split]:\n", + " for cand in candidates_by_doc[doc_id]:\n", + " train.append(cand)\n", + "print len(train)\n", + "\n", + "test = CandidateSet(name='Test Candidates')\n", + "session.add(test)\n", + "for doc_id in doc_ids[split:]:\n", + " for cand in candidates_by_doc[doc_id]:\n", + " test.append(cand)\n", + "print len(test_candidates)\n", + "\n", + "session.commit()" ] }, { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "StromStrat(Span(\"stromatolitic\", parent=12007, chars=[41,53], words=[7,7]), Span(\"Reynolds Point Formation\", parent=12007, chars=[78,101], words=[12,14]))" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "cell_type": "markdown", + "metadata": {}, "source": [ - "sv.get_selected()" + "### Reloading the candidates" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sentence(Document 54b43244e138239d8684934a, 47, u'Additional work , carried out in the summer of 1975 , has shown that elongate stromatolites are also present in the overlying Reynolds Point Formation .')\n" - ] - } - ], + "outputs": [], "source": [ - "\n" + "from snorkel.models import CandidateSet\n", + "\n", + "train = session.query(CandidateSet).filter(CandidateSet.name == 'Training Candidates').one()\n", + "print len(train)\n", + "\n", + "test = session.query(CandidateSet).filter(CandidateSet.name == 'Test Candidates').one()\n", + "print len(test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that we can **navigate using the provided buttons**, or **using the keyboard (hover over buttons to see controls)**, highlight candidates (even if they overlap), and also **apply binary labels** (more on where to use this later!). In particular, note that **the Viewer is synced dynamically with the notebook**, so that we can for example get the `Candidate` that is currently selected. Try it out!" + "## Using the `Viewer` to inspect candidates\n", + "\n", + "Next, we'll use the `Viewer` class--here, specifically, the `SentenceNgramViewer`--to inspect the data.\n", + "\n", + "It is important to note, our goal here is to **maximize the recall of true candidates** extracted, **not** to extract _only_ the correct candidates. Learning to distinguish true candidates from false candidates is covered in Tutorial 4.\n", + "\n", + "First, we instantiate the `Viewer` object, which groups the input `Candidate` objects by `Sentence`:" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": { - "collapsed": false + "collapsed": false, + "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "StromStrat(Span(\"stromatolitic\", parent=11957, chars=[16,28], words=[3,3]), Span(\"Glenelg Formation\", parent=11957, chars=[68,84], words=[12,13]))\n", - "11957\n" - ] - } - ], - "source": [ - "if 'CI' not in os.environ:\n", - " print unicode(sv.get_selected())\n", - " print sv.get_selected()[0].parent.id" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, + "outputs": [], "source": [ - "### Repeating for development and test corpora\n", - "We will rerun the same operations for the other two news corpora: development and test. All we do for each is load in the `Corpus` object, collect the `Sentence` objects, and run them through the `CandidateExtractor`." + "from snorkel.viewer import SentenceNgramViewer\n", + "\n", + "sv = SentenceNgramViewer(train, session)\n", + "sv" ] }, { @@ -729,34 +373,15 @@ }, "outputs": [], "source": [ - "for corpus_name in ['News Development', 'News Test']:\n", - " corpus = session.query(Corpus).filter(Corpus.name == corpus_name).one()\n", - " sentences = set()\n", - " for document in corpus:\n", - " for sentence in document.sentences:\n", - " if number_of_people(sentence) < 5:\n", - " sentences.add(sentence)\n", - " \n", - " %time c = ce.extract(sentences, corpus_name + ' Candidates', session)\n", - " session.add(c)\n", - "session.commit()" + "sv.get_selected()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Next, in Part 3, we will annotate some candidates with labels so that we can evaluate performance." + "Note that we can **navigate using the provided buttons**, or **using the keyboard (hover over buttons to see controls)**, highlight candidates (even if they overlap), and also **apply binary labels** (more on where to use this later!). In particular, note that **the Viewer is synced dynamically with the notebook**, so that we can for example get the `Candidate` that is currently selected. Try it out!" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] } ], "metadata": { @@ -776,19 +401,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.11" - }, - "widgets": { - "state": { - "805e1e8d6ca24bf982841aa75c864a05": { - "views": [ - { - "cell_index": 24 - } - ] - } - }, - "version": "1.2.0" + "version": "2.7.6" } }, "nbformat": 4, diff --git a/notebooks/2_labeling_functions.ipynb b/notebooks/2_labeling_functions.ipynb index 871dd31..b371ade 100755 --- a/notebooks/2_labeling_functions.ipynb +++ b/notebooks/2_labeling_functions.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "collapsed": false }, @@ -36,7 +36,7 @@ "%matplotlib inline\n", "\n", "import os,sys\n", - "os.environ['SNORKELDB']=\"postgres://jhusson@localhost:5432/snorkel_strom\"\n", + "os.environ['SNORKELDB']=\"postgres:///stromatolite\"\n", "\n", "import numpy as np\n", "from snorkel import SnorkelSession\n", @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "collapsed": false }, @@ -81,9 +81,10 @@ "outputs": [], "source": [ "from snorkel.models import CandidateSet\n", - "\n", - "train = session.query(CandidateSet).filter(CandidateSet.name == 'News Training Candidates').one()\n", - "dev = session.query(CandidateSet).filter(CandidateSet.name == 'News Development Candidates').one()" + "train_candidates = session.query(CandidateSet).filter(CandidateSet.name == 'Training Candidates').one()\n", + "print len(train)\n", + "test_candidates = session.query(CandidateSet).filter(CandidateSet.name == 'Test Candidates').one()\n", + "print len(test)" ] }, { @@ -96,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "collapsed": false }, @@ -111,31 +112,32 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can create a new feature set:" + "We can create a new feature set- note that we _create_ a set of features based on the training candidates, and then featurize the test set using this set of features (using _update_)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "collapsed": false, "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[========================================] 100%===== ] 13%\n", - "\n", - "Loading sparse Feature matrix...\n", - "CPU times: user 2.07 s, sys: 154 ms, total: 2.23 s\n", - "Wall time: 3.2 s\n" - ] - } - ], + "outputs": [], + "source": [ + "%time F_train = feature_manager.create(session, train, 'Training Features')\n", + "F_train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], "source": [ - "%time F_train = feature_manager.create(session, train, 'Train Features')" + "%time F_test = feature_manager.update(session, test, 'Training Features', False)\n", + "F_test" ] }, { @@ -153,7 +155,20 @@ }, "outputs": [], "source": [ - "%time F_train = feature_manager.load(session, train, 'Train Features')" + "F_train = feature_manager.load(session, train_candidates, 'Training Features')\n", + "F_train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "F_test = feature_manager.load(session, test_candidates, 'Training Features')\n", + "F_test" ] }, { @@ -206,7 +221,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": { "collapsed": false }, @@ -232,14 +247,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from snorkel.annotations import LabelManager\n", - "\n", "label_manager = LabelManager()" ] }, @@ -257,67 +271,6 @@ "collapsed": false }, "outputs": [], - "source": [ - "spouses = {'wife', 'husband', 'ex-wife', 'ex-husband'}\n", - "family = {'father', 'mother', 'sister', 'brother', 'son', 'daughter',\n", - " 'grandfather', 'grandmother', 'uncle', 'aunt', 'cousin'}\n", - "family = family | {f + '-in-law' for f in family}\n", - "other = {'boyfriend', 'girlfriend' 'boss', 'employee', 'secretary', 'co-worker'}\n", - "\n", - "def LF_too_far_apart(c):\n", - " return -1 if len(get_between_tokens(c)) > 10 else 0\n", - "\n", - "def LF_third_wheel(c):\n", - " return -1 if 'PERSON' in get_between_tokens(c, attrib='ner_tags', case_sensitive=True) else 0\n", - " \n", - "def LF_husband_wife(c):\n", - " return 1 if len(spouses.intersection(set(get_between_tokens(c)))) > 0 else 0\n", - "\n", - "def LF_husband_wife_left_window(c):\n", - " if len(spouses.intersection(set(get_left_tokens(c[0], window=2)))) > 0:\n", - " return 1\n", - " elif len(spouses.intersection(set(get_left_tokens(c[1], window=2)))) > 0:\n", - " return 1\n", - " else:\n", - " return 0\n", - "\n", - "def LF_no_spouse_in_sentence(c):\n", - " return -1 if len(spouses.intersection(set(c[0].parent.words))) == 0 else 0\n", - "\n", - "def LF_and_married(c):\n", - " return 1 if 'and' in get_between_tokens(c) and 'married' in get_right_tokens(c) else 0\n", - " \n", - "def LF_familial_relationship(c):\n", - " return -1 if len(set(family).intersection(set(get_between_tokens(c)))) > 0 else 0\n", - "\n", - "def LF_family_left_window(c):\n", - " if len(family.intersection(set(get_left_tokens(c[0], window=2)))) > 0:\n", - " return -1\n", - " elif len(family.intersection(set(get_left_tokens(c[1], window=2)))) > 0:\n", - " return -1\n", - " else:\n", - " return 0\n", - "\n", - "def LF_other_relationship(c):\n", - " coworker = ['boss', 'employee', 'secretary', 'co-worker']\n", - " return -1 if len(set(coworker).intersection(set(get_between_tokens(c)))) > 0 else 0" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "StromStrat(Span(\"stromatolites\", parent=193044, chars=[285,297], words=[47,47]), Span(\"Smyrna Bed\", parent=193044, chars=[0,9], words=[0,1]))\n" - ] - } - ], "source": [ "from snorkel.models import CandidateSet\n", "all_c = session.query(CandidateSet).filter(CandidateSet.name == 'Candidate Set').one()\n", @@ -329,46 +282,33 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1\n", - "0\n", - "0\n", - "0\n", - "0\n", - "0\n", - "0\n", - "1\n" - ] - } - ], + "outputs": [], "source": [ "import yaml, psycopg2\n", "from snorkel.models import Span\n", "\n", - "good_words={'strom':{'present','found'},'strat':{'contain','contains','include','includes'}}\n", + "good_words={'strom':{'present','found','abundant'},'strat':{'contain','contains','include','includes'}}\n", "\n", "# Connect to Postgres\n", + "\"\"\"\n", "with open('../credentials', 'r') as credential_yaml:\n", " credentials = yaml.load(credential_yaml)\n", - "\n", "with open('../config', 'r') as config_yaml:\n", " config = yaml.load(config_yaml)\n", + "\"\"\"\n", "\n", "# Connect to Postgres\n", "connection = psycopg2.connect(\n", - " dbname=credentials['snorkel_postgres']['database'],\n", - " user=credentials['snorkel_postgres']['user'],\n", - " password=credentials['snorkel_postgres']['password'],\n", - " host=credentials['snorkel_postgres']['host'],\n", - " port=credentials['snorkel_postgres']['port'])\n", + " dbname= 'stromatolite' #credentials['snorkel_postgres']['database'],\n", + " #user=credentials['snorkel_postgres']['user'],\n", + " #password=credentials['snorkel_postgres']['password'],\n", + " #host=credentials['snorkel_postgres']['host'],\n", + " #port=credentials['snorkel_postgres']['port'])\n", + " )\n", "cursor = connection.cursor()\n", "\n", "\n", @@ -389,32 +329,24 @@ "test=LF_num_stratphrase(c)\n", "print test\n", "\n", - "def LF_wordsep_fifty(c):\n", - " return -1 if len(get_between_tokens(c)) > 50 else 0\n", - "\n", - "test=LF_wordsep_fifty(c)\n", - "print test\n", - "\n", "def LF_wordsep_forty(c):\n", - " return -1 if len(get_between_tokens(c)) > 40 else 0\n", + " ws = len(get_between_tokens(c))\n", + " return -1 if ws > 40 else 0\n", "\n", "test=LF_wordsep_forty(c)\n", "print test\n", "\n", - "def LF_wordsep_thirty(c):\n", - " return -1 if len(get_between_tokens(c)) > 30 else 0\n", - "\n", - "test=LF_wordsep_thirty(c)\n", - "print test\n", "\n", "def LF_wordsep_twenty(c):\n", - " return -1 if len(get_between_tokens(c)) > 20 else 0\n", + " ws = len(get_between_tokens(c))\n", + " return -1 if ws > 20 and ws <= 40 else 0\n", "\n", "test=LF_wordsep_twenty(c)\n", "print test\n", "\n", "def LF_wordsep_ten(c):\n", - " return -1 if len(get_between_tokens(c)) > 10 else 0\n", + " ws = len(get_between_tokens(c))\n", + " return -1 if ws > 10 and ws <= 20 else 0\n", "\n", "test=LF_wordsep_ten(c)\n", "print test\n", @@ -451,67 +383,32 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "LFs = [LF_num_stratphrase,LF_wordsep_fifty,LF_wordsep_forty,LF_wordsep_thirty,LF_wordsep_twenty,LF_wordsep_ten,LF_nlp_parent,LF_goodwords]" + "LFs = [\n", + " #LF_num_stratphrase,\n", + " LF_wordsep_forty,\n", + " LF_wordsep_twenty,\n", + " LF_wordsep_ten,\n", + " LF_nlp_parent,\n", + " LF_goodwords\n", + "]" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "collapsed": false, "scrolled": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[========================================] 100%===== ] 13%\n", - "\n", - "Loading sparse Label matrix...\n", - "CPU times: user 381 ms, sys: 34.8 ms, total: 416 ms\n", - "Wall time: 613 ms\n" - ] - }, - { - "data": { - "text/plain": [ - "<15x7 sparse matrix of type ''\n", - "\twith 26 stored elements in Compressed Sparse Row format>" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%time L_train = label_manager.create(session, train, 'LF Labels', f=LFs)\n", - "L_train" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**OR** load if we've already created:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, "outputs": [], "source": [ - "%time L_train = label_manager.load(session, train, 'LF Labels')\n", + "%time L_train = label_manager.create(session, train_candidates, 'Training LF Labels', f=LFs)\n", "L_train" ] }, @@ -519,7 +416,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can also add or rerun a single labeling function (or more!) with the below command. Note that we set the argument `expand_key_set` to `True` to indicate that the set of matrix columns should be allowed to expand:" + "**OR** load if we've already created:" ] }, { @@ -530,7 +427,7 @@ }, "outputs": [], "source": [ - "L_train = label_manager.update(session, train, 'LF Labels', True, f=[LF_no_spouse_in_sentence])\n", + "%time L_train = label_manager.load(session, train_candidates, 'LF Labels')\n", "L_train" ] }, @@ -543,95 +440,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
conflictscoveragejoverlaps
LF_num_stratphrase00.40000000.333333
LF_wordsep_fifty00.00000010.000000
LF_wordsep_forty00.00000020.000000
LF_wordsep_thirty00.00000030.000000
LF_wordsep_twenty00.13333340.133333
LF_wordsep_ten00.33333350.333333
LF_nlp_parent00.86666760.533333
\n", - "
" - ], - "text/plain": [ - " conflicts coverage j overlaps\n", - "LF_num_stratphrase 0 0.400000 0 0.333333\n", - "LF_wordsep_fifty 0 0.000000 1 0.000000\n", - "LF_wordsep_forty 0 0.000000 2 0.000000\n", - "LF_wordsep_thirty 0 0.000000 3 0.000000\n", - "LF_wordsep_twenty 0 0.133333 4 0.133333\n", - "LF_wordsep_ten 0 0.333333 5 0.333333\n", - "LF_nlp_parent 0 0.866667 6 0.533333" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "L_train.lf_stats()" ] @@ -646,56 +459,16 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "//anaconda/lib/python2.7/site-packages/matplotlib/__init__.py:1318: UserWarning: This call to matplotlib.use() has no effect\n", - "because the backend has already been chosen;\n", - "matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\n", - "or matplotlib.backends is imported for the first time.\n", - "\n", - " warnings.warn(_use_error_msg)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "================================================================================\n", - "Training marginals (!= 0.5):\t15\n", - "Features:\t\t\t7\n", - "================================================================================\n", - "Begin training for rate=1e-05, mu=1e-06\n", - "\tLearning epoch = 0\tGradient mag. = 0.526535\n", - "\tLearning epoch = 250\tGradient mag. = 0.554394\n", - "\tLearning epoch = 500\tGradient mag. = 0.554541\n", - "\tLearning epoch = 750\tGradient mag. = 0.554688\n", - "Final gradient magnitude for rate=1e-05, mu=1e-06: 0.555\n" - ] - } - ], + "outputs": [], "source": [ "from snorkel.learning import NaiveBayes\n", "\n", "gen_model = NaiveBayes()\n", - "gen_model.train(L_train, n_iter=1000, rate=1e-5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "gen_model.save(session, 'Generative Params')" + "gen_model.train(L_train, n_iter=10000, rate=1e-4)" ] }, { @@ -707,7 +480,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "collapsed": false }, @@ -718,23 +491,11 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 1.00078652, 0.99850038, 0.99850038, 0.99850038, 1.00003983,\n", - " 1.00180991, 1.00199087])" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "gen_model.w" ] @@ -747,65 +508,6 @@ "We use the estimated probabilites to train a discriminative model that classifies each `Candidate` as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer." ] }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "from snorkel.learning import LogReg\n", - "from snorkel.learning_utils import RandomSearch, ListParameter, RangeParameter\n", - "\n", - "iter_param = ListParameter('n_iter', [250, 500, 1000, 2000])\n", - "rate_param = RangeParameter('rate', 1e-4, 1e-2, step=0.75, log_base=10)\n", - "reg_param = RangeParameter('mu', 1e-8, 1e-2, step=1, log_base=10)\n", - "\n", - "disc_model = LogReg()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we create features for the development set candidates.\n", - "\n", - "Note that we use the training feature set, because those are the only features for which we have learned parameters. Features that were not encountered during training, e.g., a token that does not appear in the training set, are ignored, because we do not have any information about them.\n", - "\n", - "To do so with the `FeatureManager`, we call update with the new `CandidateSet`, the name of the training `AnnotationKeySet`, and the value `False` for the parameter `extend_key_set` to indicate that the `AnnotationKeySet` should not be expanded with new `Feature` keys encountered during processing." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[========================================] 100%=== ] 9%\n", - "\n", - "Loading sparse Feature matrix...\n", - "CPU times: user 1.68 s, sys: 73.9 ms, total: 1.76 s\n", - "Wall time: 2.13 s\n" - ] - } - ], - "source": [ - "%time F_dev = feature_manager.update(session, dev, 'Train Features', False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**OR** if we've already created one, we can simply load as follows:" - ] - }, { "cell_type": "code", "execution_count": null, @@ -814,56 +516,15 @@ }, "outputs": [], "source": [ - "%time F_dev = feature_manager.load(session, dev, 'Train Features')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we load the development set labels and gold candidates we made in Part III." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "L_gold_dev = label_manager.load(session, dev, \"News Gold Labels\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "gold_dev_set = session.query(CandidateSet).filter(CandidateSet.name == 'News Development Candidates').one()" + "from snorkel.learning import LogReg\n", + "disc_model = LogReg(bias_term=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.\n", - "\n", - "Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "searcher = RandomSearch(disc_model, F_train, train_marginals, 10, iter_param, rate_param, reg_param)" + "**Note: Here, we're training our model with hand-tuned hyperparameters... another option (the better one at some point) is to use some of our ground-truth-labeled candidates to serve as a \"dev set\" to automatically tune the model hyperparameters. See the tutorial for this**" ] }, { @@ -874,18 +535,14 @@ }, "outputs": [], "source": [ - "np.random.seed(1701)\n", - "searcher.fit(F_dev, L_gold_dev, gold_dev_set)" + "disc_model.train(F_train, train_marginals, n_iter=1000, rate=0.01, mu=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "_Note that to train a model without tuning any hyperparameters--at your own risk!--just use the `train` method of the discriminative model. For instance, to train with 500 iterations and a learning rate of 0.001, you could run:_\n", - "```\n", - "disc_model.train(F_train, train_marginals, n_iter=500, rate=0.001)\n", - "```" + "### Scoring against the test set" ] }, { @@ -896,7 +553,8 @@ }, "outputs": [], "source": [ - "disc_model.w.shape" + "L_gold_test = label_manager.load(session, test_candidates, 'iross')\n", + "L_gold_test" ] }, { @@ -907,18 +565,7 @@ }, "outputs": [], "source": [ - "%time disc_model.save(session, \"Discriminative Params\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "tp, fp, tn, fn = disc_model.score(F_dev, L_gold_dev, gold_dev_set)" + "tp, fp, tn, fn = disc_model.score(F_test, L_gold_test, set_unlabeled_as_neg=False)" ] }, { @@ -939,32 +586,9 @@ "source": [ "from snorkel.viewer import SentenceNgramViewer\n", "\n", - "# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook\n", - "# You should ignore this!\n", - "import os\n", - "if 'CI' not in os.environ:\n", - " sv = SentenceNgramViewer(fn, session, annotator_name=\"Tutorial Part IV User\")\n", - "else:\n", - " sv = None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ + "sv = SentenceNgramViewer(fn, session)\n", "sv" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, in Part V, we will test our model on the test `CandidateSet`." - ] } ], "metadata": { @@ -984,7 +608,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.11" + "version": "2.7.6" } }, "nbformat": 4, diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..951c9dd --- /dev/null +++ b/run.sh @@ -0,0 +1,6 @@ +# Set & move to home directory +source set_env.sh + +# Launch jupyter notebook! +echo "Launching Jupyter Notebook..." +jupyter notebook diff --git a/set_env.sh b/set_env.sh new file mode 100755 index 0000000..955ac37 --- /dev/null +++ b/set_env.sh @@ -0,0 +1,9 @@ +export APPHOME="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && cd .. && pwd )" +export SNORKELHOME="$APPHOME/snorkel" +export DDBIOLIBHOME="$APPHOME/ddbiolib" +echo "Application home directory: $APPHOME" +echo "Snorkel home directory: $SNORKELHOME" +echo "ddbiolib home directory: $DDBIOLIBHOME" +export PYTHONPATH="$PYTHONPATH:$APPHOME:$DDBIOLIBHOME:$SNORKELHOME:$SNORKELHOME/treedlib" +export PATH="$PATH:$SNORKELHOME:$DDBIOLIBHOME:$SNORKELHOME/treedlib" +echo "Environment variables set!"