From 58ae5cdc5d0bc721b56c21ca5bb4bdf0924d3e1b Mon Sep 17 00:00:00 2001 From: Susan Li Date: Mon, 7 Jan 2019 17:30:38 -0500 Subject: [PATCH] Add notebook --- Flair Practice.ipynb | 1178 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1178 insertions(+) create mode 100644 Flair Practice.ipynb diff --git a/Flair Practice.ipynb b/Flair Practice.ipynb new file mode 100644 index 0000000..255d726 --- /dev/null +++ b/Flair Practice.ipynb @@ -0,0 +1,1178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.data import Sentence\n", + "from flair.models import SequenceTagger" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('I love Berlin .')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "tagger = SequenceTagger.load('ner')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"I love Berlin .\" - 4 Tokens]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tagger.predict(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: \"I love Berlin .\" - 4 Tokens\n" + ] + } + ], + "source": [ + "print(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LOC-span [3]: \"Berlin\"\n" + ] + } + ], + "source": [ + "for entity in sentence.get_spans('ner'):\n", + " print(entity)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: \"The grass is green .\" - 5 Tokens\n" + ] + } + ], + "source": [ + "sentence = Sentence('The grass is green .')\n", + "print(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: 4 green\n" + ] + } + ], + "source": [ + "print(sentence.get_token(4))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: 4 green\n" + ] + } + ], + "source": [ + "print(sentence[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: 1 The\n", + "Token: 2 grass\n", + "Token: 3 is\n", + "Token: 4 green\n", + "Token: 5 .\n" + ] + } + ], + "source": [ + "for token in sentence:\n", + " print(token)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: \"The grass is green .\" - 5 Tokens\n" + ] + } + ], + "source": [ + "sentence = Sentence('The grass is green.', use_tokenizer = True)\n", + "print(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "sentence[3].add_tag('ner', 'color')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The grass is green .\n" + ] + } + ], + "source": [ + "print(sentence.to_tagged_string())" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.data import Label" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "tag: Label = sentence[3].get_tag('ner')" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"Token: 4 green\" is tagged as \"color\" with confidence score \"1.0\"\n" + ] + } + ], + "source": [ + "print(f'\"{sentence[3]}\" is tagged as \"{tag.value}\" with confidence score \"{tag.score}\"')" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('France is the current World Cup winner.')" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "sentence.add_label('sports')" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "sentence.add_labels(['sports', 'world cup'])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: \"France is the current World Cup winner.\" - 7 Tokens\n" + ] + } + ], + "source": [ + "print(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sports (1.0)\n", + "sports (1.0)\n", + "world cup (1.0)\n" + ] + } + ], + "source": [ + "for label in sentence.labels:\n", + " print(label)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: \"France is the current World Cup winner\" - 7 Tokens\n", + "sports (1.0)\n", + "world cup (1.0)\n" + ] + } + ], + "source": [ + "sentence = Sentence('France is the current World Cup winner', labels=['sports', 'world cup'])\n", + "print(sentence)\n", + "for label in sentence.labels:\n", + " print(label)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.models import SequenceTagger" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "tagger = SequenceTagger.load('ner')" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('George Washington went to Washington .')" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"George Washington went to Washington .\" - 6 Tokens]" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tagger.predict(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "George Washington went to Washington .\n" + ] + } + ], + "source": [ + "print(sentence.to_tagged_string())" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PER-span [1,2]: \"George Washington\"\n", + "LOC-span [5]: \"Washington\"\n" + ] + } + ], + "source": [ + "for entity in sentence.get_spans('ner'):\n", + " print(entity)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'text': 'George Washington went to Washington .', 'labels': [], 'entities': [{'text': 'George Washington', 'start_pos': 0, 'end_pos': 17, 'type': 'PER', 'confidence': 0.999337375164032}, {'text': 'Washington', 'start_pos': 26, 'end_pos': 36, 'type': 'LOC', 'confidence': 0.9998500347137451}]}\n" + ] + } + ], + "source": [ + "print(sentence.to_dict(tag_type='ner'))" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "tagger = SequenceTagger.load('frame')" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "sentence_1 = Sentence('George returned to Berlin to return his hat .')\n", + "sentence_2 = Sentence('He had a look at different hats .')" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"George returned to Berlin to return his hat .\" - 9 Tokens]" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tagger.predict(sentence_1)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"He had a look at different hats .\" - 8 Tokens]" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tagger.predict(sentence_2)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "George returned to Berlin to return his hat .\n", + "He had a look at different hats .\n" + ] + } + ], + "source": [ + "print(sentence_1.to_tagged_string())\n", + "print(sentence_2.to_tagged_string())" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "text = 'This is a sentence. This is another sentence. I love Berlin.'" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "from segtok.segmenter import split_single" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "sentences = [Sentence(sent, use_tokenizer=True) for sent in split_single(text)]" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"This is a sentence .\" - 5 Tokens,\n", + " Sentence: \"This is another sentence .\" - 5 Tokens,\n", + " Sentence: \"I love Berlin .\" - 4 Tokens]" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"This is a sentence .\" - 5 Tokens,\n", + " Sentence: \"This is another sentence .\" - 5 Tokens,\n", + " Sentence: \"I love Berlin .\" - 4 Tokens]" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tagger: SequenceTagger = SequenceTagger.load('ner')\n", + "tagger.predict(sentences)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.embeddings import WordEmbeddings\n", + "\n", + "# init embedding\n", + "glove_embedding = WordEmbeddings('glove')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('The grass is green .')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"The grass is green .\" - 5 Tokens]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "glove_embedding.embed(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: 1 The\n", + "tensor([-0.0382, -0.2449, 0.7281, -0.3996, 0.0832, 0.0440, -0.3914, 0.3344,\n", + " -0.5755, 0.0875, 0.2879, -0.0673, 0.3091, -0.2638, -0.1323, -0.2076,\n", + " 0.3340, -0.3385, -0.3174, -0.4834, 0.1464, -0.3730, 0.3458, 0.0520,\n", + " 0.4495, -0.4697, 0.0263, -0.5415, -0.1552, -0.1411, -0.0397, 0.2828,\n", + " 0.1439, 0.2346, -0.3102, 0.0862, 0.2040, 0.5262, 0.1716, -0.0824,\n", + " -0.7179, -0.4153, 0.2033, -0.1276, 0.4137, 0.5519, 0.5791, -0.3348,\n", + " -0.3656, -0.5486, -0.0629, 0.2658, 0.3020, 0.9977, -0.8048, -3.0243,\n", + " 0.0125, -0.3694, 2.2167, 0.7220, -0.2498, 0.9214, 0.0345, 0.4674,\n", + " 1.1079, -0.1936, -0.0746, 0.2335, -0.0521, -0.2204, 0.0572, -0.1581,\n", + " -0.3080, -0.4162, 0.3797, 0.1501, -0.5321, -0.2055, -1.2526, 0.0716,\n", + " 0.7056, 0.4974, -0.4206, 0.2615, -1.5380, -0.3022, -0.0734, -0.2831,\n", + " 0.3710, -0.2522, 0.0162, -0.0171, -0.3898, 0.8742, -0.7257, -0.5106,\n", + " -0.5203, -0.1459, 0.8278, 0.2706])\n", + "Token: 2 grass\n", + "tensor([-0.8135, 0.9404, -0.2405, -0.1350, 0.0557, 0.3363, 0.0802, -0.1015,\n", + " -0.5478, -0.3537, 0.0734, 0.2587, 0.1987, -0.1433, 0.2507, 0.4281,\n", + " 0.1950, 0.5346, 0.7424, 0.0578, -0.3178, 0.9436, 0.8145, -0.0824,\n", + " 0.6166, 0.7284, -0.3262, -1.3641, 0.1232, 0.5373, -0.5123, 0.0246,\n", + " 1.0822, -0.2296, 0.6039, 0.5541, -0.9610, 0.4803, 0.0022, 0.5591,\n", + " -0.1637, -0.8468, 0.0741, -0.6216, 0.0260, -0.5162, -0.0525, -0.1418,\n", + " -0.0161, -0.4972, -0.5534, -0.4037, 0.5096, 1.0276, -0.0840, -1.1179,\n", + " 0.3226, 0.4928, 0.9488, 0.2040, 0.5388, 0.8397, -0.0689, 0.3136,\n", + " 1.0450, -0.2267, -0.0896, -0.6427, 0.6443, -1.1001, -0.0096, 0.2668,\n", + " -0.3230, -0.6065, 0.0479, -0.1664, 0.8571, 0.2335, 0.2539, 1.2546,\n", + " 0.5472, -0.1980, -0.7186, 0.2076, -0.2587, -0.3650, 0.0834, 0.6932,\n", + " 0.1574, 1.0931, 0.0913, -1.3773, -0.2717, 0.7071, 0.1872, -0.3307,\n", + " -0.2836, 0.1030, 1.2228, 0.8374])\n", + "Token: 3 is\n", + "tensor([-0.5426, 0.4148, 1.0322, -0.4024, 0.4669, 0.2182, -0.0749, 0.4733,\n", + " 0.0810, -0.2208, -0.1281, -0.1144, 0.5089, 0.1157, 0.0282, -0.3628,\n", + " 0.4382, 0.0475, 0.2028, 0.4986, -0.1007, 0.1327, 0.1697, 0.1165,\n", + " 0.3135, 0.2571, 0.0928, -0.5683, -0.5297, -0.0515, -0.6733, 0.9253,\n", + " 0.2693, 0.2273, 0.6636, 0.2622, 0.1972, 0.2609, 0.1877, -0.3454,\n", + " -0.4263, 0.1398, 0.5634, -0.5691, 0.1240, -0.1289, 0.7248, -0.2610,\n", + " -0.2631, -0.4360, 0.0789, -0.8415, 0.5160, 1.3997, -0.7646, -3.1453,\n", + " -0.2920, -0.3125, 1.5129, 0.5243, 0.2146, 0.4245, -0.0884, -0.1780,\n", + " 1.1876, 0.1058, 0.7657, 0.2191, 0.3582, -0.1164, 0.0933, -0.6248,\n", + " -0.2190, 0.2180, 0.7406, -0.4374, 0.1434, 0.1472, -1.1605, -0.0505,\n", + " 0.1268, -0.0144, -0.9868, -0.0913, -1.2054, -0.1197, 0.0478, -0.5400,\n", + " 0.5246, -0.7096, -0.3253, -0.1346, -0.4131, 0.3343, -0.0072, 0.3225,\n", + " -0.0442, -1.2969, 0.7622, 0.4635])\n", + "Token: 4 green\n", + "tensor([-6.7907e-01, 3.4908e-01, -2.3984e-01, -9.9652e-01, 7.3782e-01,\n", + " -6.5911e-04, 2.8010e-01, 1.7287e-02, -3.6063e-01, 3.6955e-02,\n", + " -4.0395e-01, 2.4092e-02, 2.8958e-01, 4.0497e-01, 6.9992e-01,\n", + " 2.5269e-01, 8.0350e-01, 4.9370e-02, 1.5562e-01, -6.3286e-03,\n", + " -2.9414e-01, 1.4728e-01, 1.8977e-01, -5.1791e-01, 3.6986e-01,\n", + " 7.4582e-01, 8.2689e-02, -7.2601e-01, -4.0939e-01, -9.7822e-02,\n", + " -1.4096e-01, 7.1121e-01, 6.1933e-01, -2.5014e-01, 4.2250e-01,\n", + " 4.8458e-01, -5.1915e-01, 7.7125e-01, 3.6685e-01, 4.9652e-01,\n", + " -4.1298e-02, -1.4683e+00, 2.0038e-01, 1.8591e-01, 4.9860e-02,\n", + " -1.7523e-01, -3.5528e-01, 9.4153e-01, -1.1898e-01, -5.1903e-01,\n", + " -1.1887e-02, -3.9186e-01, -1.7479e-01, 9.3451e-01, -5.8931e-01,\n", + " -2.7701e+00, 3.4522e-01, 8.6533e-01, 1.0808e+00, -1.0291e-01,\n", + " -9.1220e-02, 5.5092e-01, -3.9473e-01, 5.3676e-01, 1.0383e+00,\n", + " -4.0658e-01, 2.4590e-01, -2.6797e-01, -2.6036e-01, -1.4151e-01,\n", + " -1.2022e-01, 1.6234e-01, -7.4320e-01, -6.4728e-01, 4.7133e-02,\n", + " 5.1642e-01, 1.9898e-01, 2.3919e-01, 1.2550e-01, 2.2471e-01,\n", + " 8.2613e-01, 7.8328e-02, -5.7020e-01, 2.3934e-02, -1.5410e-01,\n", + " -2.5739e-01, 4.1262e-01, -4.6967e-01, 8.7914e-01, 7.2629e-01,\n", + " 5.3862e-02, -1.1575e+00, -4.7835e-01, 2.0139e-01, -1.0051e+00,\n", + " 1.1515e-01, -9.6609e-01, 1.2960e-01, 1.8388e-01, -3.0383e-02])\n", + "Token: 5 .\n", + "tensor([-0.3398, 0.2094, 0.4635, -0.6479, -0.3838, 0.0380, 0.1713, 0.1598,\n", + " 0.4662, -0.0192, 0.4148, -0.3435, 0.2687, 0.0446, 0.4213, -0.4103,\n", + " 0.1546, 0.0222, -0.6465, 0.2526, 0.0431, -0.1945, 0.4652, 0.4565,\n", + " 0.6859, 0.0913, 0.2188, -0.7035, 0.1679, -0.3508, -0.1263, 0.6638,\n", + " -0.2582, 0.0365, -0.1361, 0.4025, 0.1429, 0.3813, -0.1228, -0.4589,\n", + " -0.2528, -0.3043, -0.1121, -0.2618, -0.2248, -0.4455, 0.2991, -0.8561,\n", + " -0.1450, -0.4909, 0.0083, -0.1749, 0.2752, 1.4401, -0.2124, -2.8435,\n", + " -0.2796, -0.4572, 1.6386, 0.7881, -0.5526, 0.6500, 0.0864, 0.3901,\n", + " 1.0632, -0.3538, 0.4833, 0.3460, 0.8417, 0.0987, -0.2421, -0.2705,\n", + " 0.0453, -0.4015, 0.1139, 0.0062, 0.0367, 0.0185, -1.0213, -0.2081,\n", + " 0.6407, -0.0688, -0.5864, 0.3348, -1.1432, -0.1148, -0.2509, -0.4591,\n", + " -0.0968, -0.1795, -0.0634, -0.6741, -0.0689, 0.5360, -0.8777, 0.3180,\n", + " -0.3924, -0.2339, 0.4730, -0.0288])\n" + ] + } + ], + "source": [ + "for token in sentence:\n", + " print(token)\n", + " print(token.embedding)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Sentence: \"The grass is green .\" - 5 Tokens]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from flair.embeddings import CharacterEmbeddings\n", + "\n", + "# init embedding\n", + "embedding = CharacterEmbeddings()\n", + "\n", + "# create a sentence\n", + "sentence = Sentence('The grass is green .')\n", + "\n", + "# embed words in sentence\n", + "embedding.embed(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.embeddings import WordEmbeddings, CharacterEmbeddings\n", + "\n", + "glove_embedding = WordEmbeddings('glove')\n", + "character_embeddings = CharacterEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.embeddings import StackedEmbeddings\n", + "\n", + "stacked_embeddings = StackedEmbeddings(\n", + " embeddings = [glove_embedding, character_embeddings])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('The grass is green .')\n", + "stacked_embeddings.embed(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "StackedEmbeddings(\n", + " (list_embedding_0): WordEmbeddings()\n", + " (list_embedding_1): CharacterEmbeddings(\n", + " (char_embedding): Embedding(275, 25)\n", + " (char_rnn): LSTM(25, 25, bidirectional=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stacked_embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: 1 The\n", + "tensor([-3.8194e-02, -2.4487e-01, 7.2812e-01, -3.9961e-01, 8.3172e-02,\n", + " 4.3953e-02, -3.9141e-01, 3.3440e-01, -5.7545e-01, 8.7459e-02,\n", + " 2.8787e-01, -6.7310e-02, 3.0906e-01, -2.6384e-01, -1.3231e-01,\n", + " -2.0757e-01, 3.3395e-01, -3.3848e-01, -3.1743e-01, -4.8336e-01,\n", + " 1.4640e-01, -3.7304e-01, 3.4577e-01, 5.2041e-02, 4.4946e-01,\n", + " -4.6971e-01, 2.6280e-02, -5.4155e-01, -1.5518e-01, -1.4107e-01,\n", + " -3.9722e-02, 2.8277e-01, 1.4393e-01, 2.3464e-01, -3.1021e-01,\n", + " 8.6173e-02, 2.0397e-01, 5.2624e-01, 1.7164e-01, -8.2378e-02,\n", + " -7.1787e-01, -4.1531e-01, 2.0335e-01, -1.2763e-01, 4.1367e-01,\n", + " 5.5187e-01, 5.7908e-01, -3.3477e-01, -3.6559e-01, -5.4857e-01,\n", + " -6.2892e-02, 2.6584e-01, 3.0205e-01, 9.9775e-01, -8.0481e-01,\n", + " -3.0243e+00, 1.2540e-02, -3.6942e-01, 2.2167e+00, 7.2201e-01,\n", + " -2.4978e-01, 9.2136e-01, 3.4514e-02, 4.6745e-01, 1.1079e+00,\n", + " -1.9358e-01, -7.4575e-02, 2.3353e-01, -5.2062e-02, -2.2044e-01,\n", + " 5.7162e-02, -1.5806e-01, -3.0798e-01, -4.1625e-01, 3.7972e-01,\n", + " 1.5006e-01, -5.3212e-01, -2.0550e-01, -1.2526e+00, 7.1624e-02,\n", + " 7.0565e-01, 4.9744e-01, -4.2063e-01, 2.6148e-01, -1.5380e+00,\n", + " -3.0223e-01, -7.3438e-02, -2.8312e-01, 3.7104e-01, -2.5217e-01,\n", + " 1.6215e-02, -1.7099e-02, -3.8984e-01, 8.7424e-01, -7.2569e-01,\n", + " -5.1058e-01, -5.2028e-01, -1.4590e-01, 8.2780e-01, 2.7062e-01,\n", + " 6.3108e-03, -2.5734e-01, -1.0127e-02, -3.2274e-02, 2.8811e-03,\n", + " -5.4744e-03, 1.3090e-01, -5.9536e-02, 1.6519e-01, -4.3329e-02,\n", + " -1.3551e-01, -3.7349e-02, 1.3456e-01, 2.2238e-01, -1.4624e-01,\n", + " 2.5576e-01, -3.7204e-02, 4.0126e-01, 2.4248e-01, 1.0761e-01,\n", + " -1.1448e-01, -7.3020e-02, -2.2720e-01, 2.5039e-02, -2.8237e-01,\n", + " 1.4390e-01, -1.1083e-01, -1.3574e-01, -4.3571e-02, -8.2276e-02,\n", + " 3.1797e-01, 1.0562e-01, 1.8005e-01, -2.1153e-01, -1.6619e-01,\n", + " -6.8136e-02, 2.4039e-01, -3.3454e-02, -2.0433e-01, 1.0134e-01,\n", + " -2.1497e-01, 1.5355e-02, 2.0674e-01, 1.4811e-01, -9.1302e-02,\n", + " 4.0148e-02, 8.6612e-02, 4.1856e-02, 8.1967e-02, -7.4910e-02],\n", + " grad_fn=)\n", + "Token: 2 grass\n", + "tensor([-0.8135, 0.9404, -0.2405, -0.1350, 0.0557, 0.3363, 0.0802, -0.1015,\n", + " -0.5478, -0.3537, 0.0734, 0.2587, 0.1987, -0.1433, 0.2507, 0.4281,\n", + " 0.1950, 0.5346, 0.7424, 0.0578, -0.3178, 0.9436, 0.8145, -0.0824,\n", + " 0.6166, 0.7284, -0.3262, -1.3641, 0.1232, 0.5373, -0.5123, 0.0246,\n", + " 1.0822, -0.2296, 0.6039, 0.5541, -0.9610, 0.4803, 0.0022, 0.5591,\n", + " -0.1637, -0.8468, 0.0741, -0.6216, 0.0260, -0.5162, -0.0525, -0.1418,\n", + " -0.0161, -0.4972, -0.5534, -0.4037, 0.5096, 1.0276, -0.0840, -1.1179,\n", + " 0.3226, 0.4928, 0.9488, 0.2040, 0.5388, 0.8397, -0.0689, 0.3136,\n", + " 1.0450, -0.2267, -0.0896, -0.6427, 0.6443, -1.1001, -0.0096, 0.2668,\n", + " -0.3230, -0.6065, 0.0479, -0.1664, 0.8571, 0.2335, 0.2539, 1.2546,\n", + " 0.5472, -0.1980, -0.7186, 0.2076, -0.2587, -0.3650, 0.0834, 0.6932,\n", + " 0.1574, 1.0931, 0.0913, -1.3773, -0.2717, 0.7071, 0.1872, -0.3307,\n", + " -0.2836, 0.1030, 1.2228, 0.8374, 0.1004, 0.0290, 0.2366, 0.1697,\n", + " 0.1663, 0.1168, 0.1768, 0.2029, 0.2458, -0.2917, -0.2440, 0.2163,\n", + " 0.1219, -0.1865, -0.0176, -0.1864, 0.1176, 0.1054, 0.1579, -0.1860,\n", + " -0.2466, -0.1175, 0.0732, -0.2293, 0.1627, 0.0272, -0.0785, 0.0360,\n", + " -0.0057, 0.0218, -0.0729, 0.1934, 0.0903, -0.0927, -0.4069, 0.0892,\n", + " -0.0540, 0.1659, 0.0860, -0.0584, -0.2017, 0.0455, -0.0908, 0.1252,\n", + " -0.0151, 0.0822, -0.1524, -0.0566, -0.3361, 0.0536],\n", + " grad_fn=)\n", + "Token: 3 is\n", + "tensor([-0.5426, 0.4148, 1.0322, -0.4024, 0.4669, 0.2182, -0.0749, 0.4733,\n", + " 0.0810, -0.2208, -0.1281, -0.1144, 0.5089, 0.1157, 0.0282, -0.3628,\n", + " 0.4382, 0.0475, 0.2028, 0.4986, -0.1007, 0.1327, 0.1697, 0.1165,\n", + " 0.3135, 0.2571, 0.0928, -0.5683, -0.5297, -0.0515, -0.6733, 0.9253,\n", + " 0.2693, 0.2273, 0.6636, 0.2622, 0.1972, 0.2609, 0.1877, -0.3454,\n", + " -0.4263, 0.1398, 0.5634, -0.5691, 0.1240, -0.1289, 0.7248, -0.2610,\n", + " -0.2631, -0.4360, 0.0789, -0.8415, 0.5160, 1.3997, -0.7646, -3.1453,\n", + " -0.2920, -0.3125, 1.5129, 0.5243, 0.2146, 0.4245, -0.0884, -0.1780,\n", + " 1.1876, 0.1058, 0.7657, 0.2191, 0.3582, -0.1164, 0.0933, -0.6248,\n", + " -0.2190, 0.2180, 0.7406, -0.4374, 0.1434, 0.1472, -1.1605, -0.0505,\n", + " 0.1268, -0.0144, -0.9868, -0.0913, -1.2054, -0.1197, 0.0478, -0.5400,\n", + " 0.5246, -0.7096, -0.3253, -0.1346, -0.4131, 0.3343, -0.0072, 0.3225,\n", + " -0.0442, -1.2969, 0.7622, 0.4635, -0.0254, -0.0266, 0.2035, 0.1153,\n", + " -0.0407, 0.1062, 0.0477, 0.1164, 0.2812, -0.2613, -0.2390, 0.2604,\n", + " 0.0625, -0.1660, -0.0306, -0.1705, 0.1613, 0.1041, 0.1519, -0.1656,\n", + " -0.2456, -0.0969, 0.1303, -0.0885, 0.1226, 0.0272, -0.0785, 0.0360,\n", + " -0.0057, 0.0218, -0.0729, 0.1934, 0.0903, -0.0927, -0.4069, 0.0892,\n", + " -0.0540, 0.1659, 0.0860, -0.0584, -0.2017, 0.0455, -0.0908, 0.1252,\n", + " -0.0151, 0.0822, -0.1524, -0.0566, -0.3361, 0.0536],\n", + " grad_fn=)\n", + "Token: 4 green\n", + "tensor([-6.7907e-01, 3.4908e-01, -2.3984e-01, -9.9652e-01, 7.3782e-01,\n", + " -6.5911e-04, 2.8010e-01, 1.7287e-02, -3.6063e-01, 3.6955e-02,\n", + " -4.0395e-01, 2.4092e-02, 2.8958e-01, 4.0497e-01, 6.9992e-01,\n", + " 2.5269e-01, 8.0350e-01, 4.9370e-02, 1.5562e-01, -6.3286e-03,\n", + " -2.9414e-01, 1.4728e-01, 1.8977e-01, -5.1791e-01, 3.6986e-01,\n", + " 7.4582e-01, 8.2689e-02, -7.2601e-01, -4.0939e-01, -9.7822e-02,\n", + " -1.4096e-01, 7.1121e-01, 6.1933e-01, -2.5014e-01, 4.2250e-01,\n", + " 4.8458e-01, -5.1915e-01, 7.7125e-01, 3.6685e-01, 4.9652e-01,\n", + " -4.1298e-02, -1.4683e+00, 2.0038e-01, 1.8591e-01, 4.9860e-02,\n", + " -1.7523e-01, -3.5528e-01, 9.4153e-01, -1.1898e-01, -5.1903e-01,\n", + " -1.1887e-02, -3.9186e-01, -1.7479e-01, 9.3451e-01, -5.8931e-01,\n", + " -2.7701e+00, 3.4522e-01, 8.6533e-01, 1.0808e+00, -1.0291e-01,\n", + " -9.1220e-02, 5.5092e-01, -3.9473e-01, 5.3676e-01, 1.0383e+00,\n", + " -4.0658e-01, 2.4590e-01, -2.6797e-01, -2.6036e-01, -1.4151e-01,\n", + " -1.2022e-01, 1.6234e-01, -7.4320e-01, -6.4728e-01, 4.7133e-02,\n", + " 5.1642e-01, 1.9898e-01, 2.3919e-01, 1.2550e-01, 2.2471e-01,\n", + " 8.2613e-01, 7.8328e-02, -5.7020e-01, 2.3934e-02, -1.5410e-01,\n", + " -2.5739e-01, 4.1262e-01, -4.6967e-01, 8.7914e-01, 7.2629e-01,\n", + " 5.3862e-02, -1.1575e+00, -4.7835e-01, 2.0139e-01, -1.0051e+00,\n", + " 1.1515e-01, -9.6609e-01, 1.2960e-01, 1.8388e-01, -3.0383e-02,\n", + " 2.3410e-01, -1.4150e-01, -1.4317e-01, -7.9950e-02, 2.2265e-01,\n", + " -3.1271e-02, 2.4928e-01, 9.5457e-02, 7.0562e-03, 8.6135e-02,\n", + " 1.3798e-01, -6.3350e-02, 9.9218e-02, -4.5819e-03, 1.9424e-01,\n", + " 3.0682e-01, 7.8153e-03, 1.2644e-01, 7.8239e-02, 9.1541e-02,\n", + " -3.2165e-02, -9.5144e-02, -1.1466e-01, -3.8280e-02, -7.9813e-02,\n", + " 7.5818e-03, 1.6530e-01, -7.5781e-02, -5.4557e-02, -7.6738e-02,\n", + " -4.6856e-02, -1.0195e-01, 9.9022e-02, 2.4027e-01, 1.0468e-02,\n", + " 1.9845e-01, -2.1230e-02, 7.1300e-02, 1.7585e-02, -9.3911e-03,\n", + " -9.7738e-02, 1.1224e-01, 3.7499e-02, -2.0135e-01, 8.5252e-02,\n", + " 6.1836e-02, -3.2621e-02, 1.1995e-02, -2.0415e-01, -2.8720e-02],\n", + " grad_fn=)\n", + "Token: 5 .\n", + "tensor([-3.3979e-01, 2.0941e-01, 4.6348e-01, -6.4792e-01, -3.8377e-01,\n", + " 3.8034e-02, 1.7127e-01, 1.5978e-01, 4.6619e-01, -1.9169e-02,\n", + " 4.1479e-01, -3.4349e-01, 2.6872e-01, 4.4640e-02, 4.2131e-01,\n", + " -4.1032e-01, 1.5459e-01, 2.2239e-02, -6.4653e-01, 2.5256e-01,\n", + " 4.3136e-02, -1.9445e-01, 4.6516e-01, 4.5651e-01, 6.8588e-01,\n", + " 9.1295e-02, 2.1875e-01, -7.0351e-01, 1.6785e-01, -3.5079e-01,\n", + " -1.2634e-01, 6.6384e-01, -2.5820e-01, 3.6542e-02, -1.3605e-01,\n", + " 4.0253e-01, 1.4289e-01, 3.8132e-01, -1.2283e-01, -4.5886e-01,\n", + " -2.5282e-01, -3.0432e-01, -1.1215e-01, -2.6182e-01, -2.2482e-01,\n", + " -4.4554e-01, 2.9910e-01, -8.5612e-01, -1.4503e-01, -4.9086e-01,\n", + " 8.2973e-03, -1.7491e-01, 2.7524e-01, 1.4401e+00, -2.1239e-01,\n", + " -2.8435e+00, -2.7958e-01, -4.5722e-01, 1.6386e+00, 7.8808e-01,\n", + " -5.5262e-01, 6.5000e-01, 8.6426e-02, 3.9012e-01, 1.0632e+00,\n", + " -3.5379e-01, 4.8328e-01, 3.4600e-01, 8.4174e-01, 9.8707e-02,\n", + " -2.4213e-01, -2.7053e-01, 4.5287e-02, -4.0147e-01, 1.1395e-01,\n", + " 6.2226e-03, 3.6673e-02, 1.8518e-02, -1.0213e+00, -2.0806e-01,\n", + " 6.4072e-01, -6.8763e-02, -5.8635e-01, 3.3476e-01, -1.1432e+00,\n", + " -1.1480e-01, -2.5091e-01, -4.5907e-01, -9.6819e-02, -1.7946e-01,\n", + " -6.3351e-02, -6.7412e-01, -6.8895e-02, 5.3604e-01, -8.7773e-01,\n", + " 3.1802e-01, -3.9242e-01, -2.3394e-01, 4.7298e-01, -2.8803e-02,\n", + " 8.2464e-02, -1.7575e-01, -1.4336e-01, 3.9867e-03, -1.4155e-01,\n", + " -7.6877e-03, 1.0880e-03, -6.3159e-02, -7.6448e-02, 8.3365e-02,\n", + " 3.6257e-03, 7.6893e-03, 1.4932e-02, -3.5098e-03, 2.7587e-02,\n", + " -3.3187e-02, -6.8181e-03, 1.6592e-01, 2.3646e-02, 1.7029e-01,\n", + " -4.5547e-02, -6.0603e-02, 1.0320e-01, -8.0149e-02, -1.5537e-01,\n", + " 9.6964e-02, -1.3416e-01, -2.1076e-01, -1.3461e-01, 8.6052e-02,\n", + " -1.5016e-01, 1.9833e-01, -1.9856e-02, -1.9699e-01, 3.4966e-02,\n", + " -4.7196e-02, 8.6259e-02, 1.0409e-01, -7.2638e-02, 2.0218e-01,\n", + " -5.5694e-02, 7.0337e-02, 1.3896e-01, 1.0324e-01, -8.2287e-02,\n", + " -7.9263e-02, -6.4011e-02, 2.1714e-03, 5.0975e-02, -1.6845e-02],\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " grad_fn=)\n" + ] + } + ], + "source": [ + "for token in sentence:\n", + " print(token)\n", + " print(token.embedding)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\SusanLi\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:4: DeprecationWarning: Call to deprecated method __init__. (Use 'FlairEmbeddings' instead.) -- Deprecated since version 0.4.\n", + " after removing the cwd from sys.path.\n", + "C:\\Users\\SusanLi\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:5: DeprecationWarning: Call to deprecated method __init__. (Use 'FlairEmbeddings' instead.) -- Deprecated since version 0.4.\n", + " \"\"\"\n" + ] + } + ], + "source": [ + "from flair.embeddings import WordEmbeddings, CharLMEmbeddings, DocumentPoolEmbeddings, Sentence\n", + "\n", + "glove_embedding = WordEmbeddings('glove')\n", + "charlm_embedding_forward = CharLMEmbeddings('news-forward')\n", + "charlm_embedding_backward = CharLMEmbeddings('news-backward')\n", + "document_embeddings = DocumentPoolEmbeddings([glove_embedding, \n", + " charlm_embedding_forward, \n", + " charlm_embedding_backward])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('The grass is green . And the sky is blue .')" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "document_embeddings.embed(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-0.3197, 0.2621, 0.4037, ..., -0.0008, -0.0051, -0.0109]])\n" + ] + } + ], + "source": [ + "print(sentence.get_embedding())" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "document_embeddings = DocumentPoolEmbeddings([glove_embedding, \n", + " charlm_embedding_backward,\n", + " charlm_embedding_forward],\n", + " mode = 'min')" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from flair.embeddings import WordEmbeddings, DocumentLSTMEmbeddings\n", + "\n", + "glove_embedding = WordEmbeddings('glove')\n", + "document_embeddings = DocumentLSTMEmbeddings([glove_embedding])" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = Sentence('The grass is green . And the sky is blue .')\n", + "document_embeddings.embed(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.0000, -0.2567, -0.3857, 0.0000, 0.0000, 0.4679, -0.0000, -0.0000,\n", + " -0.0000, 0.0413, 0.3378, -0.0000, -0.0000, -0.0000, 0.6527, -0.6511,\n", + " 1.0144, -0.1377, 0.5243, -0.5654, 0.0000, -0.0236, 0.1107, 0.0000,\n", + " -0.7132, -0.5130, -0.3489, -0.5734, 0.7072, 0.1158, -0.3548, 0.0000,\n", + " 0.0000, -0.1011, 0.0743, 0.5346, 0.2456, 0.3685, 0.0000, 0.1319,\n", + " -0.6749, -0.0000, 0.0000, -0.3798, 0.4302, 0.0000, 0.1881, 0.4432,\n", + " -0.0000, 0.6083, -0.2418, 0.5634, -0.7348, 0.7113, -0.3781, -0.4040,\n", + " 0.7722, -0.6238, 0.8772, 0.0000, 0.5456, 0.4980, 0.0000, 0.1653,\n", + " -0.0000, 0.0553, -0.8303, 0.5382, -0.0000, 0.0000, 0.1737, -0.2544,\n", + " -1.0751, 0.0816, 0.0000, -0.6108, 0.0000, 0.7551, -0.0000, -0.0000,\n", + " -0.0000, 0.0000, -0.2756, 0.0173, 0.0000, -0.0000, 0.0904, 0.0000,\n", + " 0.3185, -0.0000, 0.0000, -0.0000, -0.0000, -0.0000, 0.1771, -0.4003,\n", + " 0.0000, 0.0000, -0.6380, -0.3645, -0.0000, 0.0000, 0.0000, 0.0000,\n", + " 0.5596, 0.0000, -0.0000, -0.1360, 0.2858, -0.0000, -0.6948, -0.0000,\n", + " -1.0255, -0.1839, -0.5161, -0.0000, -0.0000, -0.0791, -0.3432, 0.5404,\n", + " 0.3125, 0.0000, 0.0000, -0.0419, 0.0000, -0.0000, 0.6848, 0.0000]],\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "print(sentence.get_embedding())" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'NLPTask' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mflair\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata_fetcher\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mNLPTaskDataFetcher\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mcorpus\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mNLPTaskDataFetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_corpus\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNLPTask\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mUD_ENGLISH\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;31mNameError\u001b[0m: name 'NLPTask' is not defined" + ] + } + ], + "source": [ + "from flair.data import TaggedCorpus\n", + "from flair.data_fetcher import NLPTaskDataFetcher\n", + "\n", + "corpus = NLPTaskDataFetcher.load_corpus(NLPTask.UD_ENGLISH)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "storage has wrong size: expected -1862414276 got 22700", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mflair\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mflair\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mSentence\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mclassifier\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'en-sentiment'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0msentence\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mSentence\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Flair is pretty neat!'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mclassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msentence\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\flair\\models\\text_classification_model.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(model)\u001b[0m\n\u001b[0;32m 277\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 278\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmodel_file\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 279\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_from_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\flair\\models\\text_classification_model.py\u001b[0m in \u001b[0;36mload_from_file\u001b[1;34m(cls, model_file)\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;32mreturn\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mloaded\u001b[0m \u001b[0mtext\u001b[0m \u001b[0mclassifier\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 105\u001b[0m \"\"\"\n\u001b[1;32m--> 106\u001b[1;33m \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mTextClassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_load_state\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 107\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 108\u001b[0m model = TextClassifier(\n", + "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\flair\\models\\text_classification_model.py\u001b[0m in \u001b[0;36m_load_state\u001b[1;34m(cls, model_file)\u001b[0m\n\u001b[0;32m 144\u001b[0m \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 145\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 146\u001b[1;33m \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m'cuda:0'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m'cpu'\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 147\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mstate\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 148\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module)\u001b[0m\n\u001b[0;32m 365\u001b[0m \u001b[0mf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'rb'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 366\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 367\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 368\u001b[0m \u001b[1;32mfinally\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 369\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mnew_fd\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m_load\u001b[1;34m(f, map_location, pickle_module)\u001b[0m\n\u001b[0;32m 543\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdeserialized_storage_keys\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 544\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdeserialized_objects\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 545\u001b[1;33m \u001b[0mdeserialized_objects\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_set_from_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moffset\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf_should_read_directly\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 546\u001b[0m \u001b[0moffset\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 547\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mRuntimeError\u001b[0m: storage has wrong size: expected -1862414276 got 22700" + ] + } + ], + "source": [ + "from flair.models import TextClassifier\n", + "from flair.data import Sentence\n", + "classifier = TextClassifier.load('en-sentiment')\n", + "sentence = Sentence('Flair is pretty neat!')\n", + "classifier.predict(sentence)\n", + "# print sentence with predicted labels\n", + "print('Sentence above is: ', sentence.labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}