diff --git a/notebooks/1_candidate_generation.ipynb b/notebooks/1_candidate_generation.ipynb new file mode 100644 index 0000000..d4915d5 --- /dev/null +++ b/notebooks/1_candidate_generation.ipynb @@ -0,0 +1,726 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1: Candidate Extraction" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from snorkel import SnorkelSession\n", + "session = SnorkelSession()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the `Corpus`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Corpus (News Training)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from snorkel.models import Corpus\n", + "\n", + "corpus = session.query(Corpus).filter(Corpus.name == 'Training').one()\n", + "corpus" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining a `Candidate` schema\n", + "We now define the schema of the relation mention we want to extract (which is also the schema of the candidates). This must be a subclass of `Candidate`, and we define it using a helper function.\n", + "\n", + "Here we'll define a binary _spouse relation mention_ which connects two `Span` objects of text. Note that this function will create the table in the database backend if it does not exist:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.models import candidate_subclass\n", + "\n", + "Spouse = candidate_subclass('Spouse', ['person1', 'person2'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Writing a basic `CandidateExtractor`\n", + "\n", + "Next, we'll write a basic function to extract **candidate spouse relation mentions** from the corpus. The `SentenceParser` we used in Part I is built on [CoreNLP](http://stanfordnlp.github.io/CoreNLP/), which performs _named entity recognition_ for us.\n", + "\n", + "We will extract `Candidate` objects of the `Spouse` type by identifying, for each `Sentence`, all pairs of ngrams (up to trigrams) that were tagged as people." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we define a child context space for our sentences." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from snorkel.candidates import Ngrams\n", + "\n", + "ngrams = Ngrams(n_max=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we use a `PersonMatcher` to enforce that candidate relations are composed of pairs of spans that were tagged as people by the `SentenceParser`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.matchers import PersonMatcher\n", + "\n", + "person_matcher = PersonMatcher(longest_match_only=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we combine the candidate class, child context space, and matcher into an extractor." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from snorkel.candidates import CandidateExtractor\n", + "\n", + "ce = CandidateExtractor(Spouse, [ngrams, ngrams], [person_matcher, person_matcher],\n", + " symmetric_relations=False, nested_relations=False, self_relations=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running the `CandidateExtractor`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We run the `CandidateExtractor` by calling extract with the contexts to extract from, a name for the `CandidateSet` that will contain the results, and the current session." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[========================================] 100%\n", + "\n", + "CPU times: user 1.53 s, sys: 197 ms, total: 1.73 s\n", + "Wall time: 1.59 s\n", + "Number of candidates: 65\n" + ] + } + ], + "source": [ + "%time c = ce.extract(sentences, 'News Training Candidates', session)\n", + "print \"Number of candidates:\", len(c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Saving the extracted candidates" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "session.add(c)\n", + "session.commit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reloading the candidates" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Candidate Set (News Training Candidates)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from snorkel.models import CandidateSet\n", + "c = session.query(CandidateSet).filter(CandidateSet.name == 'News Training Candidates').one()\n", + "c" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the `Viewer` to inspect candidates\n", + "\n", + "Next, we'll use the `Viewer` class--here, specifically, the `SentenceNgramViewer`--to inspect the data.\n", + "\n", + "It is important to note, our goal here is to **maximize the recall of true candidates** extracted, **not** to extract _only_ the correct candidates. Learning to distinguish true candidates from false candidates is covered in Tutorial 4.\n", + "\n", + "First, we instantiate the `Viewer` object, which groups the input `Candidate` objects by `Sentence`:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/javascript": [ + "require.undef('viewer');\n", + "\n", + "// NOTE: all elements should be selected using this.$el.find to avoid collisions with other Viewers\n", + "\n", + "define('viewer', [\"jupyter-js-widgets\"], function(widgets) {\n", + " var ViewerView = widgets.DOMWidgetView.extend({\n", + " render: function() {\n", + " this.cids = this.model.get('cids');\n", + " this.nPages = this.cids.length;\n", + " this.pid = 0;\n", + " this.cxid = 0;\n", + " this.cid = 0;\n", + "\n", + " // Insert the html payload\n", + " this.$el.append(this.model.get('html'));\n", + "\n", + " // Initialize all labels from previous sessions\n", + " this.labels = this.deserializeDict(this.model.get('_labels_serialized'));\n", + " for (var i=0; i < this.nPages; i++) {\n", + " this.pid = i;\n", + " for (var j=0; j < this.cids[i].length; j++) {\n", + " this.cxid = j;\n", + " for (var k=0; k < this.cids[i][j].length; k++) {\n", + " this.cid = k;\n", + " if (this.cids[i][j][k] in this.labels) {\n", + " this.markCurrentCandidate(false);\n", + " }\n", + " }\n", + " }\n", + " }\n", + " this.pid = 0;\n", + " this.cxid = 0;\n", + " this.cid = 0;\n", + "\n", + " // Enable button functionality for navigation\n", + " var that = this;\n", + " this.$el.find(\"#next-cand\").click(function() {\n", + " that.switchCandidate(1);\n", + " });\n", + " this.$el.find(\"#prev-cand\").click(function() {\n", + " that.switchCandidate(-1);\n", + " });\n", + " this.$el.find(\"#next-context\").click(function() {\n", + " that.switchContext(1);\n", + " });\n", + " this.$el.find(\"#prev-context\").click(function() {\n", + " that.switchContext(-1);\n", + " });\n", + " this.$el.find(\"#next-page\").click(function() {\n", + " that.switchPage(1);\n", + " });\n", + " this.$el.find(\"#prev-page\").click(function() {\n", + " that.switchPage(-1);\n", + " });\n", + " this.$el.find(\"#label-true\").click(function() {\n", + " that.labelCandidate(true, true);\n", + " });\n", + " this.$el.find(\"#label-false\").click(function() {\n", + " that.labelCandidate(false, true);\n", + " });\n", + "\n", + " // Arrow key functionality\n", + " this.$el.keydown(function(e) {\n", + " switch(e.which) {\n", + " case 74: // j\n", + " that.switchCandidate(-1);\n", + " break;\n", + "\n", + " case 73: // i\n", + " that.switchPage(-1);\n", + " break;\n", + "\n", + " case 76: // l\n", + " that.switchCandidate(1);\n", + " break;\n", + "\n", + " case 75: // k\n", + " that.switchPage(1);\n", + " break;\n", + "\n", + " case 84: // t\n", + " that.labelCandidate(true, true);\n", + " break;\n", + "\n", + " case 70: // f\n", + " that.labelCandidate(false, true);\n", + " break;\n", + " }\n", + " });\n", + "\n", + " // Show the first page and highlight the first candidate\n", + " this.$el.find(\"#viewer-page-0\").show();\n", + " this.switchCandidate(0);\n", + " },\n", + "\n", + " // Get candidate selector for currently selected candidate, escaping id properly\n", + " getCandidate: function() {\n", + " return this.$el.find(\".\"+this.cids[this.pid][this.cxid][this.cid]);\n", + " }, \n", + "\n", + " // Color the candidate correctly according to registered label, as well as set highlighting\n", + " markCurrentCandidate: function(highlight) {\n", + " var cid = this.cids[this.pid][this.cxid][this.cid];\n", + " var tags = this.$el.find(\".\"+cid);\n", + "\n", + " // Clear color classes\n", + " tags.removeClass(\"candidate-h\");\n", + " tags.removeClass(\"true-candidate\");\n", + " tags.removeClass(\"true-candidate-h\");\n", + " tags.removeClass(\"false-candidate\");\n", + " tags.removeClass(\"false-candidate-h\");\n", + " tags.removeClass(\"highlighted\");\n", + "\n", + " if (highlight) {\n", + " if (cid in this.labels) {\n", + " tags.addClass(String(this.labels[cid]) + \"-candidate-h\");\n", + " } else {\n", + " tags.addClass(\"candidate-h\");\n", + " }\n", + " \n", + " // If un-highlighting, leave with first non-null coloring\n", + " } else {\n", + " var that = this;\n", + " tags.each(function() {\n", + " var cids = $(this).attr('class').split(/\\s+/).map(function(item) {\n", + " return parseInt(item);\n", + " });\n", + " cids.sort();\n", + " for (var i in cids) {\n", + " if (cids[i] in that.labels) {\n", + " var label = that.labels[cids[i]];\n", + " $(this).addClass(String(label) + \"-candidate\");\n", + " $(this).removeClass(String(!label) + \"-candidate\");\n", + " break;\n", + " }\n", + " }\n", + " });\n", + " }\n", + "\n", + " // Extra highlighting css\n", + " if (highlight) {\n", + " tags.addClass(\"highlighted\");\n", + " }\n", + "\n", + " // Classes for showing direction of relation\n", + " if (highlight) {\n", + " this.$el.find(\".\"+cid+\"-0\").addClass(\"left-candidate\");\n", + " this.$el.find(\".\"+cid+\"-1\").addClass(\"right-candidate\");\n", + " } else {\n", + " this.$el.find(\".\"+cid+\"-0\").removeClass(\"left-candidate\");\n", + " this.$el.find(\".\"+cid+\"-1\").removeClass(\"right-candidate\");\n", + " }\n", + " },\n", + "\n", + " // Cycle through candidates and highlight, by increment inc\n", + " switchCandidate: function(inc) {\n", + " var N = this.cids[this.pid].length\n", + " var M = this.cids[this.pid][this.cxid].length;\n", + " if (N == 0 || M == 0) { return false; }\n", + "\n", + " // Clear highlighting from previous candidate\n", + " if (inc != 0) {\n", + " this.markCurrentCandidate(false);\n", + "\n", + " // Increment the cid counter\n", + "\n", + " // Move to next context\n", + " if (this.cid + inc >= M) {\n", + " while (this.cid + inc >= M) {\n", + " \n", + " // At last context on page, halt\n", + " if (this.cxid == N - 1) {\n", + " this.cid = M - 1;\n", + " inc = 0;\n", + " break;\n", + " \n", + " // Increment to next context\n", + " } else {\n", + " inc -= M - this.cid;\n", + " this.cxid += 1;\n", + " M = this.cids[this.pid][this.cxid].length;\n", + " this.cid = 0;\n", + " }\n", + " }\n", + "\n", + " // Move to previous context\n", + " } else if (this.cid + inc < 0) {\n", + " while (this.cid + inc < 0) {\n", + " \n", + " // At first context on page, halt\n", + " if (this.cxid == 0) {\n", + " this.cid = 0;\n", + " inc = 0;\n", + " break;\n", + " \n", + " // Increment to previous context\n", + " } else {\n", + " inc += this.cid + 1;\n", + " this.cxid -= 1;\n", + " M = this.cids[this.pid][this.cxid].length;\n", + " this.cid = M - 1;\n", + " }\n", + " }\n", + " }\n", + "\n", + " // Move within current context\n", + " this.cid += inc;\n", + " }\n", + " this.markCurrentCandidate(true);\n", + "\n", + " // Push this new cid to the model\n", + " this.model.set('_selected_cid', this.cids[this.pid][this.cxid][this.cid]);\n", + " this.touch();\n", + " },\n", + "\n", + " // Switch through contexts\n", + " switchContext: function(inc) {\n", + " this.markCurrentCandidate(false);\n", + "\n", + " // Iterate context on this page\n", + " var M = this.cids[this.pid].length;\n", + " if (this.cxid + inc < 0) {\n", + " this.cxid = 0;\n", + " } else if (this.cxid + inc >= M) {\n", + " this.cxid = M - 1;\n", + " } else {\n", + " this.cxid += inc;\n", + " }\n", + "\n", + " // Reset cid and set to first candidate\n", + " this.cid = 0;\n", + " this.switchCandidate(0);\n", + " },\n", + "\n", + " // Switch through pages\n", + " switchPage: function(inc) {\n", + " this.markCurrentCandidate(false);\n", + " this.$el.find(\".viewer-page\").hide();\n", + " if (this.pid + inc < 0) {\n", + " this.pid = 0;\n", + " } else if (this.pid + inc > this.nPages - 1) {\n", + " this.pid = this.nPages - 1;\n", + " } else {\n", + " this.pid += inc;\n", + " }\n", + " this.$el.find(\"#viewer-page-\"+this.pid).show();\n", + "\n", + " // Show pagination\n", + " this.$el.find(\"#page\").html(this.pid);\n", + "\n", + " // Reset cid and set to first candidate\n", + " this.cid = 0;\n", + " this.cxid = 0;\n", + " this.switchCandidate(0);\n", + " },\n", + "\n", + " // Label currently-selected candidate\n", + " labelCandidate: function(label, highlighted) {\n", + " var c = this.getCandidate();\n", + " var cid = this.cids[this.pid][this.cxid][this.cid];\n", + " var cl = String(label) + \"-candidate\";\n", + " var clh = String(label) + \"-candidate-h\";\n", + " var cln = String(!label) + \"-candidate\";\n", + " var clnh = String(!label) + \"-candidate-h\";\n", + "\n", + " // Toggle label highlighting\n", + " if (c.hasClass(cl) || c.hasClass(clh)) {\n", + " c.removeClass(cl);\n", + " c.removeClass(clh);\n", + " if (highlighted) {\n", + " c.addClass(\"candidate-h\");\n", + " }\n", + " this.labels[cid] = null;\n", + " this.send({event: 'delete_label', cid: cid});\n", + " } else {\n", + " c.removeClass(cln);\n", + " c.removeClass(clnh);\n", + " if (highlighted) {\n", + " c.addClass(clh);\n", + " } else {\n", + " c.addClass(cl);\n", + " }\n", + " this.labels[cid] = label;\n", + " this.send({event: 'set_label', cid: cid, value: label});\n", + " }\n", + "\n", + " // Set the label and pass back to the model\n", + " this.model.set('_labels_serialized', this.serializeDict(this.labels));\n", + " this.touch();\n", + " },\n", + "\n", + " // Serialization of hash maps, because traitlets Dict doesn't seem to work...\n", + " serializeDict: function(d) {\n", + " var s = [];\n", + " for (var key in d) {\n", + " s.push(key+\"~~\"+d[key]);\n", + " }\n", + " return s.join();\n", + " },\n", + "\n", + " // Deserialization of hash maps\n", + " deserializeDict: function(s) {\n", + " var d = {};\n", + " var entries = s.split(/,/);\n", + " var kv;\n", + " for (var i in entries) {\n", + " kv = entries[i].split(/~~/);\n", + " if (kv[1] == \"true\") {\n", + " d[kv[0]] = true;\n", + " } else if (kv[1] == \"false\") {\n", + " d[kv[0]] = false;\n", + " }\n", + " }\n", + " return d;\n", + " },\n", + " });\n", + "\n", + " return {\n", + " ViewerView: ViewerView\n", + " };\n", + "});\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from snorkel.viewer import SentenceNgramViewer\n", + "\n", + "# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook\n", + "# You should ignore this!\n", + "import os\n", + "if 'CI' not in os.environ:\n", + " sv = SentenceNgramViewer(c[:300], session, annotator_name=\"Tutorial Part 2 User\")\n", + "else:\n", + " sv = None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we render the `Viewer." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that we can **navigate using the provided buttons**, or **using the keyboard (hover over buttons to see controls)**, highlight candidates (even if they overlap), and also **apply binary labels** (more on where to use this later!). In particular, note that **the Viewer is synced dynamically with the notebook**, so that we can for example get the `Candidate` that is currently selected. Try it out!" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spouse(Span(\"Lupo\", parent=12, chars=[98,101], words=[21,21]), Span(\"Catherine\", parent=12, chars=[24,32], words=[6,6]))\n" + ] + } + ], + "source": [ + "if 'CI' not in os.environ:\n", + " print unicode(sv.get_selected())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Repeating for development and test corpora\n", + "We will rerun the same operations for the other two news corpora: development and test. All we do for each is load in the `Corpus` object, collect the `Sentence` objects, and run them through the `CandidateExtractor`." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[========================================] 100%\n", + "\n", + "CPU times: user 1.82 s, sys: 272 ms, total: 2.09 s\n", + "Wall time: 1.88 s\n", + "[========================================] 100%\n", + "\n", + "CPU times: user 2.18 s, sys: 232 ms, total: 2.42 s\n", + "Wall time: 2.24 s\n" + ] + } + ], + "source": [ + "for corpus_name in ['News Development', 'News Test']:\n", + " corpus = session.query(Corpus).filter(Corpus.name == corpus_name).one()\n", + " sentences = set()\n", + " for document in corpus:\n", + " for sentence in document.sentences:\n", + " if number_of_people(sentence) < 5:\n", + " sentences.add(sentence)\n", + " \n", + " %time c = ce.extract(sentences, corpus_name + ' Candidates', session)\n", + " session.add(c)\n", + "session.commit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, in Part 3, we will annotate some candidates with labels so that we can evaluate performance." + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python [default]", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}