diff --git a/notebooks/1_candidate_generation.ipynb b/notebooks/1_candidate_generation.ipynb index e14decb..72d2730 100644 --- a/notebooks/1_candidate_generation.ipynb +++ b/notebooks/1_candidate_generation.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "collapsed": false }, @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "collapsed": false }, @@ -42,7 +42,7 @@ "source": [ "from snorkel.models import Sentence\n", "\n", - "sentences = session.query(Sentence).all()" + "sentences = session.query(Sentence).limit(10000).all()" ] }, { @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "collapsed": false }, @@ -65,7 +65,7 @@ "source": [ "from snorkel.models import candidate_subclass\n", "\n", - "strom = candidate_subclass('strom', ['strom', 'strat_name'])" + "StromStrat = candidate_subclass('StromStrat', ['strom', 'stratname'])" ] }, { @@ -88,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "collapsed": true }, @@ -96,7 +96,8 @@ "source": [ "from snorkel.candidates import Ngrams\n", "\n", - "ngrams = Ngrams(n_max=3)" + "ngram_strom = Ngrams(n_max=1)\n", + "ngram_strat = Ngrams(n_max=9)" ] }, { @@ -108,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "collapsed": false }, @@ -116,14 +117,14 @@ "source": [ "from snorkel.matchers import RegexMatchSpan\n", "\n", - "strom_matcher = RegexMatchSpan(rgx=\"stromatol|thrombol\")" + "strom_matcher = RegexMatchSpan(rgx=\"stromatolit|thrombolit\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { - "collapsed": true + "collapsed": false }, "outputs": [], "source": [ @@ -134,10 +135,30 @@ "request = urllib.urlopen('https://macrostrat.org/api/v2/defs/strat_names?all')\n", "data = json.loads(request.read())\n", "\n", - "strat_dict = { r['strat_name_long'] for r in data['success']['data'] }\n", + "#FULL STRAT NAME\n", + "strat_dict_long = { r['strat_name_long'] for r in data['success']['data'] }\n", "\n", + "#ABBREVIATED STRAT NAME - V1\n", + "strat_dict_abV1 = { r['strat_name'] + ' ' + r['rank'] for r in data['success']['data'] }\n", "\n", - "strat_matcher=DictionaryMatch(d=strat_dict,ignore_case=False,longest_match_only=True)" + "#ABBREVIATED STRAT NAME - V2\n", + "strat_dict_abV2 = { r['strat_name'] + ' ' + r['rank'] + '.' for r in data['success']['data'] }\n", + "\n", + "#LITHOLOGY STRAT NAMES\n", + "request = urllib.urlopen('https://macrostrat.org/api/v2/defs/lithologies?all')\n", + "lithologies = json.loads(request.read())\n", + "lithologies=[l['name'].capitalize() for l in lithologies['success']['data']]\n", + "\n", + "strat_dict_short = { r['strat_name'] for r in data['success']['data'] }\n", + "\n", + "strat_dict_lith=set()\n", + "for r in strat_dict_short:\n", + " if r.split(' ')[-1] in lithologies:\n", + " strat_dict_lith.add(r)\n", + " \n", + "strat_dict=set(list(strat_dict_long)+list(strat_dict_abV1)+list(strat_dict_abV2)+list(strat_dict_lith))\n", + " \n", + "strat_matcher=DictionaryMatch(d=strat_dict,ignore_case=False,longest_match_only=True)\n" ] }, { @@ -149,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false }, @@ -157,8 +178,8 @@ "source": [ "from snorkel.candidates import CandidateExtractor\n", "\n", - "ce = CandidateExtractor(Spouse, [ngrams, ngrams], [person_matcher, person_matcher],\n", - " symmetric_relations=False, nested_relations=False, self_relations=False)" + "ce = CandidateExtractor(StromStrat, [ngram_strom, ngram_strat], [strom_matcher, strat_matcher],\n", + " symmetric_relations=True, nested_relations=False, self_relations=False)" ] }, { @@ -177,14 +198,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false, "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[========================================] 100% ] 0%\n", + "\n", + "CPU times: user 46.4 s, sys: 5.86 s, total: 52.3 s\n", + "Wall time: 1min 1s\n", + "Number of candidates: 36\n" + ] + } + ], "source": [ - "%time c = ce.extract(sentences, 'News Training Candidates', session)\n", + "%time c = ce.extract(sentences, 'Candidate Set', session)\n", "print \"Number of candidates:\", len(c)" ] }, @@ -197,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "collapsed": false }, @@ -216,15 +249,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StromStrat(Span(\"stromatolites\", parent=12144, chars=[117,129], words=[18,18]), Span(\"Rae Group\", parent=12144, chars=[138,146], words=[21,22]))\n" + ] + } + ], "source": [ "from snorkel.models import CandidateSet\n", - "c = session.query(CandidateSet).filter(CandidateSet.name == 'News Training Candidates').one()\n", - "c" + "c = session.query(CandidateSet).filter(CandidateSet.name == 'Candidate Set').one()\n", + "c=c.candidates[-5]\n", + "print c" ] }, { @@ -242,12 +284,345 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": { "collapsed": false, "scrolled": true }, - "outputs": [], + "outputs": [ + { + "data": { + "application/javascript": [ + "require.undef('viewer');\n", + "\n", + "// NOTE: all elements should be selected using this.$el.find to avoid collisions with other Viewers\n", + "\n", + "define('viewer', [\"jupyter-js-widgets\"], function(widgets) {\n", + " var ViewerView = widgets.DOMWidgetView.extend({\n", + " render: function() {\n", + " this.cids = this.model.get('cids');\n", + " this.nPages = this.cids.length;\n", + " this.pid = 0;\n", + " this.cxid = 0;\n", + " this.cid = 0;\n", + "\n", + " // Insert the html payload\n", + " this.$el.append(this.model.get('html'));\n", + "\n", + " // Initialize all labels from previous sessions\n", + " this.labels = this.deserializeDict(this.model.get('_labels_serialized'));\n", + " for (var i=0; i < this.nPages; i++) {\n", + " this.pid = i;\n", + " for (var j=0; j < this.cids[i].length; j++) {\n", + " this.cxid = j;\n", + " for (var k=0; k < this.cids[i][j].length; k++) {\n", + " this.cid = k;\n", + " if (this.cids[i][j][k] in this.labels) {\n", + " this.markCurrentCandidate(false);\n", + " }\n", + " }\n", + " }\n", + " }\n", + " this.pid = 0;\n", + " this.cxid = 0;\n", + " this.cid = 0;\n", + "\n", + " // Enable button functionality for navigation\n", + " var that = this;\n", + " this.$el.find(\"#next-cand\").click(function() {\n", + " that.switchCandidate(1);\n", + " });\n", + " this.$el.find(\"#prev-cand\").click(function() {\n", + " that.switchCandidate(-1);\n", + " });\n", + " this.$el.find(\"#next-context\").click(function() {\n", + " that.switchContext(1);\n", + " });\n", + " this.$el.find(\"#prev-context\").click(function() {\n", + " that.switchContext(-1);\n", + " });\n", + " this.$el.find(\"#next-page\").click(function() {\n", + " that.switchPage(1);\n", + " });\n", + " this.$el.find(\"#prev-page\").click(function() {\n", + " that.switchPage(-1);\n", + " });\n", + " this.$el.find(\"#label-true\").click(function() {\n", + " that.labelCandidate(true, true);\n", + " });\n", + " this.$el.find(\"#label-false\").click(function() {\n", + " that.labelCandidate(false, true);\n", + " });\n", + "\n", + " // Arrow key functionality\n", + " this.$el.keydown(function(e) {\n", + " switch(e.which) {\n", + " case 74: // j\n", + " that.switchCandidate(-1);\n", + " break;\n", + "\n", + " case 73: // i\n", + " that.switchPage(-1);\n", + " break;\n", + "\n", + " case 76: // l\n", + " that.switchCandidate(1);\n", + " break;\n", + "\n", + " case 75: // k\n", + " that.switchPage(1);\n", + " break;\n", + "\n", + " case 84: // t\n", + " that.labelCandidate(true, true);\n", + " break;\n", + "\n", + " case 70: // f\n", + " that.labelCandidate(false, true);\n", + " break;\n", + " }\n", + " });\n", + "\n", + " // Show the first page and highlight the first candidate\n", + " this.$el.find(\"#viewer-page-0\").show();\n", + " this.switchCandidate(0);\n", + " },\n", + "\n", + " // Get candidate selector for currently selected candidate, escaping id properly\n", + " getCandidate: function() {\n", + " return this.$el.find(\".\"+this.cids[this.pid][this.cxid][this.cid]);\n", + " }, \n", + "\n", + " // Color the candidate correctly according to registered label, as well as set highlighting\n", + " markCurrentCandidate: function(highlight) {\n", + " var cid = this.cids[this.pid][this.cxid][this.cid];\n", + " var tags = this.$el.find(\".\"+cid);\n", + "\n", + " // Clear color classes\n", + " tags.removeClass(\"candidate-h\");\n", + " tags.removeClass(\"true-candidate\");\n", + " tags.removeClass(\"true-candidate-h\");\n", + " tags.removeClass(\"false-candidate\");\n", + " tags.removeClass(\"false-candidate-h\");\n", + " tags.removeClass(\"highlighted\");\n", + "\n", + " if (highlight) {\n", + " if (cid in this.labels) {\n", + " tags.addClass(String(this.labels[cid]) + \"-candidate-h\");\n", + " } else {\n", + " tags.addClass(\"candidate-h\");\n", + " }\n", + " \n", + " // If un-highlighting, leave with first non-null coloring\n", + " } else {\n", + " var that = this;\n", + " tags.each(function() {\n", + " var cids = $(this).attr('class').split(/\\s+/).map(function(item) {\n", + " return parseInt(item);\n", + " });\n", + " cids.sort();\n", + " for (var i in cids) {\n", + " if (cids[i] in that.labels) {\n", + " var label = that.labels[cids[i]];\n", + " $(this).addClass(String(label) + \"-candidate\");\n", + " $(this).removeClass(String(!label) + \"-candidate\");\n", + " break;\n", + " }\n", + " }\n", + " });\n", + " }\n", + "\n", + " // Extra highlighting css\n", + " if (highlight) {\n", + " tags.addClass(\"highlighted\");\n", + " }\n", + "\n", + " // Classes for showing direction of relation\n", + " if (highlight) {\n", + " this.$el.find(\".\"+cid+\"-0\").addClass(\"left-candidate\");\n", + " this.$el.find(\".\"+cid+\"-1\").addClass(\"right-candidate\");\n", + " } else {\n", + " this.$el.find(\".\"+cid+\"-0\").removeClass(\"left-candidate\");\n", + " this.$el.find(\".\"+cid+\"-1\").removeClass(\"right-candidate\");\n", + " }\n", + " },\n", + "\n", + " // Cycle through candidates and highlight, by increment inc\n", + " switchCandidate: function(inc) {\n", + " var N = this.cids[this.pid].length\n", + " var M = this.cids[this.pid][this.cxid].length;\n", + " if (N == 0 || M == 0) { return false; }\n", + "\n", + " // Clear highlighting from previous candidate\n", + " if (inc != 0) {\n", + " this.markCurrentCandidate(false);\n", + "\n", + " // Increment the cid counter\n", + "\n", + " // Move to next context\n", + " if (this.cid + inc >= M) {\n", + " while (this.cid + inc >= M) {\n", + " \n", + " // At last context on page, halt\n", + " if (this.cxid == N - 1) {\n", + " this.cid = M - 1;\n", + " inc = 0;\n", + " break;\n", + " \n", + " // Increment to next context\n", + " } else {\n", + " inc -= M - this.cid;\n", + " this.cxid += 1;\n", + " M = this.cids[this.pid][this.cxid].length;\n", + " this.cid = 0;\n", + " }\n", + " }\n", + "\n", + " // Move to previous context\n", + " } else if (this.cid + inc < 0) {\n", + " while (this.cid + inc < 0) {\n", + " \n", + " // At first context on page, halt\n", + " if (this.cxid == 0) {\n", + " this.cid = 0;\n", + " inc = 0;\n", + " break;\n", + " \n", + " // Increment to previous context\n", + " } else {\n", + " inc += this.cid + 1;\n", + " this.cxid -= 1;\n", + " M = this.cids[this.pid][this.cxid].length;\n", + " this.cid = M - 1;\n", + " }\n", + " }\n", + " }\n", + "\n", + " // Move within current context\n", + " this.cid += inc;\n", + " }\n", + " this.markCurrentCandidate(true);\n", + "\n", + " // Push this new cid to the model\n", + " this.model.set('_selected_cid', this.cids[this.pid][this.cxid][this.cid]);\n", + " this.touch();\n", + " },\n", + "\n", + " // Switch through contexts\n", + " switchContext: function(inc) {\n", + " this.markCurrentCandidate(false);\n", + "\n", + " // Iterate context on this page\n", + " var M = this.cids[this.pid].length;\n", + " if (this.cxid + inc < 0) {\n", + " this.cxid = 0;\n", + " } else if (this.cxid + inc >= M) {\n", + " this.cxid = M - 1;\n", + " } else {\n", + " this.cxid += inc;\n", + " }\n", + "\n", + " // Reset cid and set to first candidate\n", + " this.cid = 0;\n", + " this.switchCandidate(0);\n", + " },\n", + "\n", + " // Switch through pages\n", + " switchPage: function(inc) {\n", + " this.markCurrentCandidate(false);\n", + " this.$el.find(\".viewer-page\").hide();\n", + " if (this.pid + inc < 0) {\n", + " this.pid = 0;\n", + " } else if (this.pid + inc > this.nPages - 1) {\n", + " this.pid = this.nPages - 1;\n", + " } else {\n", + " this.pid += inc;\n", + " }\n", + " this.$el.find(\"#viewer-page-\"+this.pid).show();\n", + "\n", + " // Show pagination\n", + " this.$el.find(\"#page\").html(this.pid);\n", + "\n", + " // Reset cid and set to first candidate\n", + " this.cid = 0;\n", + " this.cxid = 0;\n", + " this.switchCandidate(0);\n", + " },\n", + "\n", + " // Label currently-selected candidate\n", + " labelCandidate: function(label, highlighted) {\n", + " var c = this.getCandidate();\n", + " var cid = this.cids[this.pid][this.cxid][this.cid];\n", + " var cl = String(label) + \"-candidate\";\n", + " var clh = String(label) + \"-candidate-h\";\n", + " var cln = String(!label) + \"-candidate\";\n", + " var clnh = String(!label) + \"-candidate-h\";\n", + "\n", + " // Toggle label highlighting\n", + " if (c.hasClass(cl) || c.hasClass(clh)) {\n", + " c.removeClass(cl);\n", + " c.removeClass(clh);\n", + " if (highlighted) {\n", + " c.addClass(\"candidate-h\");\n", + " }\n", + " this.labels[cid] = null;\n", + " this.send({event: 'delete_label', cid: cid});\n", + " } else {\n", + " c.removeClass(cln);\n", + " c.removeClass(clnh);\n", + " if (highlighted) {\n", + " c.addClass(clh);\n", + " } else {\n", + " c.addClass(cl);\n", + " }\n", + " this.labels[cid] = label;\n", + " this.send({event: 'set_label', cid: cid, value: label});\n", + " }\n", + "\n", + " // Set the label and pass back to the model\n", + " this.model.set('_labels_serialized', this.serializeDict(this.labels));\n", + " this.touch();\n", + " },\n", + "\n", + " // Serialization of hash maps, because traitlets Dict doesn't seem to work...\n", + " serializeDict: function(d) {\n", + " var s = [];\n", + " for (var key in d) {\n", + " s.push(key+\"~~\"+d[key]);\n", + " }\n", + " return s.join();\n", + " },\n", + "\n", + " // Deserialization of hash maps\n", + " deserializeDict: function(s) {\n", + " var d = {};\n", + " var entries = s.split(/,/);\n", + " var kv;\n", + " for (var i in entries) {\n", + " kv = entries[i].split(/~~/);\n", + " if (kv[1] == \"true\") {\n", + " d[kv[0]] = true;\n", + " } else if (kv[1] == \"false\") {\n", + " d[kv[0]] = false;\n", + " }\n", + " }\n", + " return d;\n", + " },\n", + " });\n", + "\n", + " return {\n", + " ViewerView: ViewerView\n", + " };\n", + "});\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from snorkel.viewer import SentenceNgramViewer\n", "\n", @@ -255,7 +630,7 @@ "# You should ignore this!\n", "import os\n", "if 'CI' not in os.environ:\n", - " sv = SentenceNgramViewer(c[:300], session, annotator_name=\"Tutorial Part 2 User\")\n", + " sv = SentenceNgramViewer(c, session, annotator_name=\"Tutorial Part 2 User\")\n", "else:\n", " sv = None" ] @@ -269,7 +644,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": { "collapsed": false }, @@ -287,14 +662,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StromStrat(Span(\"stromatolitic\", parent=11957, chars=[16,28], words=[3,3]), Span(\"Glenelg Formation\", parent=11957, chars=[68,84], words=[12,13]))\n", + "11957\n" + ] + } + ], "source": [ "if 'CI' not in os.environ:\n", - " print unicode(sv.get_selected())" + " print unicode(sv.get_selected())\n", + " print sv.get_selected()[0].parent.id" ] }, { @@ -332,12 +717,21 @@ "source": [ "Next, in Part 3, we will annotate some candidates with labels so that we can evaluate performance." ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python [default]", + "display_name": "Python 2", "language": "python", "name": "python2" }, @@ -351,7 +745,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.12" + "version": "2.7.11" } }, "nbformat": 4, diff --git a/notebooks/2_labeling_functions.ipynb b/notebooks/2_labeling_functions.ipynb new file mode 100755 index 0000000..96b4dd2 --- /dev/null +++ b/notebooks/2_labeling_functions.ipynb @@ -0,0 +1,798 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Intro. to Snorkel: Extracting Spouse Relations from the News" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part IV: Training a Model with Data Programming" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this part of the tutorial, we will train a statistical model to differentiate between true and false `Spouse` mentions.\n", + "\n", + "We will train this model using _data programming_, and we will **ignore** the training labels provided with the training data. This is a more realistic scenario; in the wild, hand-labeled training data is rare and expensive. Data programming enables us to train a model using only a modest amount of hand-labeled data for validation and testing. For more information on data programming, see the [NIPS 2016 paper](https://arxiv.org/abs/1605.07723)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "#%autoreload 2\n", + "%matplotlib inline\n", + "\n", + "import os\n", + "os.environ['SNORKELDB']=\"postgres://jhusson@localhost:5432/snorkel_strom\"\n", + "\n", + "import numpy as np\n", + "from snorkel import SnorkelSession\n", + "session = SnorkelSession()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.models import candidate_subclass\n", + "\n", + "StromStrat = candidate_subclass('StromStrat', ['strom', 'stratname'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We repeat our definition of the `Spouse` `Candidate` subclass from Parts II and III." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading `CandidateSet` objects\n", + "\n", + "We reload the training and development `CandidateSet` objects from the previous parts of the tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.models import CandidateSet\n", + "\n", + "train = session.query(CandidateSet).filter(CandidateSet.name == 'News Training Candidates').one()\n", + "dev = session.query(CandidateSet).filter(CandidateSet.name == 'News Development Candidates').one()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Automatically Creating Features\n", + "Recall that our goal is to distinguish between true and false mentions of spouse relations. To train a model for this task, we first embed our `Spouse` candidates in a feature space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.annotations import FeatureManager\n", + "\n", + "feature_manager = FeatureManager()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can create a new feature set:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [], + "source": [ + "%time F_train = feature_manager.create(session, train, 'Train Features')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**OR** if we've already created one, we can simply load as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%time F_train = feature_manager.load(session, train, 'Train Features')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the returned matrix is a special subclass of the `scipy.sparse.csr_matrix` class, with some special features which we demonstrate below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "F_train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "F_train.get_candidate(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "F_train.get_key(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating Labeling Functions\n", + "Labeling functions are a core tool of data programming. They are heuristic functions that aim to classify candidates correctly. Their outputs will be automatically combined and denoised to estimate the probabilities of training labels for the training data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n", + "from snorkel.lf_helpers import get_left_tokens, get_right_tokens, get_between_tokens, get_text_between" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applying Labeling Functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we construct a `LabelManager`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.annotations import LabelManager\n", + "\n", + "label_manager = LabelManager()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next we run the `LabelManager` to to apply the labeling functions to the training `CandidateSet`. We'll start with some of our labeling functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "spouses = {'wife', 'husband', 'ex-wife', 'ex-husband'}\n", + "family = {'father', 'mother', 'sister', 'brother', 'son', 'daughter',\n", + " 'grandfather', 'grandmother', 'uncle', 'aunt', 'cousin'}\n", + "family = family | {f + '-in-law' for f in family}\n", + "other = {'boyfriend', 'girlfriend' 'boss', 'employee', 'secretary', 'co-worker'}\n", + "\n", + "def LF_too_far_apart(c):\n", + " return -1 if len(get_between_tokens(c)) > 10 else 0\n", + "\n", + "def LF_third_wheel(c):\n", + " return -1 if 'PERSON' in get_between_tokens(c, attrib='ner_tags', case_sensitive=True) else 0\n", + " \n", + "def LF_husband_wife(c):\n", + " return 1 if len(spouses.intersection(set(get_between_tokens(c)))) > 0 else 0\n", + "\n", + "def LF_husband_wife_left_window(c):\n", + " if len(spouses.intersection(set(get_left_tokens(c[0], window=2)))) > 0:\n", + " return 1\n", + " elif len(spouses.intersection(set(get_left_tokens(c[1], window=2)))) > 0:\n", + " return 1\n", + " else:\n", + " return 0\n", + "\n", + "def LF_no_spouse_in_sentence(c):\n", + " return -1 if len(spouses.intersection(set(c[0].parent.words))) == 0 else 0\n", + "\n", + "def LF_and_married(c):\n", + " return 1 if 'and' in get_between_tokens(c) and 'married' in get_right_tokens(c) else 0\n", + " \n", + "def LF_familial_relationship(c):\n", + " return -1 if len(set(family).intersection(set(get_between_tokens(c)))) > 0 else 0\n", + "\n", + "def LF_family_left_window(c):\n", + " if len(family.intersection(set(get_left_tokens(c[0], window=2)))) > 0:\n", + " return -1\n", + " elif len(family.intersection(set(get_left_tokens(c[1], window=2)))) > 0:\n", + " return -1\n", + " else:\n", + " return 0\n", + "\n", + "def LF_other_relationship(c):\n", + " coworker = ['boss', 'employee', 'secretary', 'co-worker']\n", + " return -1 if len(set(coworker).intersection(set(get_between_tokens(c)))) > 0 else 0" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StromStrat(Span(\"stromatolites\", parent=12144, chars=[117,129], words=[18,18]), Span(\"Rae Group\", parent=12144, chars=[138,146], words=[21,22]))\n" + ] + } + ], + "source": [ + "from snorkel.models import CandidateSet\n", + "c = session.query(CandidateSet).filter(CandidateSet.name == 'Candidate Set').one()\n", + "c=c.candidates[-5]\n", + "print c" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "0\n", + "0\n", + "0\n", + "0\n", + "0\n", + "stromatolites\n", + "Rae Group\n", + " t\n", + "StromStrat(Span(\"stromatolites\", parent=12144, chars=[117,129], words=[18,18]), Span(\"Rae Group\", parent=12144, chars=[138,146], words=[21,22]))\n" + ] + } + ], + "source": [ + "import yaml, psycopg2\n", + "from snorkel.models import Span\n", + "\n", + "\n", + "# Connect to Postgres\n", + "with open('../credentials', 'r') as credential_yaml:\n", + " credentials = yaml.load(credential_yaml)\n", + "\n", + "with open('../config', 'r') as config_yaml:\n", + " config = yaml.load(config_yaml)\n", + "\n", + "# Connect to Postgres\n", + "connection = psycopg2.connect(\n", + " dbname=credentials['snorkel_postgres']['database'],\n", + " user=credentials['snorkel_postgres']['user'],\n", + " password=credentials['snorkel_postgres']['password'],\n", + " host=credentials['snorkel_postgres']['host'],\n", + " port=credentials['snorkel_postgres']['port'])\n", + "cursor = connection.cursor()\n", + "\n", + "\n", + "def LF_num_stratphrase(c):\n", + " cursor.execute(\"\"\"\n", + " SELECT distinct span.id from span \n", + " JOIN strom_strat on span.id=strom_strat.stratname_id \n", + " WHERE span.parent_id=%(parent_id)s;\"\"\",\n", + " {\"parent_id\": c[0].parent.id\n", + " })\n", + " tmp_span=cursor.fetchall()\n", + "\n", + " tmp_strat = session.query(Span).filter(Span.id.in_(tmp_span)).all()\n", + " num_strat = len({a.get_span() for a in tmp_strat})\n", + "\n", + " return -1 if num_strat > 1 else 0\n", + "\n", + "test=LF_num_stratphrase(c)\n", + "print test\n", + "\n", + "def LF_wordsep_fifty(c):\n", + " return -1 if len(get_between_tokens(c)) > 50 else 0\n", + "\n", + "test=LF_wordsep_fifty(c)\n", + "print test\n", + "\n", + "def LF_wordsep_forty(c):\n", + " return -1 if len(get_between_tokens(c)) > 40 else 0\n", + "\n", + "test=LF_wordsep_forty(c)\n", + "print test\n", + "\n", + "def LF_wordsep_thirty(c):\n", + " return -1 if len(get_between_tokens(c)) > 30 else 0\n", + "\n", + "test=LF_wordsep_thirty(c)\n", + "print test\n", + "\n", + "def LF_wordsep_twenty(c):\n", + " return -1 if len(get_between_tokens(c)) > 20 else 0\n", + "\n", + "test=LF_wordsep_twenty(c)\n", + "print test\n", + "\n", + "def LF_wordsep_ten(c):\n", + " return -1 if len(get_between_tokens(c)) > 10 else 0\n", + "\n", + "test=LF_wordsep_ten(c)\n", + "print test\n", + "\n", + "#print c[0].get_attrib_span(a='dep_parents')\n", + "print c[0].get_attrib_span(a='words')\n", + "#print c[1].get_attrib_span(a='dep_parents')\n", + "print c[1].get_attrib_span(a='words')\n", + "print c[1].get_attrib_span(a='text')\n", + "print c" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "LFs = [LF_too_far_apart, LF_third_wheel, LF_husband_wife, LF_husband_wife_left_window,\n", + " LF_and_married, LF_familial_relationship, LF_other_relationship]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [], + "source": [ + "%time L_train = label_manager.create(session, train, 'LF Labels', f=LFs)\n", + "L_train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**OR** load if we've already created:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%time L_train = label_manager.load(session, train, 'LF Labels')\n", + "L_train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also add or rerun a single labeling function (or more!) with the below command. Note that we set the argument `expand_key_set` to `True` to indicate that the set of matrix columns should be allowed to expand:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "L_train = label_manager.update(session, train, 'LF Labels', True, f=[LF_no_spouse_in_sentence])\n", + "L_train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can view statistics about the resulting label matrix:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "L_train.lf_stats()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fitting the Generative Model\n", + "We estimate the accuracies of the labeling functions without supervision. Specifically, we estimate the parameters of a `NaiveBayes` generative model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.learning import NaiveBayes\n", + "\n", + "gen_model = NaiveBayes()\n", + "gen_model.train(L_train, n_iter=1000, rate=1e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "gen_model.save(session, 'Generative Params')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now apply the generative model to the training candidates." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "train_marginals = gen_model.marginals(L_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "gen_model.w" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training the Discriminative Model\n", + "We use the estimated probabilites to train a discriminative model that classifies each `Candidate` as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.learning import LogReg\n", + "from snorkel.learning_utils import RandomSearch, ListParameter, RangeParameter\n", + "\n", + "iter_param = ListParameter('n_iter', [250, 500, 1000, 2000])\n", + "rate_param = RangeParameter('rate', 1e-4, 1e-2, step=0.75, log_base=10)\n", + "reg_param = RangeParameter('mu', 1e-8, 1e-2, step=1, log_base=10)\n", + "\n", + "disc_model = LogReg()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we create features for the development set candidates.\n", + "\n", + "Note that we use the training feature set, because those are the only features for which we have learned parameters. Features that were not encountered during training, e.g., a token that does not appear in the training set, are ignored, because we do not have any information about them.\n", + "\n", + "To do so with the `FeatureManager`, we call update with the new `CandidateSet`, the name of the training `AnnotationKeySet`, and the value `False` for the parameter `extend_key_set` to indicate that the `AnnotationKeySet` should not be expanded with new `Feature` keys encountered during processing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%time F_dev = feature_manager.update(session, dev, 'Train Features', False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**OR** if we've already created one, we can simply load as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%time F_dev = feature_manager.load(session, dev, 'Train Features')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we load the development set labels and gold candidates we made in Part III." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "L_gold_dev = label_manager.load(session, dev, \"News Gold Labels\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "gold_dev_set = session.query(CandidateSet).filter(CandidateSet.name == 'News Development Candidates').one()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.\n", + "\n", + "Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "searcher = RandomSearch(disc_model, F_train, train_marginals, 10, iter_param, rate_param, reg_param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "np.random.seed(1701)\n", + "searcher.fit(F_dev, L_gold_dev, gold_dev_set)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Note that to train a model without tuning any hyperparameters--at your own risk!--just use the `train` method of the discriminative model. For instance, to train with 500 iterations and a learning rate of 0.001, you could run:_\n", + "```\n", + "disc_model.train(F_train, train_marginals, n_iter=500, rate=0.001)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "disc_model.w.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%time disc_model.save(session, \"Discriminative Params\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "tp, fp, tn, fn = disc_model.score(F_dev, L_gold_dev, gold_dev_set)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Viewing Examples\n", + "After evaluating on the development `CandidateSet`, the labeling functions can be modified. Try changing the labeling functions to improve performance. You can view the true positives, false positives, true negatives, and false negatives using the `Viewer`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.viewer import SentenceNgramViewer\n", + "\n", + "# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook\n", + "# You should ignore this!\n", + "import os\n", + "if 'CI' not in os.environ:\n", + " sv = SentenceNgramViewer(fn, session, annotator_name=\"Tutorial Part IV User\")\n", + "else:\n", + " sv = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, in Part V, we will test our model on the test `CandidateSet`." + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}